-
-
Notifications
You must be signed in to change notification settings - Fork 775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance gap between Unet and efficient Unet #72
Comments
@jacobwjs thanks for sharing these results! the EMA decay rate is actually defaulted to 0.9999, so i'd definitely run it for longer than 16k steps before passing judgement but good to see the differences visually between the two unet designs! |
Fair point! Those runs were killed a bit early indeed, but I've run this a few times pushing past 100K and didn't see much difference. Happy to run any bleeding edge ideas through this and post results when you add/change something significant so we can start baselining. Please reach out when you want to test. |
@jacobwjs oh dang, do you have a link to the wandb for the 100k run? is this on the latest version? (fixed a bug with the time conditioning in the final resnet block, but not sure if that will make a difference) |
Something I found useful in my own experiments was adding the time conditioning to every ResNetBlock in the chain. This is similar to how the OpenAI Unets function, and seems to have negligible compute performance penalty from my testing. I did extend that further to add the conditioning as well to everything in the chain (provided cross embeds were set), but I'm not sure that's strictly necessary. |
@marunine ❤️ ❤️ thank you! ugh, this is why i love open source |
@marunine let me add it now, and credit you for this finding |
Same thing here https://github.com/yang-song/score_sde_pytorch/blob/main/models/ddpm.py |
ok, v0.9.0 should have the time conditioning across all the resnet blocks |
@samedii do you mean the ddpm architecture in Yang Song's repository also exhibit the same color shifting issue? the architecture in his repo isn't identical with the Imagen one |
Sorry no I just meant that they are also adding the conditioning to every ResNetBlock and it's been working better for me than imagen's Unet when I tried to replace it. I wouldn't be surprised if your Unet arch might be better after this change. |
@samedii great! yea, i may just add it to all my ddpm related repositories this all feels like the early days of attention and transformers a few years back, so much to learn |
ohh, i had it already in my other repository https://github.com/lucidrains/dalle2-pytorch , it was because of the diagrams in Imagen paper that i removed it |
I think what implies that it's still there (even in the Efficient Unet) is the positioning of |
@marunine yes I do believe you are right! thank you! |
Thanks for asking, will start keeping historical runs around. Just started syncing to W&B, and all previous runs were blown away unfortunately. This behavior was v0.0.7, and what I linked to above was v0.0.8. Currently re-running v0.0.8 with reduced precision (fp16) just to verify sane behavior. Will fire up a run on v0.0.9 later tonight. Not throwing any significant compute at these quick & dirty tests, so expect a few days. |
Also noticed this today while trying to squeeze low FIDs on CIFAR-10 and inspecting the guided diffusion code; could not get below FIDs~20 with the curent conditioning. Will see if the time conditioning in every block will do it, got 3 slightly variating runs pumping. |
@KhrulkovV if anything (good or bad) comes out of this please share configs/results. I'm very interested in knowing what works and doesn't before launching a big training session. |
I am also experimenting color shifting with the efficient nets, it usually starts happening after step 5k. It didn't happen before with the vanilla Unet. I am trying to find out what might be the problem! |
oh bummer 😢 could you share what you are seeing? another thing we should rule out is the scaling of the skip connections (can be turned off by setting |
I think I'm seeing artifacts in some of the upscaler Unets training (though not color shifting) and I'm wondering if it's to do with extra upsample/downsample pass on the last layer: imagen-pytorch/imagen_pytorch/imagen_pytorch.py Line 1174 in 29f6475
I don't think should be a difference in whether or not the upsample/downsample layer is present compared to the base Unet, just the positioning of it, eg. imagen-pytorch/imagen_pytorch/imagen_pytorch.py Line 1175 in 29f6475
Would have to change to match this as well. The last bit that would change is the size of the hiddens for the last layer, they'd no longer be downsampled at all for that final upsample layer, making this no longer imagen-pytorch/imagen_pytorch/imagen_pytorch.py Line 1194 in 29f6475
The last bit I can think of that's strange about this is where the hiddens are added - in the OpenAI implementation they're appended after the downsample, whereas in this one they're appended before it. imagen-pytorch/imagen_pytorch/imagen_pytorch.py Line 1399 in 29f6475
I do have good performance on the base Unet the way it is (I'm sure having more data doesn't hurt) but I guess that also hurts efficiency since the channels are doubling. |
@marunine thank you for your insights as usual! i think removing the downsampling block in the last layer is worth trying for the memory efficient case, but i'll do that on a separate branch since it is a rather big change decided to make two changes to try to figure out this issue, one is to make sure memory efficient downsampling block does not change the channel dimensions on downsample (so it is more like the base unet, which we know is trouble free). the other is to add an option 032bfa8#diff-edef3c5fe92797a22c0b8fc6cca1d57b4b84ef03dfdfe802ed9147e21fc88109R995 to allow for earlier concatenation of the hiddens (just set this to |
After seeing your code change I realize that's probably the main culprit - the channel reduction on the downsampling doesn't match what the paper does and essentially halves the channels if my intuition is correct. I'll give that one a try first, thanks for finding that! |
oh gosh, yes, that is pretty bad 😢 i guess worst comes to worst, if we can't figure it out, then we can just use the base unet and i can try to get the partial reversibility working from https://arxiv.org/abs/1906.06148 |
hopefully! |
You all move too fast, can't keep up :) Looks like we're iterating towards the culprit. I'd like to pin down exactly what's going on with the efficient unet, which means a sweep across a few values at this point. Based on quick glance the current options seem to be,
What else would be good? |
those 2 options would be great to start with! and also, if you have spare compute, what @marunine proposed with removing the last downsampling layer, which is implemented at the latest commit here https://github.com/lucidrains/imagen-pytorch/tree/pw/memory-efficient-no-last-layer-downsample |
I'm still getting artifacts even with the channels no longer being reduced, but I went through and traced the size of the tensors throughout the execution of the Unet. Using the default Unet sizes on the readme, with the normal unet on the left and the memory efficient on the right you can see the differences in dimensions: The paper, for reference, is: The resolution in the middle convolutions is technically matching the paper in the efficient case, whereas the base Unet does not. For the purposes of matching the existing base unet the test branch is useful, but does differ in architecture. I'll try a run to see how it fares. I'm still unclear on what the dimensions of the skip connection/hiddens are supposed to be per the paper, but as noted they are half the resolution in the memory efficient version. |
one more minor version update incoming... ugh so many breaking changes |
@KhrulkovV done b8c3d35 |
After carefully looking into the guided-diffusion Unet I am very confused, e.g, they use (num_resnet_blocks) on the down branch and (num_resnet_blocks + 1) blocks on the up branch. Probably it is reasonable but still Also, in both branches upsample/downsample follows convolution contrary to what they are saying about efficient Unet motivation |
that sounds probably not that important 😄 but good to know! |
ok, do let me know if the skip connections end up speeding up convergence! i'm going to work on some elucidating ddpm code in another repo |
Just to add a comment on the efficient Unet training. It has some loss spikes during the training, while the base Unet does not present this behavior: (the first unet is never an efficient Unet, and the second one the read and green plots are efficient Unet and the orange one a non-efficient Unet)Nothing to worry about, but I just wanted to put it out there! Thanks for all the effort! 🚀 |
@pacocp awesome, that is good to know! in general anything with attention or transformer blocks will exhibit a bit of rockiness (as long as it does not explode it is fine) but perhaps i can add some more skip connections in the memory efficient, if the early downsample is the culprit (and not the cross attention blocks) @pacocp more importantly, if you still see the color shifting problems that you posted earlier, do let me know! we need to defeat that problem lol |
I am training for longer now with the last changes. It takes some time, I would need to speed up the training. The last time I profiled the time it took between 1-2 seconds for the base Unet and between 3-4 seconds for the SRUnet (forward and update), so it takes me a while to train for a long number of steps (training on a single GPU). Right now, with 4k steps, it looks that it has started to shift the color, but let's see how it continues: |
@pacocp how large are your batch sizes? and yep, 4k is still pretty early, but if it persists past 20k then we are in trouble |
@lucidrains batch size is 64, but I am using 8 as maximum batch size to avoid the CUDA error 😅 |
@pacocp yup, 64 is good 😄 |
@lucidrains 9k steps and the color shifting doesn't show up yet 🥳 |
@pacocp woohoo! 🥳 |
@lucidrains no! it is TCGA data. Only a small portion to do some tests |
@pacocp ahh cool! show it to some pathologists at your hospital, they'll get a kick out of it 😉 |
anyone still seeing issues with the latest version of the unets? |
I had a run going prior to the upsampling with nearest changes, and that stopped color shifting on all samples after about 100k steps. It seemed to be based on how strongly the Unet understood the conditioning of the text itself, but I'm not entirely sure. I am trying out the elucidating network now after updating the branch. Once the base Unet is trained (it's looking a lot better than the original Imagen) I'm going to try out the memory efficient Unet and let you know the results. I think the extra residuals/nearest upsampling helps. |
@lucidrains I have a few checkboard artifacts in the first training steps, but I am sure they will disappear with more training. Also, I am testing the elucinated network, using the last version of the unets. Quite different results on the sampling for the same number of steps. This is 3000k on the elucinated version: and here for the "normal" version: Less "noisy" in the elucinated version, but quite more "abstract" in terms of the training dataset and how it should looks like. But it is only 3k steps, so let's see how it continuous! |
What is the recent understanding of color shifting? @lucidrains I am training a large scale efficient SR diffusion and getting some artifacts (still around 40k steps though, maybe it will disappear later) |
Hi @KhrulkovV, have you solved the color shifting problem? I also encountered this issue in all my experiments when predicting the noise, but it worked well when predicting the X_0. |
I have the same issue, anyone has idea? |
Hey! Color shifting was gone after training for longer and it was working fine after that |
Thanks @KhrulkovV ! May I ask how many images in your dataset and how many steps are needed to get over the color shift? In a test run, I am training 64->256 SR model on CelebA-HQ, which has 30000 images, with batch size 64, and after 30k steps(70 epoch), I still observe color shift. |
Hello! I am training a large scale model on text to image data (think LAION datasets); for a large SR model with 700m parameters color shifts were gone after ~150k iterations, for smaller one - about 200k ish steps seems fine The most critical piece in my opinion is the |
Thanks @KhrulkovV , that's very helpful. Eventually I will trained on large dataset and CelebA is just a sanity check. I think you are referring to init_conv_to_final_conv_residual which add a skip connection from first conv layer to the last? In addition, what is your unet parametrization? Some people in the thread says that unet needs to predict x_0 to avoid color shift. |
Yes, exactly |
…color-shifting issue, thanks to @marunine in lucidrains/imagen-pytorch#72
…color-shifting issue, thanks to @marunine in lucidrains/imagen-pytorch#72
Hi Phil,
Wanted to bring a bit more rigor and testing to your great work. Ran a few quick experiments with a simple dataset as mentioned in another thread, and noticed a large disparity in results between the two Unet configurations.
To note, the efficient version brings in large memory savings at what seems to be a sacrifice of fidelity. Tinting seems to still be hanging around a bit in both, but the efficient version it is quite noticeable.
See the report here: https://tinyurl.com/imagen-pytorch
The text was updated successfully, but these errors were encountered: