diff --git a/generate.py b/generate.py index 14360f20..9da1ec83 100755 --- a/generate.py +++ b/generate.py @@ -4,7 +4,7 @@ from model import StyledGenerator generator = StyledGenerator(512).cuda() -generator.load_state_dict(torch.load('checkpoint/130000.model')) +generator.load_state_dict(torch.load('checkpoint/140000.model')) mean_style = None diff --git a/model.py b/model.py index 465943ff..d0d1fa01 100755 --- a/model.py +++ b/model.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn import init from torch.nn import functional as F -from torch.autograd import Variable +from torch.autograd import Function from math import sqrt @@ -53,6 +53,64 @@ def equal_lr(module, name='weight'): return module +class FusedUpsample(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, padding=0): + super().__init__() + + weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size) + bias = torch.zeros(out_channel) + + fan_in = in_channel * kernel_size * kernel_size + self.multiplier = sqrt(2 / fan_in) + + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) + + self.pad = padding + + def forward(self, input): + weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1]) + weight = ( + weight[:, :, 1:, 1:] + + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + + weight[:, :, :-1, :-1] + ) / 4 + + out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad) + + return out + + +class FusedDownsample(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, padding=0): + super().__init__() + + weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size) + bias = torch.zeros(out_channel) + + fan_in = in_channel * kernel_size * kernel_size + self.multiplier = sqrt(2 / fan_in) + + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) + + self.pad = padding + + def forward(self, input): + weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1]) + weight = ( + weight[:, :, 1:, 1:] + + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + + weight[:, :, :-1, :-1] + ) / 4 + + out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad) + + return out + + class PixelNorm(nn.Module): def __init__(self): super().__init__() @@ -61,6 +119,49 @@ def forward(self, input): return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) +class BlurFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + grad_input = F.conv2d( + grad_output, kernel_flip, padding=1, groups=grad_output.shape[1] + ) + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = F.conv2d( + gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1] + ) + + return grad_input, None, None + + +class BlurFunction(Function): + @staticmethod + def forward(ctx, input, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + + output = F.conv2d(input, kernel, padding=1, groups=input.shape[1]) + + return output + + @staticmethod + def backward(ctx, grad_output): + kernel, kernel_flip = ctx.saved_tensors + + grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) + + return grad_input, None, None + + +blur = BlurFunction.apply + + class Blur(nn.Module): def __init__(self, channel): super().__init__() @@ -68,15 +169,14 @@ def __init__(self, channel): weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) weight = weight.view(1, 1, 3, 3) weight = weight / weight.sum() + weight_flip = torch.flip(weight, [2, 3]) + self.register_buffer('weight', weight.repeat(channel, 1, 1, 1)) + self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1)) def forward(self, input): - return F.conv2d( - input, - self.weight, - padding=1, - groups=input.shape[1], - ) + return blur(input, self.weight, self.weight_flip) + # return F.conv2d(input, self.weight, padding=1, groups=input.shape[1]) class EqualConv2d(nn.Module): @@ -115,8 +215,8 @@ def __init__( padding, kernel_size2=None, padding2=None, - pixel_norm=True, - spectral_norm=False, + downsample=False, + fused=False, ): super().__init__() @@ -130,15 +230,36 @@ def __init__( if kernel_size2 is not None: kernel2 = kernel_size2 - self.conv = nn.Sequential( + self.conv1 = nn.Sequential( EqualConv2d(in_channel, out_channel, kernel1, padding=pad1), nn.LeakyReLU(0.2), - EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), - nn.LeakyReLU(0.2), ) + if downsample: + if fused: + self.conv2 = nn.Sequential( + Blur(out_channel), + FusedDownsample(out_channel, out_channel, kernel2, padding=pad2), + nn.LeakyReLU(0.2), + ) + + else: + self.conv2 = nn.Sequential( + Blur(out_channel), + EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), + nn.AvgPool2d(2), + nn.LeakyReLU(0.2), + ) + + else: + self.conv2 = nn.Sequential( + EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), + nn.LeakyReLU(0.2), + ) + def forward(self, input): - out = self.conv(input) + out = self.conv1(input) + out = self.conv2(out) return out @@ -195,6 +316,8 @@ def __init__( padding=1, style_dim=512, initial=False, + upsample=False, + fused=False, ): super().__init__() @@ -202,9 +325,28 @@ def __init__( self.conv1 = ConstantInput(in_channel) else: - self.conv1 = EqualConv2d( - in_channel, out_channel, kernel_size, padding=padding - ) + if upsample: + if fused: + self.conv1 = nn.Sequential( + FusedUpsample( + in_channel, out_channel, kernel_size, padding=padding + ), + Blur(out_channel), + ) + + else: + self.conv1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + EqualConv2d( + in_channel, out_channel, kernel_size, padding=padding + ), + Blur(out_channel), + ) + + else: + self.conv1 = EqualConv2d( + in_channel, out_channel, kernel_size, padding=padding + ) self.noise1 = equal_lr(NoiseInjection(out_channel)) self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim) @@ -230,20 +372,20 @@ def forward(self, input, style, noise): class Generator(nn.Module): - def __init__(self, code_dim): + def __init__(self, code_dim, fused=True): super().__init__() self.progression = nn.ModuleList( [ - StyledConvBlock(512, 512, 3, 1, initial=True), - StyledConvBlock(512, 512, 3, 1), - StyledConvBlock(512, 512, 3, 1), - StyledConvBlock(512, 512, 3, 1), - StyledConvBlock(512, 256, 3, 1), - StyledConvBlock(256, 128, 3, 1), - StyledConvBlock(128, 64, 3, 1), - StyledConvBlock(64, 32, 3, 1), - StyledConvBlock(32, 16, 3, 1), + StyledConvBlock(512, 512, 3, 1, initial=True), # 4 + StyledConvBlock(512, 512, 3, 1, upsample=True), # 8 + StyledConvBlock(512, 512, 3, 1, upsample=True), # 16 + StyledConvBlock(512, 512, 3, 1, upsample=True), # 32 + StyledConvBlock(512, 256, 3, 1, upsample=True), # 64 + StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128 + StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256 + StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512 + StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024 ] ) @@ -291,11 +433,7 @@ def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)): if i > 0 and step > 0: out_prev = out - upsample = F.interpolate( - out, scale_factor=2, mode='bilinear', align_corners=False - ) - # upsample = self.blur(upsample) - out = conv(upsample, style_step, noise[i]) + out = conv(out, style_step, noise[i]) else: out = conv(out, style_step, noise[i]) @@ -305,7 +443,7 @@ def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)): if i > 0 and 0 <= alpha < 1: skip_rgb = self.to_rgb[i - 1](out_prev) - skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='bilinear', align_corners=False) + skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest') out = (1 - alpha) * skip_rgb + alpha * out break @@ -369,19 +507,19 @@ def mean_style(self, input): class Discriminator(nn.Module): - def __init__(self): + def __init__(self, fused=True): super().__init__() self.progression = nn.ModuleList( [ - ConvBlock(16, 32, 3, 1), - ConvBlock(32, 64, 3, 1), - ConvBlock(64, 128, 3, 1), - ConvBlock(128, 256, 3, 1), - ConvBlock(256, 512, 3, 1), - ConvBlock(512, 512, 3, 1), - ConvBlock(512, 512, 3, 1), - ConvBlock(512, 512, 3, 1), + ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512 + ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256 + ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128 + ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64 + ConvBlock(256, 512, 3, 1, downsample=True), # 32 + ConvBlock(512, 512, 3, 1, downsample=True), # 16 + ConvBlock(512, 512, 3, 1, downsample=True), # 8 + ConvBlock(512, 512, 3, 1, downsample=True), # 4 ConvBlock(513, 512, 3, 1, 4, 0), ] ) @@ -422,17 +560,9 @@ def forward(self, input, step=0, alpha=-1): out = self.progression[index](out) if i > 0: - # out = F.avg_pool2d(out, 2) - out = F.interpolate( - out, scale_factor=0.5, mode='bilinear', align_corners=False - ) - if i == step and 0 <= alpha < 1: - # skip_rgb = F.avg_pool2d(input, 2) - skip_rgb = self.from_rgb[index + 1](input) - skip_rgb = F.interpolate( - skip_rgb, scale_factor=0.5, mode='bilinear', align_corners=False - ) + skip_rgb = F.avg_pool2d(input, 2) + skip_rgb = self.from_rgb[index + 1](skip_rgb) out = (1 - alpha) * skip_rgb + alpha * out