Skip to content

Commit

Permalink
Revised model to more closely match official implementation
Browse files Browse the repository at this point in the history
rosinality committed Jul 4, 2019
1 parent c6814c2 commit d4c0438
Showing 2 changed files with 182 additions and 52 deletions.
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from model import StyledGenerator

generator = StyledGenerator(512).cuda()
generator.load_state_dict(torch.load('checkpoint/130000.model'))
generator.load_state_dict(torch.load('checkpoint/140000.model'))

mean_style = None

232 changes: 181 additions & 51 deletions model.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.autograd import Variable
from torch.autograd import Function

from math import sqrt

@@ -53,6 +53,64 @@ def equal_lr(module, name='weight'):
return module


class FusedUpsample(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, padding=0):
super().__init__()

weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
bias = torch.zeros(out_channel)

fan_in = in_channel * kernel_size * kernel_size
self.multiplier = sqrt(2 / fan_in)

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

self.pad = padding

def forward(self, input):
weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
weight = (
weight[:, :, 1:, 1:]
+ weight[:, :, :-1, 1:]
+ weight[:, :, 1:, :-1]
+ weight[:, :, :-1, :-1]
) / 4

out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)

return out


class FusedDownsample(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, padding=0):
super().__init__()

weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)
bias = torch.zeros(out_channel)

fan_in = in_channel * kernel_size * kernel_size
self.multiplier = sqrt(2 / fan_in)

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

self.pad = padding

def forward(self, input):
weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
weight = (
weight[:, :, 1:, 1:]
+ weight[:, :, :-1, 1:]
+ weight[:, :, 1:, :-1]
+ weight[:, :, :-1, :-1]
) / 4

out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)

return out


class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
@@ -61,22 +119,64 @@ def forward(self, input):
return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)


class BlurFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)

grad_input = F.conv2d(
grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
)

return grad_input

@staticmethod
def backward(ctx, gradgrad_output):
kernel, kernel_flip = ctx.saved_tensors

grad_input = F.conv2d(
gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
)

return grad_input, None, None


class BlurFunction(Function):
@staticmethod
def forward(ctx, input, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)

output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

return output

@staticmethod
def backward(ctx, grad_output):
kernel, kernel_flip = ctx.saved_tensors

grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

return grad_input, None, None


blur = BlurFunction.apply


class Blur(nn.Module):
def __init__(self, channel):
super().__init__()

weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
weight = weight.view(1, 1, 3, 3)
weight = weight / weight.sum()
weight_flip = torch.flip(weight, [2, 3])

self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

def forward(self, input):
return F.conv2d(
input,
self.weight,
padding=1,
groups=input.shape[1],
)
return blur(input, self.weight, self.weight_flip)
# return F.conv2d(input, self.weight, padding=1, groups=input.shape[1])


class EqualConv2d(nn.Module):
@@ -115,8 +215,8 @@ def __init__(
padding,
kernel_size2=None,
padding2=None,
pixel_norm=True,
spectral_norm=False,
downsample=False,
fused=False,
):
super().__init__()

@@ -130,15 +230,36 @@ def __init__(
if kernel_size2 is not None:
kernel2 = kernel_size2

self.conv = nn.Sequential(
self.conv1 = nn.Sequential(
EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
nn.LeakyReLU(0.2),
EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
nn.LeakyReLU(0.2),
)

if downsample:
if fused:
self.conv2 = nn.Sequential(
Blur(out_channel),
FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),
nn.LeakyReLU(0.2),
)

else:
self.conv2 = nn.Sequential(
Blur(out_channel),
EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
nn.AvgPool2d(2),
nn.LeakyReLU(0.2),
)

else:
self.conv2 = nn.Sequential(
EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
nn.LeakyReLU(0.2),
)

def forward(self, input):
out = self.conv(input)
out = self.conv1(input)
out = self.conv2(out)

return out

@@ -195,16 +316,37 @@ def __init__(
padding=1,
style_dim=512,
initial=False,
upsample=False,
fused=False,
):
super().__init__()

if initial:
self.conv1 = ConstantInput(in_channel)

else:
self.conv1 = EqualConv2d(
in_channel, out_channel, kernel_size, padding=padding
)
if upsample:
if fused:
self.conv1 = nn.Sequential(
FusedUpsample(
in_channel, out_channel, kernel_size, padding=padding
),
Blur(out_channel),
)

else:
self.conv1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
EqualConv2d(
in_channel, out_channel, kernel_size, padding=padding
),
Blur(out_channel),
)

