Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance gap between Unet and efficient Unet #72

Closed
jacobwjs opened this issue Jun 21, 2022 · 149 comments
Closed

Performance gap between Unet and efficient Unet #72

jacobwjs opened this issue Jun 21, 2022 · 149 comments

Comments

@jacobwjs
Copy link
Contributor

Hi Phil,

Wanted to bring a bit more rigor and testing to your great work. Ran a few quick experiments with a simple dataset as mentioned in another thread, and noticed a large disparity in results between the two Unet configurations.

To note, the efficient version brings in large memory savings at what seems to be a sacrifice of fidelity. Tinting seems to still be hanging around a bit in both, but the efficient version it is quite noticeable.

See the report here: https://tinyurl.com/imagen-pytorch

@lucidrains
Copy link
Owner

@jacobwjs thanks for sharing these results! the EMA decay rate is actually defaulted to 0.9999, so i'd definitely run it for longer than 16k steps before passing judgement

but good to see the differences visually between the two unet designs!

@jacobwjs
Copy link
Contributor Author

jacobwjs commented Jun 22, 2022

Fair point! Those runs were killed a bit early indeed, but I've run this a few times pushing past 100K and didn't see much difference.

Happy to run any bleeding edge ideas through this and post results when you add/change something significant so we can start baselining. Please reach out when you want to test.

@lucidrains
Copy link
Owner

@jacobwjs oh dang, do you have a link to the wandb for the 100k run? is this on the latest version? (fixed a bug with the time conditioning in the final resnet block, but not sure if that will make a difference)

@marunine
Copy link

Something I found useful in my own experiments was adding the time conditioning to every ResNetBlock in the chain.

marunine@ab88a2d

This is similar to how the OpenAI Unets function, and seems to have negligible compute performance penalty from my testing.

I did extend that further to add the conditioning as well to everything in the chain (provided cross embeds were set), but I'm not sure that's strictly necessary.

@lucidrains
Copy link
Owner

@marunine ❤️ ❤️ thank you! ugh, this is why i love open source

@lucidrains
Copy link
Owner

@marunine let me add it now, and credit you for this finding

lucidrains added a commit that referenced this issue Jun 22, 2022
@samedii
Copy link

samedii commented Jun 22, 2022

Same thing here https://github.com/yang-song/score_sde_pytorch/blob/main/models/ddpm.py
I've gotten worse results in another project trying the imagen architecture. This might well be the reason.

@lucidrains
Copy link
Owner

ok, v0.9.0 should have the time conditioning across all the resnet blocks

@lucidrains
Copy link
Owner

lucidrains commented Jun 22, 2022

Same thing here https://github.com/yang-song/score_sde_pytorch/blob/main/models/ddpm.py I've gotten worse results in another project trying the imagen architecture. This might well be the reason.

@samedii do you mean the ddpm architecture in Yang Song's repository also exhibit the same color shifting issue? the architecture in his repo isn't identical with the Imagen one

@samedii
Copy link

samedii commented Jun 22, 2022

Sorry no I just meant that they are also adding the conditioning to every ResNetBlock and it's been working better for me than imagen's Unet when I tried to replace it. I wouldn't be surprised if your Unet arch might be better after this change.

@lucidrains
Copy link
Owner

@samedii great! yea, i may just add it to all my ddpm related repositories

this all feels like the early days of attention and transformers a few years back, so much to learn

@lucidrains
Copy link
Owner

ohh, i had it already in my other repository https://github.com/lucidrains/dalle2-pytorch , it was because of the diagrams in Imagen paper that i removed it

@marunine
Copy link

I think what implies that it's still there (even in the Efficient Unet) is the positioning of CombineEmbs in Figure A.28. If it was only passed into SelfAttention then it would be after the ResNetBlock group I think. Standard unets have the time conditioning on all the blocks so I think it makes sense they kept it there as well.

@lucidrains
Copy link
Owner

@marunine yes I do believe you are right! thank you!

@jacobwjs
Copy link
Contributor Author

@jacobwjs oh dang, do you have a link to the wandb for the 100k run? is this on the latest version? (fixed a bug with the time conditioning in the final resnet block, but not sure if that will make a difference)

