Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load optimizer state due to state_steps #1083

Open
rowhanm opened this issue Sep 26, 2022 · 10 comments
Open

Can't load optimizer state due to state_steps #1083

rowhanm opened this issue Sep 26, 2022 · 10 comments

Comments

@rowhanm
Copy link

rowhanm commented Sep 26, 2022

Hi, I recently upgraded to PyTorch 1.12 and have had issues with loading a saved optimizer state using FSDP here and the issue seems something that is addressed in comments here -

# In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton

From what I understand, Adam's step state changed into a singleton tensor and when I call gather_full_optim_state_dict() this step is converted to an int.

Sample saving dict code:

model = FSDP(model, ...)
# call on all ranks
optim_state = model.gather_full_optim_state_dict(optimizer)
if rank == 0:
    # save only on rank 0
    checkpoint = {
        'optimizer': optim_state,
        ...
    }
    torch.save(checkpoint)

Now when I load this optim state dict back - I do the following:

model = FSDP(model, ...)
torch.distributed.barrier()
# on all ranks
checkpoint = torch.load(snapshot_name)
curr_opt_state_dict = checkpoint["optimizer"]
optim_shard_dict = model.get_shard_from_optim_state_dict(curr_opt_state_dict)
optimizer.load_state_dict(optim_shard_dict)

This always fails the assertion in the Adam code - https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py#L204 because I imagine the step was converted to an int within FSDP and Adam expects it to be a singleton tensor.

My question is am I saving the state dict correctly? Do I need to call optimizer.state_dict() on top of model.gather_full_optim_state_dict()?

A workaround I'm using to get things to bypass the assertion is to convert the ints back to singleton tensors in the adamw function however that does not seem safe. Any thoughts?

Apologies if my understanding is incorrect, I followed some of the discussion here - #776 for the state_dict saving logic.

@min-xu-ai
Copy link
Contributor

Hey, thanks for the detailed question! I think what you are doing is correct. #776 is largely different from your issue, which is related to the optimizer state.

I am not sure whether you are running into problem 1 or problem 2 below or both.

  1. load pre-1.12 checkpoint and crash
  2. using same version (Post 1.12), save a checkpoint and load it back causing crashes

For 1, I suggest you just use torch.load and torch.save manually and patch the checkpoint so that they are compatible with 1.12. You can save 2 version of the checkpoints (one for pre 1.12, one for post 1.12) and load the correct one to avoid crashes.

For 2, that would be a bug. Please send us a minimal reproduction case if you can. PR to fix is even more awesome! ;-)

@rowhanm
Copy link
Author

rowhanm commented Sep 27, 2022

Hi, I think it is the second alternative. Saving a checkpoint and then running optimizer.step() (with or without load) causes a crash.

Here is a minimal reproduction - https://gist.github.com/rowhanm/71272f157d8c9450d6b1c7639a612126.
[python==3.7.5, pytorch==1.12.0, fairscale==0.4.6(can't upgrade due to being restricted to py3.7; doesn't matter since this function remains the same)]

I've narrowed down the problem to be this line here - https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2443 and I am able to fix the issue in my script by converting the state step back into a singleton tensor from an int.

I don't have too much context on what the comment in the source "comparison with original state dict would fail." means, and I'm not sure if my fix would cause any issues later. If there are no side-effects my proposed fix to the bug would be:

  1. either not do the original singleton tensor -> int conversion, or
  2. convert the step state back to int after serializing (not sure where exactly this should take place).

@min-xu-ai
Copy link
Contributor

I see. This makes sense. We likely don't have test case that catches this issue. I will find a time to fix this.

@min-xu-ai
Copy link
Contributor

btw, here is the error I got when running your sample code:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/private/home/m1n/git/fairscale/t.py", line 99, in demo_basic
    optimizer.step()
  File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
    return func(*args, **kwargs)
  File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/adamw.py", line 161, in step
    adamw(params_with_grad,
  File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/adamw.py", line 204, in adamw
    raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
RuntimeError: API has changed, `state_steps` argument must contain a list of singleton tensors

@rowhanm
Copy link
Author

rowhanm commented Sep 27, 2022

yep, that is the main issue :) Adam expects state_steps to be a singleton tensor but gather_full_optim_state_dict() converts it into an int. If you uncomment line 108-112 in my Gist, that basically fixes the issue.

I can throw in a small test case + PR that fixes it in a bit. Again, not sure if that is the best possible fix since the git blame on

for _, bufs in osd["state"].items():
doesn't quite tell me the reason why this needs to be added.

@min-xu-ai
Copy link
Contributor

Thanks for trying a fix!

My best memory is that this is needed because if the step is a singleton tensor, then it maybe treated like a sharded optimizer state and gets handled by the gather function. In a way, this step scalar is assumed to be the same across all ranks, which is true for FSDP at least. Maybe there are reasons why it changed from scalar to a tensor in the first place but I haven't looked into it.

@min-xu-ai
Copy link
Contributor

BTW, when I ran your test code with pt 1.8, it gave a different error in the loss function, which is very interesting too.

@rowhanm
Copy link
Author

rowhanm commented Sep 27, 2022

Super weird. I had only tested on pyt 1.12 which gives this error and 1.11 as expected does not since Adam expects state_steps to be an int.
I tried but can't test with 1.8 unfortunately due to not having a GPU with correct CUDA capabilities, could you tell me what the error is that you see with 1.8?

@min-xu-ai
Copy link
Contributor

no need to worry, but here is the error of 1.8


-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/private/home/m1n/git/fairscale/t.py", line 82, in demo_basic
    loss = criterion(preds, target)
  File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1047, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/functional.py", line 2693, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/functional.py", line 2388, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward


@rowhanm
Copy link
Author

rowhanm commented Sep 27, 2022

hmmm...not sure if it's a mixed precision issue. Seems like something I've seen before with incorrect typecasting when using AMP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants