-
Notifications
You must be signed in to change notification settings - Fork 145
/
generate.py
119 lines (95 loc) · 4.42 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
## this script is for generating images from pre-trained network based on StyleGAN1 (TensorFlow) and StyleGAN2-ada (PyTorch) ##
import os
import click
import dnnlib
import numpy as np
import PIL.Image
import legacy
from typing import List, Optional
"""
Generate images using pretrained network pickle.
Examples:
\b
# Generate human full-body images without truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=1 --seeds=1,3,5,7 \\
--network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
\b
# Generate human full-body images with truncation
python generate.py --outdir=outputs/generate/stylegan_human_v2_1024 --trunc=0.8 --seeds=0-100\\
--network=pretrained_models/stylegan_human_v2_1024.pkl --version 2
# \b
# Generate human full-body images using stylegan V1
# python generate.py --outdir=outputs/generate/stylegan_human_v1_1024 \\
# --network=pretrained_models/stylegan_human_v1_1024.pkl --version 1
"""
@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=legacy.num_range, help='List of random seeds')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', help='Where to save the output images', default= 'outputs/generate/' , type=str, required=True, metavar='DIR')
@click.option('--version', help="stylegan version, 1, 2 or 3", type=int, default=2)
def generate_images(
ctx: click.Context,
network_pkl: str,
seeds: Optional[List[int]],
truncation_psi: float,
noise_mode: str,
outdir: str,
version: int
):
print('Loading networks from "%s"...' % network_pkl)
if version == 1:
import dnnlib.tflib as tflib
tflib.init_tf()
G, D, Gs = legacy.load_pkl(network_pkl)
else:
import torch
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
os.makedirs(outdir, exist_ok=True)
if seeds is None:
ctx.fail('--seeds option is required.')
# Generate images.
target_z = np.array([])
target_w = np.array([])
latent_out = outdir.replace('/images/','')
for seed_idx, seed in enumerate(seeds):
if seed % 5000 == 0:
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
if version == 1: ## stylegan v1
z = np.random.RandomState(seed).randn(1, Gs.input_shape[1])
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
if noise_mode == 'const': randomize_noise=False
else: randomize_noise = True
images = Gs.run(z, None, truncation_psi=truncation_psi, randomize_noise=randomize_noise, output_transform=fmt)
PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png')
else: ## stylegan v2/v3
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
if target_z.size==0:
target_z= z.cpu()
else:
target_z=np.append(target_z, z.cpu(), axis=0)
w = G.mapping(z, label,truncation_psi=truncation_psi)
img = G.synthesis(w, noise_mode=noise_mode,force_fp32 = True)
if target_w.size==0:
target_w= w.cpu()
else:
target_w=np.append(target_w, w.cpu(), axis=0)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
# print(target_z)
# print(target_z.shape,target_w.shape)
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_images()
#----------------------------------------------------------------------------