Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.4.1] ValueError: Attempting to unscale FP16 gradients. #834

Closed
carmocca opened this issue Oct 27, 2021 · 10 comments
Closed

[0.4.1] ValueError: Attempting to unscale FP16 gradients. #834

carmocca opened this issue Oct 27, 2021 · 10 comments
Assignees
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@carmocca
Copy link
Contributor

🐛 Bug

Traceback (most recent call last):
  File "kk.py", line 43, in <module>
    trainer.fit(model, DataLoader(RandomDataset(32, 64), batch_size=2))
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
    self._run(model)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 922, in _run
    self._dispatch()
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _dispatch
    self.accelerator.start_training(self)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1000, in run_stage
    return self._run_train()
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1049, in _run_train
    self.fit_loop.run()
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
    batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 100, in run
    super().run(batch, batch_idx, dataloader_idx)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 147, in advance
    result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 201, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 395, in _optimizer_step
    model_ref.optimizer_step(
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1616, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 206, in step
    self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 128, in __optimizer_step
    trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 292, in optimizer_step
    make_optimizer_step = self.precision_plugin.pre_optimizer_step(
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 57, in pre_optimizer_step
    result = lambda_closure()  # native amp does not support closures
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 235, in _training_step_and_backward_closure
    result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 548, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 589, in backward
    result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 278, in backward
    closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss, optimizer)
  File "/home/carlos/venv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 77, in post_backward
    self.scaler.unscale_(optimizer)
  File "/home/carlos/venv/lib/python3.8/site-packages/fairscale/optim/grad_scaler.py", line 62, in unscale_
    super().unscale_(optimizer)
  File "/home/carlos/venv/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 279, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/home/carlos/venv/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 207, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")

Command

python script.py on a GPU machine

To Reproduce

import torch
from fairscale.nn import wrap
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.plugins import DDPFullyShardedPlugin


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class TestFSDPModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

    def forward(self, x):
        return self.layer(x)

    def configure_sharded_model(self):
        self.layer = wrap(self.layer)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


if __name__ == "__main__":
    model = TestFSDPModel()
    trainer = Trainer(gpus=1, plugins=DDPFullyShardedPlugin(), precision=16, max_epochs=1)
    trainer.fit(model, DataLoader(RandomDataset(32, 64), batch_size=2))

Expected behavior

I'm not entirely sure whether this is a problem on our end or in the release. But this did work with the 0.4.0 release.

cc @SeanNaren

Environment

pytorch-lightning==1.4.9
torch==1.9.0+cu111
fairscale==0.4.1

Thank you for your help!

@SeanNaren
Copy link

@anj-s has anything changed with respects to FSDP and gradients in FP16 mode?

@anj-s
Copy link
Contributor

anj-s commented Oct 28, 2021

@SeanNaren Nothing should have changed as far as gradients and fp16.

@carmocca I can reproduce this issue in FairScale 0.4.0 without lightning. Let me look into this issue to figure out root cause.

@anj-s anj-s added the FSDP FullyShardedDataParallel (zero-3) label Oct 28, 2021
@anupambhatnagar anupambhatnagar self-assigned this Oct 29, 2021
@anupambhatnagar
Copy link

anupambhatnagar commented Oct 29, 2021

I have been able to reproduce the issue. The root cause is the unscale_grads function call here. The value for allow_fp16 is set to False, hence your code is breaking.

Copy the _unscale_grads_ function from fairscale GradScaler class (https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/grad_scaler.py#L18-L21) and add it to the ShardedGradScaler class and now the code works. (Note the parameters on Line 21) 😄 This is not a good solution, we can work to provide a better solution, give us some time and this will be resolved. Waiting to land this PR and it should be easy from there.

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
distributed_backend=nccl
All DDP processes registered. Starting ddp with 1 processes
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  | Name  | Type                     | Params
---------------------------------------------------
0 | layer | FullyShardedDataParallel | 1.1 K
---------------------------------------------------
1.1 K     Trainable params
0         Non-trainable params
1.1 K     Total params
0.004     Total estimated model params size (MB)
/private/home/anupamb/miniconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:326: UserWarning: The number of training samples (32) is smaller than the logging interval Trainer(
log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                                                                           | 0/32 [00:00<00:00, 3823.43it/s]
WARNING:root:ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 175.90it/s, loss=-5.94e+03, v_num=22]

@anupambhatnagar
Copy link

@anj-s If we extend the _unscale_grads_ to ShardedGradScaler with the allow_fp16 set to True, then this issue is resolved. I checked that all tests in the TestComparisonToPyTorchDDP class passed on top of the cpu_gradscaler branch. If other tests pass as well (which I haven't checked yet) then this will be an easy fix.

@anj-s
Copy link
Contributor

anj-s commented Nov 1, 2021

@anupam-fb thanks for looking into this! The fix sounds good. The only things I would follow up on are:

  1. Do we have existing tests for FSDP + FP16 + ShardedGradScaler? This is just to make sure we either did not cover this use case or the test did not fire as expected.
  2. Understand more about the use case where this might be useful. For example, what kind of optimizer would be used with FP16 grads? How is the optimizer state being handled?
  3. Is there a reason PyTorch has not enabled this and has by default set it to False?

@anupambhatnagar
Copy link

@anupam-fb thanks for looking into this! The fix sounds good. The only things I would follow up on are:

  1. Do we have existing tests for FSDP + FP16 + ShardedGradScaler? This is just to make sure we either did not cover this use case or the test did not fire as expected.
  2. Understand more about the use case where this might be useful. For example, what kind of optimizer would be used with FP16 grads? How is the optimizer state being handled?
  3. Is there a reason PyTorch has not enabled this and has by default set it to False?
  1. Yes, we have added tests to cover this case.
  2. I'll look into these as I continue to work on layer wise grad scaler.
  3. Don't know why PyTorch didn't enable. The function is for internal use only (private) and hence they have limited functionality enabled.

@anupambhatnagar
Copy link

@carmocca The solution to this issue has been merged in main. Could you try installing fairscale (pip install -e setup.py) from the main branch and let us know if that works for you? The solution will be available in FairScale 0.4.3 out of the box.

@carmocca
Copy link
Contributor Author

@anupambhatnagar Seems to be working, thank you! Feel free to close this issue.

Also, please ping me here once 0.4.3 is released, the Lightning CI has fairscale pinned at the moment.

@anupambhatnagar
Copy link

@carmocca FYI - Fairscale 0.4.5 will be releasing soon. ICYMI, we are introducing Per Layer Gradient Scaling in the upcoming version.

Could you please share how you are using the ShardedGradScaler? Feedback on that would be helpful to us. Thanks!

@carmocca
Copy link
Contributor Author

Thanks for the heads-up!

Could you please share how you are using the ShardedGradScaler? Feedback on that would be helpful to us. Thanks!

Our integration is quite simple, internally we just select this class when the user requested sharded with mixed precision:

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/sharded_native_amp.py

And override clip_grad_norm from the base implementation as it does not use orch.nn.utils.clip_grad_norm_.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

No branches or pull requests

4 participants