-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't load optimizer state due to state_steps
#1083
Comments
Hey, thanks for the detailed question! I think what you are doing is correct. #776 is largely different from your issue, which is related to the optimizer state. I am not sure whether you are running into problem 1 or problem 2 below or both.
For 1, I suggest you just use torch.load and torch.save manually and patch the checkpoint so that they are compatible with 1.12. You can save 2 version of the checkpoints (one for pre 1.12, one for post 1.12) and load the correct one to avoid crashes. For 2, that would be a bug. Please send us a minimal reproduction case if you can. PR to fix is even more awesome! ;-) |
Hi, I think it is the second alternative. Saving a checkpoint and then running Here is a minimal reproduction - https://gist.github.com/rowhanm/71272f157d8c9450d6b1c7639a612126. I've narrowed down the problem to be this line here - https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2443 and I am able to fix the issue in my script by converting the state I don't have too much context on what the comment in the source "
|
I see. This makes sense. We likely don't have test case that catches this issue. I will find a time to fix this. |
btw, here is the error I got when running your sample code:
|
yep, that is the main issue :) Adam expects I can throw in a small test case + PR that fixes it in a bit. Again, not sure if that is the best possible fix since the
|
Thanks for trying a fix! My best memory is that this is needed because if the step is a singleton tensor, then it maybe treated like a sharded optimizer state and gets handled by the gather function. In a way, this step scalar is assumed to be the same across all ranks, which is true for FSDP at least. Maybe there are reasons why it changed from scalar to a tensor in the first place but I haven't looked into it. |
BTW, when I ran your test code with pt 1.8, it gave a different error in the loss function, which is very interesting too. |
Super weird. I had only tested on pyt 1.12 which gives this error and 1.11 as expected does not since Adam expects |
no need to worry, but here is the error of 1.8
|
hmmm...not sure if it's a mixed precision issue. Seems like something I've seen before with incorrect typecasting when using AMP |
Hi, I recently upgraded to PyTorch 1.12 and have had issues with loading a saved optimizer state using FSDP here and the issue seems something that is addressed in comments here -
fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Line 2436 in 4975b05
From what I understand, Adam's
step
state changed into a singleton tensor and when I callgather_full_optim_state_dict()
thisstep
is converted to an int.Sample saving dict code:
Now when I load this optim state dict back - I do the following:
This always fails the assertion in the Adam code - https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py#L204 because I imagine the step was converted to an
int
within FSDP and Adam expects it to be a singleton tensor.My question is am I saving the state dict correctly? Do I need to call
optimizer.state_dict()
on top ofmodel.gather_full_optim_state_dict()
?A workaround I'm using to get things to bypass the assertion is to convert the
ints
back to singleton tensors in the adamw function however that does not seem safe. Any thoughts?Apologies if my understanding is incorrect, I followed some of the discussion here - #776 for the state_dict saving logic.
The text was updated successfully, but these errors were encountered: