diff --git a/.gitignore b/.gitignore index 894a44cc..64c70f38 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,5 @@ venv.bak/ # mypy .mypy_cache/ + +*.lmdb diff --git a/README.md b/README.md index 49659bb4..77681bb6 100755 --- a/README.md +++ b/README.md @@ -14,12 +14,7 @@ for FFHQ ## Sample -![Sample of the model trained on CelebA](doc/sample.png) -![Style mixing sample of the model trained on CelebA](doc/sample_mixing.png) +![Sample of the model trained on FFHQ](doc/sample_ffhq_new.png) +![Style mixing sample of the model trained on FFHQ](doc/sample_mixing_ffhq_new.png) -I have mixed styles at 4^2 - 8^2 scale. I can't get samples as dramatic as samles in the original paper. I think my model too dependent on 4^2 scale features - it seems like that much of details determined in that scale, so little variations can be acquired after it. - -![Sample of the model trained on FFHQ](doc/sample_ffhq.png) -![Style mixing sample of the model trained on FFHQ](doc/sample_mixing_ffhq.png) - -Trained high resolution model on FFHQ. I think result seems more interesting. \ No newline at end of file +512px sample from the generator trained on FFHQ. \ No newline at end of file diff --git a/doc/sample_ffhq_new.png b/doc/sample_ffhq_new.png new file mode 100755 index 00000000..6364801e Binary files /dev/null and b/doc/sample_ffhq_new.png differ diff --git a/doc/sample_mixing_ffhq_new.png b/doc/sample_mixing_ffhq_new.png new file mode 100755 index 00000000..51d76737 Binary files /dev/null and b/doc/sample_mixing_ffhq_new.png differ diff --git a/generate.py b/generate.py index 9da1ec83..82226b7a 100755 --- a/generate.py +++ b/generate.py @@ -4,66 +4,69 @@ from model import StyledGenerator generator = StyledGenerator(512).cuda() -generator.load_state_dict(torch.load('checkpoint/140000.model')) +generator.load_state_dict(torch.load('checkpoint/180000.model')) +generator.eval() mean_style = None -step = 6 +step = 7 +alpha = 1 shape = 4 * 2 ** step -for i in range(10): - style = generator.mean_style(torch.randn(1024, 512).cuda()) +with torch.no_grad(): + for i in range(10): + style = generator.mean_style(torch.randn(1024, 512).cuda()) - if mean_style is None: - mean_style = style + if mean_style is None: + mean_style = style - else: - mean_style += style + else: + mean_style += style -mean_style /= 10 + mean_style /= 10 -image = generator( - torch.randn(50, 512).cuda(), - step=step, - alpha=1, - mean_style=mean_style, - style_weight=0.7, -) + image = generator( + torch.randn(15, 512).cuda(), + step=step, + alpha=alpha, + mean_style=mean_style, + style_weight=0.7, + ) -utils.save_image(image, 'sample.png', nrow=10, normalize=True, range=(-1, 1)) + utils.save_image(image, 'sample.png', nrow=5, normalize=True, range=(-1, 1)) -for j in range(20): - source_code = torch.randn(9, 512).cuda() - target_code = torch.randn(5, 512).cuda() + for j in range(20): + source_code = torch.randn(5, 512).cuda() + target_code = torch.randn(3, 512).cuda() - images = [torch.ones(1, 3, shape, shape).cuda() * -1] + images = [torch.ones(1, 3, shape, shape).cuda() * -1] - source_image = generator( - source_code, step=step, alpha=1, mean_style=mean_style, style_weight=0.7 - ) - target_image = generator( - target_code, step=step, alpha=1, mean_style=mean_style, style_weight=0.7 - ) + source_image = generator( + source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7 + ) + target_image = generator( + target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7 + ) - images.append(source_image) + images.append(source_image) - for i in range(5): - image = generator( - [target_code[i].unsqueeze(0).repeat(9, 1), source_code], - step=step, - alpha=1, - mean_style=mean_style, - style_weight=0.7, - mixing_range=(0, 1), - ) - images.append(target_image[i].unsqueeze(0)) - images.append(image) + for i in range(3): + image = generator( + [target_code[i].unsqueeze(0).repeat(5, 1), source_code], + step=step, + alpha=alpha, + mean_style=mean_style, + style_weight=0.7, + mixing_range=(0, 1), + ) + images.append(target_image[i].unsqueeze(0)) + images.append(image) - # print([i.shape for i in images]) + # print([i.shape for i in images]) - images = torch.cat(images, 0) + images = torch.cat(images, 0) - utils.save_image( - images, f'sample_mixing_{j}.png', nrow=10, normalize=True, range=(-1, 1) - ) + utils.save_image( + images, f'sample_mixing_{j}.png', nrow=6, normalize=True, range=(-1, 1) + )