Thanks for asking, will start keeping historical runs around. Just started syncing to W&B, and all previous runs were blown away unfortunately.

This behavior was v0.0.7, and what I linked to above was v0.0.8. Currently re-running v0.0.8 with reduced precision (fp16) just to verify sane behavior. Will fire up a run on v0.0.9 later tonight.

Not throwing any significant compute at these quick & dirty tests, so expect a few days.

@KhrulkovV
Copy link

KhrulkovV commented Jun 23, 2022

Something I found useful in my own experiments was adding the time conditioning to every ResNetBlock in the chain.

marunine@ab88a2d

This is similar to how the OpenAI Unets function, and seems to have negligible compute performance penalty from my testing.

I did extend that further to add the conditioning as well to everything in the chain (provided cross embeds were set), but I'm not sure that's strictly necessary.

Also noticed this today while trying to squeeze low FIDs on CIFAR-10 and inspecting the guided diffusion code; could not get below FIDs~20 with the curent conditioning. Will see if the time conditioning in every block will do it, got 3 slightly variating runs pumping.

@jacobwjs
Copy link
Contributor Author

@KhrulkovV if anything (good or bad) comes out of this please share configs/results. I'm very interested in knowing what works and doesn't before launching a big training session.

@pacocp
Copy link
Contributor

pacocp commented Jun 24, 2022

I am also experimenting color shifting with the efficient nets, it usually starts happening after step 5k. It didn't happen before with the vanilla Unet. I am trying to find out what might be the problem!

@lucidrains
Copy link
Owner

I am also experimenting color shifting with the efficient nets, it usually starts happening after step 5k. It didn't happen before with the vanilla Unet. I am trying to find out what might be the problem!

oh bummer 😢 could you share what you are seeing?

another thing we should rule out is the scaling of the skip connections (can be turned off by setting scale_resnet_skip_connection = False in the Unet class)

@marunine
Copy link

I think I'm seeing artifacts in some of the upscaler Unets training (though not color shifting) and I'm wondering if it's to do with extra upsample/downsample pass on the last layer:

downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,

I don't think should be a difference in whether or not the upsample/downsample layer is present compared to the base Unet, just the positioning of it, eg. if memory_efficient and not is_last else dim_in would be used instead of the current logic.

ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups, skip_connection_scale = skip_connect_scale),

Would have to change to match this as well.

The last bit that would change is the size of the hiddens for the last layer, they'd no longer be downsampled at all for that final upsample layer, making this no longer dim_out * 2, but dim_out + hidden_size:

ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups, skip_connection_scale = skip_connect_scale),

The last bit I can think of that's strange about this is where the hiddens are added - in the OpenAI implementation they're appended after the downsample, whereas in this one they're appended before it.

I do have good performance on the base Unet the way it is (I'm sure having more data doesn't hurt) but I guess that also hurts efficiency since the channels are doubling.

@lucidrains
Copy link
Owner

lucidrains commented Jun 25, 2022

@marunine thank you for your insights as usual!

i think removing the downsampling block in the last layer is worth trying for the memory efficient case, but i'll do that on a separate branch since it is a rather big change

decided to make two changes to try to figure out this issue, one is to make sure memory efficient downsampling block does not change the channel dimensions on downsample (so it is more like the base unet, which we know is trouble free). the other is to add an option 032bfa8#diff-edef3c5fe92797a22c0b8fc6cca1d57b4b84ef03dfdfe802ed9147e21fc88109R995 to allow for earlier concatenation of the hiddens (just set this to True for the memory efficient unet)

@marunine
Copy link

After seeing your code change I realize that's probably the main culprit - the channel reduction on the downsampling doesn't match what the paper does and essentially halves the channels if my intuition is correct.

I'll give that one a try first, thanks for finding that!

@pacocp
Copy link
Contributor

pacocp commented Jun 25, 2022

I am also experimenting color shifting with the efficient nets, it usually starts happening after step 5k. It didn't happen before with the vanilla Unet. I am trying to find out what might be the problem!

oh bummer 😢 could you share what you are seeing?

another thing we should rule out is the scaling of the skip connections (can be turned off by setting scale_resnet_skip_connection = False in the Unet class)

