Skip to content

Commit

Permalink
Update samples
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Jul 6, 2019
1 parent d4c0438 commit b7a058f
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 52 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,5 @@ venv.bak/

# mypy
.mypy_cache/

*.lmdb
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ for FFHQ
## Sample

![Sample of the model trained on CelebA](doc/sample.png)
![Style mixing sample of the model trained on CelebA](doc/sample_mixing.png)
![Sample of the model trained on FFHQ](doc/sample_ffhq_new.png)
![Style mixing sample of the model trained on FFHQ](doc/sample_mixing_ffhq_new.png)

I have mixed styles at 4^2 - 8^2 scale. I can't get samples as dramatic as samles in the original paper. I think my model too dependent on 4^2 scale features - it seems like that much of details determined in that scale, so little variations can be acquired after it.

![Sample of the model trained on FFHQ](doc/sample_ffhq.png)
![Style mixing sample of the model trained on FFHQ](doc/sample_mixing_ffhq.png)

Trained high resolution model on FFHQ. I think result seems more interesting.
512px sample from the generator trained on FFHQ.
Binary file added doc/sample_ffhq_new.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/sample_mixing_ffhq_new.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 47 additions & 44 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,69 @@
from model import StyledGenerator

generator = StyledGenerator(512).cuda()
generator.load_state_dict(torch.load('checkpoint/140000.model'))
generator.load_state_dict(torch.load('checkpoint/180000.model'))
generator.eval()

mean_style = None

step = 6
step = 7
alpha = 1

shape = 4 * 2 ** step

for i in range(10):
style = generator.mean_style(torch.randn(1024, 512).cuda())
with torch.no_grad():
for i in range(10):
style = generator.mean_style(torch.randn(1024, 512).cuda())

if mean_style is None:
mean_style = style
if mean_style is None:
mean_style = style

else:
mean_style += style
else:
mean_style += style

mean_style /= 10
mean_style /= 10

image = generator(
torch.randn(50, 512).cuda(),
step=step,
alpha=1,
mean_style=mean_style,
style_weight=0.7,
)
image = generator(
torch.randn(15, 512).cuda(),
step=step,
alpha=alpha,
mean_style=mean_style,
style_weight=0.7,
)

utils.save_image(image, 'sample.png', nrow=10, normalize=True, range=(-1, 1))
utils.save_image(image, 'sample.png', nrow=5, normalize=True, range=(-1, 1))

for j in range(20):
source_code = torch.randn(9, 512).cuda()
target_code = torch.randn(5, 512).cuda()
for j in range(20):
source_code = torch.randn(5, 512).cuda()
target_code = torch.randn(3, 512).cuda()

images = [torch.ones(1, 3, shape, shape).cuda() * -1]
images = [torch.ones(1, 3, shape, shape).cuda() * -1]

source_image = generator(
source_code, step=step, alpha=1, mean_style=mean_style, style_weight=0.7
)
target_image = generator(
target_code, step=step, alpha=1, mean_style=mean_style, style_weight=0.7
)
source_image = generator(
source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
)
target_image = generator(
target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
)

images.append(source_image)
images.append(source_image)

for i in range(5):
image = generator(
[target_code[i].unsqueeze(0).repeat(9, 1), source_code],
step=step,
alpha=1,
mean_style=mean_style,
style_weight=0.7,
mixing_range=(0, 1),
)
images.append(target_image[i].unsqueeze(0))
images.append(image)
for i in range(3):
image = generator(
[target_code[i].unsqueeze(0).repeat(5, 1), source_code],
step=step,
alpha=alpha,
mean_style=mean_style,
style_weight=0.7,
mixing_range=(0, 1),
)
images.append(target_image[i].unsqueeze(0))
images.append(image)

# print([i.shape for i in images])
# print([i.shape for i in images])

images = torch.cat(images, 0)
images = torch.cat(images, 0)

utils.save_image(
images, f'sample_mixing_{j}.png', nrow=10, normalize=True, range=(-1, 1)
)
utils.save_image(
images, f'sample_mixing_{j}.png', nrow=6, normalize=True, range=(-1, 1)
)

0 comments on commit b7a058f

Please sign in to comment.