else:
self.conv1 = EqualConv2d(
in_channel, out_channel, kernel_size, padding=padding
)

self.noise1 = equal_lr(NoiseInjection(out_channel))
self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
@@ -230,20 +372,20 @@ def forward(self, input, style, noise):


class Generator(nn.Module):
def __init__(self, code_dim):
def __init__(self, code_dim, fused=True):
super().__init__()

self.progression = nn.ModuleList(
[
StyledConvBlock(512, 512, 3, 1, initial=True),
StyledConvBlock(512, 512, 3, 1),
StyledConvBlock(512, 512, 3, 1),
StyledConvBlock(512, 512, 3, 1),
StyledConvBlock(512, 256, 3, 1),
StyledConvBlock(256, 128, 3, 1),
StyledConvBlock(128, 64, 3, 1),
StyledConvBlock(64, 32, 3, 1),
StyledConvBlock(32, 16, 3, 1),
StyledConvBlock(512, 512, 3, 1, initial=True), # 4
StyledConvBlock(512, 512, 3, 1, upsample=True), # 8
StyledConvBlock(512, 512, 3, 1, upsample=True), # 16
StyledConvBlock(512, 512, 3, 1, upsample=True), # 32
StyledConvBlock(512, 256, 3, 1, upsample=True), # 64
StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128
StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256
StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512
StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024
]
)

@@ -291,11 +433,7 @@ def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
if i > 0 and step > 0:
out_prev = out

upsample = F.interpolate(
out, scale_factor=2, mode='bilinear', align_corners=False
)
# upsample = self.blur(upsample)
out = conv(upsample, style_step, noise[i])
out = conv(out, style_step, noise[i])

else:
out = conv(out, style_step, noise[i])
@@ -305,7 +443,7 @@ def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):

if i > 0 and 0 <= alpha < 1:
skip_rgb = self.to_rgb[i - 1](out_prev)
skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='bilinear', align_corners=False)
skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
out = (1 - alpha) * skip_rgb + alpha * out

break
@@ -369,19 +507,19 @@ def mean_style(self, input):


class Discriminator(nn.Module):
def __init__(self):
def __init__(self, fused=True):
super().__init__()

self.progression = nn.ModuleList(
[
ConvBlock(16, 32, 3, 1),
ConvBlock(32, 64, 3, 1),
ConvBlock(64, 128, 3, 1),
ConvBlock(128, 256, 3, 1),
ConvBlock(256, 512, 3, 1),
ConvBlock(512, 512, 3, 1),
ConvBlock(512, 512, 3, 1),
ConvBlock(512, 512, 3, 1),
ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512
ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256
ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128
ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64
ConvBlock(256, 512, 3, 1, downsample=True), # 32
ConvBlock(512, 512, 3, 1, downsample=True), # 16
ConvBlock(512, 512, 3, 1, downsample=True), # 8
ConvBlock(512, 512, 3, 1, downsample=True), # 4
ConvBlock(513, 512, 3, 1, 4, 0),
]
)
@@ -422,17 +560,9 @@ def forward(self, input, step=0, alpha=-1):
out = self.progression[index](out)

if i > 0:
# out = F.avg_pool2d(out, 2)
out = F.interpolate(
out, scale_factor=0.5, mode='bilinear', align_corners=False
)

if i == step and 0 <= alpha < 1:
# skip_rgb = F.avg_pool2d(input, 2)
skip_rgb = self.from_rgb[index + 1](input)
skip_rgb = F.interpolate(
skip_rgb, scale_factor=0.5, mode='bilinear', align_corners=False
)
skip_rgb = F.avg_pool2d(input, 2)
skip_rgb = self.from_rgb[index + 1](skip_rgb)

out = (1 - alpha) * skip_rgb + alpha * out

0 comments on commit d4c0438

Please sign in to comment.