This is one example where they shift to the red channel:
imagen

And another one where they shift to the green channel:

imagen

I think maybe it's what you are saying about the downsampling. I will give it a try and check what it is doing to the channel dimensions! Thanks a lot! 🚀

@lucidrains
Copy link
Owner

oh gosh, yes, that is pretty bad 😢

i guess worst comes to worst, if we can't figure it out, then we can just use the base unet and i can try to get the partial reversibility working from https://arxiv.org/abs/1906.06148

@lucidrains
Copy link
Owner

After seeing your code change I realize that's probably the main culprit - the channel reduction on the downsampling doesn't match what the paper does and essentially halves the channels if my intuition is correct.

I'll give that one a try first, thanks for finding that!

hopefully!

@jacobwjs
Copy link
Contributor Author

jacobwjs commented Jun 25, 2022

You all move too fast, can't keep up :)

Looks like we're iterating towards the culprit. I'd like to pin down exactly what's going on with the efficient unet, which means a sweep across a few values at this point.

Based on quick glance the current options seem to be,

  • scale_resnet_skip_connection = True & False
  • downsample_concat_hiddens_earlier = True & False

What else would be good?

@lucidrains
Copy link
Owner

You all move too fast, can't keep up :)

Looks like we're iterating towards the culprit. I'd like to pin down exactly what's going on with the efficient unet, which means a sweep across a few values at this point.

Based on quick glance the current options seem to be,

  • scale_resnet_skip_connection = True & False
  • downsample_concat_hiddens_earlier = True & False

What else would be good?

those 2 options would be great to start with! and also, if you have spare compute, what @marunine proposed with removing the last downsampling layer, which is implemented at the latest commit here https://github.com/lucidrains/imagen-pytorch/tree/pw/memory-efficient-no-last-layer-downsample

@marunine
Copy link

I'm still getting artifacts even with the channels no longer being reduced, but I went through and traced the size of the tensors throughout the execution of the Unet.

Using the default Unet sizes on the readme, with the normal unet on the left and the memory efficient on the right you can see the differences in dimensions:

image

The paper, for reference, is:

image

The resolution in the middle convolutions is technically matching the paper in the efficient case, whereas the base Unet does not. For the purposes of matching the existing base unet the test branch is useful, but does differ in architecture. I'll try a run to see how it fares.

I'm still unclear on what the dimensions of the skip connection/hiddens are supposed to be per the paper, but as noted they are half the resolution in the memory efficient version.

@lucidrains
Copy link
Owner

one more minor version update incoming... ugh so many breaking changes

@lucidrains
Copy link
Owner

@KhrulkovV done b8c3d35

@KhrulkovV
Copy link

KhrulkovV commented Jun 28, 2022

After carefully looking into the guided-diffusion Unet I am very confused, e.g, they use (num_resnet_blocks) on the down branch and (num_resnet_blocks + 1) blocks on the up branch. Probably it is reasonable but still

Also, in both branches upsample/downsample follows convolution contrary to what they are saying about efficient Unet motivation

@lucidrains
Copy link
Owner

After carefully looking into the guided-diffusion Unet I am very confused, e.g, they use (num_resnet_blocks) on the down branch and (num_resnet_blocks + 1) blocks on the up branch. Probably it is reasonable but still

Also, in both branches upsample/downsample follows convolution contrary to what they are saying about efficient Unet motivation

that sounds probably not that important 😄 but good to know!

@lucidrains
Copy link
Owner

ok, do let me know if the skip connections end up speeding up convergence! i'm going to work on some elucidating ddpm code in another repo

@pacocp
Copy link
Contributor

pacocp commented Jun 28, 2022

Just to add a comment on the efficient Unet training. It has some loss spikes during the training, while the base Unet does not present this behavior:

imagen

(the first unet is never an efficient Unet, and the second one the read and green plots are efficient Unet and the orange one a non-efficient Unet)

Nothing to worry about, but I just wanted to put it out there! Thanks for all the effort! 🚀

@lucidrains
Copy link
Owner

@pacocp awesome, that is good to know! in general anything with attention or transformer blocks will exhibit a bit of rockiness (as long as it does not explode it is fine)

