Skip to content

Commit

Permalink
condition the entire resnet blocks with time, seems to alleviate the …
Browse files Browse the repository at this point in the history
…color-shifting issue, thanks to @marunine in #72
  • Loading branch information
lucidrains committed Jun 22, 2022
1 parent 5986d11 commit 8fb4683
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ Not at the moment but one will likely be trained and open sourced within the yea

- <a href="https://github.com/marunine">Marunine</a> and <a href="https://github.com/Netruk44">Netruk44</a>, for reviewing code, sharing experimental results, and help with debugging

- <a href="https://github.com/marunine">Marunine</a> for providing a <a href="https://github.com/lucidrains/imagen-pytorch/issues/72#issuecomment-1163275757">potential solution</a> for a color shifting issue in the memory efficient u-nets

- You? It isn't done yet, chip in if you are a researcher or skilled ML engineer


Expand Down
20 changes: 10 additions & 10 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def __init__(
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Scale(skip_connection_scale)


def forward(self, x, cond = None, time_emb = None):
def forward(self, x, time_emb = None, cond = None):

scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def __init__(
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups, skip_connection_scale = skip_connect_scale),
nn.ModuleList([ResnetBlock(dim_out, dim_out, groups = groups, use_gca = use_global_context_attn, skip_connection_scale = skip_connect_scale) for _ in range(layer_num_resnet_blocks)]),
nn.ModuleList([ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn, skip_connection_scale = skip_connect_scale) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult),
downsample_klass(dim_out) if not memory_efficient and not is_last else None,
]))
Expand All @@ -1192,7 +1192,7 @@ def __init__(

self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups, skip_connection_scale = skip_connect_scale),
nn.ModuleList([ResnetBlock(dim_in, dim_in, groups = groups, skip_connection_scale = skip_connect_scale, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
nn.ModuleList([ResnetBlock(dim_in, dim_in, time_cond_dim = time_cond_dim, groups = groups, skip_connection_scale = skip_connect_scale, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_in, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult),
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
]))
Expand Down Expand Up @@ -1390,30 +1390,30 @@ def forward(
if exists(pre_downsample):
x = pre_downsample(x)

x = init_block(x, c, t)
x = init_block(x, t, c)

for resnet_block in resnet_blocks:
x = resnet_block(x)
x = resnet_block(x, t)

x = attn_block(x)
hiddens.append(x)

if exists(post_downsample):
x = post_downsample(x)

x = self.mid_block1(x, c, t)
x = self.mid_block1(x, t, c)

if exists(self.mid_attn):
x = self.mid_attn(x)

x = self.mid_block2(x, c, t)
x = self.mid_block2(x, t, c)

for init_block, resnet_blocks, attn_block, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t)
x = init_block(x, t, c)

for resnet_block in resnet_blocks:
x = resnet_block(x)
x = resnet_block(x, t)

x = attn_block(x)
x = upsample(x)
Expand All @@ -1424,7 +1424,7 @@ def forward(
x = torch.cat((x, init_conv_residual), dim = 1)

if exists(self.final_res_block):
x = self.final_res_block(x, None, t)
x = self.final_res_block(x, t)

return self.final_conv(x)

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.8'
__version__ = '0.9.0'

0 comments on commit 8fb4683

Please sign in to comment.