but perhaps i can add some more skip connections in the memory efficient, if the early downsample is the culprit (and not the cross attention blocks)

@pacocp more importantly, if you still see the color shifting problems that you posted earlier, do let me know! we need to defeat that problem lol

@pacocp
Copy link
Contributor

pacocp commented Jun 28, 2022

@pacocp awesome, that is good to know! in general anything with attention or transformer blocks will exhibit a bit of rockiness (as long as it does not explode it is fine)

but perhaps i can add some more skip connections in the memory efficient, if the early downsample is the culprit (and not the cross attention blocks)

@pacocp more importantly, if you still see the color shifting problems that you posted earlier, do let me know! we need to defeat that problem lol

I am training for longer now with the last changes. It takes some time, I would need to speed up the training. The last time I profiled the time it took between 1-2 seconds for the base Unet and between 3-4 seconds for the SRUnet (forward and update), so it takes me a while to train for a long number of steps (training on a single GPU). Right now, with 4k steps, it looks that it has started to shift the color, but let's see how it continues:

imagen

@lucidrains
Copy link
Owner

lucidrains commented Jun 28, 2022

@pacocp how large are your batch sizes? and yep, 4k is still pretty early, but if it persists past 20k then we are in trouble

@pacocp
Copy link
Contributor

pacocp commented Jun 28, 2022

@lucidrains batch size is 64, but I am using 8 as maximum batch size to avoid the CUDA error 😅

@lucidrains
Copy link
Owner

@pacocp yup, 64 is good 😄

@pacocp
Copy link
Contributor

pacocp commented Jun 29, 2022

@lucidrains 9k steps and the color shifting doesn't show up yet 🥳
imagen

@lucidrains
Copy link
Owner

@pacocp woohoo! 🥳

@pacocp
Copy link
Contributor

pacocp commented Jun 29, 2022

@lucidrains no! it is TCGA data. Only a small portion to do some tests

@lucidrains
Copy link
Owner

@pacocp ahh cool! show it to some pathologists at your hospital, they'll get a kick out of it 😉

@lucidrains
Copy link
Owner

anyone still seeing issues with the latest version of the unets?

@marunine
Copy link

I had a run going prior to the upsampling with nearest changes, and that stopped color shifting on all samples after about 100k steps. It seemed to be based on how strongly the Unet understood the conditioning of the text itself, but I'm not entirely sure.
I was still getting checkerboarding artifacts but they were relatively minor on the output.

I am trying out the elucidating network now after updating the branch. Once the base Unet is trained (it's looking a lot better than the original Imagen) I'm going to try out the memory efficient Unet and let you know the results.

I think the extra residuals/nearest upsampling helps.

@pacocp
Copy link
Contributor

pacocp commented Jun 30, 2022

@lucidrains I have a few checkboard artifacts in the first training steps, but I am sure they will disappear with more training. Also, I am testing the elucinated network, using the last version of the unets. Quite different results on the sampling for the same number of steps. This is 3000k on the elucinated version:

imagen

and here for the "normal" version:

imagen

Less "noisy" in the elucinated version, but quite more "abstract" in terms of the training dataset and how it should looks like. But it is only 3k steps, so let's see how it continuous!

@lucidrains
Copy link
Owner

@pacocp @marunine thanks! and yes, the elucidated training trajectory is noticeably different on my machine too

ok, i think we can close this issue, and any elucidated imagen issues can go in the other thread

@KhrulkovV
Copy link

KhrulkovV commented Jul 10, 2022

What is the recent understanding of color shifting? @lucidrains I am training a large scale efficient SR diffusion and getting some artifacts (still around 40k steps though, maybe it will disappear later)
Screen Shot 2022-07-10 at 22 34 55
left - original images, right - ddpm samples with 1000 steps

@BIGJUN777
Copy link

What is the recent understanding of color shifting? @lucidrains I am training a large scale efficient SR diffusion and getting some artifacts (still around 40k steps though, maybe it will disappear later) Screen Shot 2022-07-10 at 22 34 55 left - original images, right - ddpm samples with 1000 steps

Hi @KhrulkovV, have you solved the color shifting problem? I also encountered this issue in all my experiments when predicting the noise, but it worked well when predicting the X_0.

@XavierXiao
Copy link

XavierXiao commented Oct 8, 2022

What is the recent understanding of color shifting? @lucidrains I am training a large scale efficient SR diffusion and getting some artifacts (still around 40k steps though, maybe it will disappear later) Screen Shot 2022-07-10 at 22 34 55 left - original images, right - ddpm samples with 1000 steps

Hi @KhrulkovV, have you solved the color shifting problem? I also encountered this issue in all my experiments when predicting the noise, but it worked well when predicting the X_0.

I have the same issue, anyone has idea?

@KhrulkovV
Copy link

Hey! Color shifting was gone after training for longer and it was working fine after that

@XavierXiao
Copy link

XavierXiao commented Oct 8, 2022

Hey! Color shifting was gone after training for longer and it was working fine after that

Thanks @KhrulkovV ! May I ask how many images in your dataset and how many steps are needed to get over the color shift? In a test run, I am training 64->256 SR model on CelebA-HQ, which has 30000 images, with batch size 64, and after 30k steps(70 epoch), I still observe color shift.

low res:
lowres

sample:
sample

@KhrulkovV
Copy link

Hey! Color shifting was gone after training for longer and it was working fine after that

Thanks @KhrulkovV ! May I ask how many images in your dataset and how many steps are needed to get over the color shift? In a test run, I am training 64->256 SR model on CelebA-HQ, which has 30000 images, with batch size 64, and after 30k steps(70 epoch), I still observe color shift.

low res: lowres

sample: sample

Hello! I am training a large scale model on text to image data (think LAION datasets); for a large SR model with 700m parameters color shifts were gone after ~150k iterations, for smaller one - about 200k ish steps seems fine

The most critical piece in my opinion is the skip first to last parameter, without it it was not working that well

@XavierXiao
Copy link

XavierXiao commented Oct 9, 2022

Hey! Color shifting was gone after training for longer and it was working fine after that

Thanks @KhrulkovV ! May I ask how many images in your dataset and how many steps are needed to get over the color shift? In a test run, I am training 64->256 SR model on CelebA-HQ, which has 30000 images, with batch size 64, and after 30k steps(70 epoch), I still observe color shift.
low res: lowres
sample: sample

Hello! I am training a large scale model on text to image data (think LAION datasets); for a large SR model with 700m parameters color shifts were gone after ~150k iterations, for smaller one - about 200k ish steps seems fine

The most critical piece in my opinion is the skip first to last parameter, without it it was not working that well

Thanks @KhrulkovV , that's very helpful. Eventually I will trained on large dataset and CelebA is just a sanity check. I think you are referring to init_conv_to_final_conv_residual which add a skip connection from first conv layer to the last?

In addition, what is your unet parametrization? Some people in the thread says that unet needs to predict x_0 to avoid color shift.

@KhrulkovV
Copy link

Hey! Color shifting was gone after training for longer and it was working fine after that

Thanks @KhrulkovV ! May I ask how many images in your dataset and how many steps are needed to get over the color shift? In a test run, I am training 64->256 SR model on CelebA-HQ, which has 30000 images, with batch size 64, and after 30k steps(70 epoch), I still observe color shift.
low res: lowres
sample: sample

Hello! I am training a large scale model on text to image data (think LAION datasets); for a large SR model with 700m parameters color shifts were gone after ~150k iterations, for smaller one - about 200k ish steps seems fine
The most critical piece in my opinion is the skip first to last parameter, without it it was not working that well

Thanks @KhrulkovV , that's very helpful. Eventually I will trained on large dataset and CelebA is just a sanity check. I think you are referring to init_conv_to_final_conv_residual which add a skip connection from first conv layer to the last?

In addition, what is your unet parametrization? Some people in the thread says that unet needs to predict x_0 to avoid color shift.

Yes, exactly
I am training in a default setting (noise prediction) it probably matters only at the start

AIDevMonster added a commit to AIDevMonster/Text-to-Image-Neural-Network-Pytorch that referenced this issue Jun 27, 2023
whiteghostDev added a commit to whiteghostDev/Text-to-Image-Neural-Network-Pytorch that referenced this issue Aug 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests