Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Support with AMP Grad scaler #421

Closed
SeanNaren opened this issue Feb 23, 2021 · 12 comments · Fixed by #831
Closed

[FSDP] Support with AMP Grad scaler #421

SeanNaren opened this issue Feb 23, 2021 · 12 comments · Fixed by #831
Assignees
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@SeanNaren
Copy link

SeanNaren commented Feb 23, 2021

🐛 Bug

If you modify the FSDP test here to include a Grad Scaler, the cpu_offload test fails:

Modification:

     from fairscale.optim.grad_scaler import ShardedGradScaler

    @staticmethod
    def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
        model_device = next(model.parameters()).device
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        scaler = ShardedGradScaler()
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of move_grads_cpu, or model.device
                input = model.module.get_input(torch.device("cuda"))
                output = model(*input)
                loss = model.module.get_loss(input, output).to(model_device)
            loss = scaler.scale(loss)
            assert loss.dtype == torch.float32
            model.module.run_backward(loss)
            if norm_type is not None:
                clip_norm = 0.3
                if isinstance(model, FullyShardedDataParallel):
                    model.clip_grad_norm_(clip_norm, norm_type)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
            scaler.step(optim)
            scaler.update()
        if hasattr(model, "assert_idle"):
            model.assert_idle()
        return loss.detach()

Command:

pytest tests/nn/data_parallel/test_fsdp.py::TestComparisonToPyTorchDDP::test_cpu_offload_and_cpu_grads

Error:

E       torch.multiprocessing.spawn.ProcessRaisedException:
E
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 812, in init_and_run
E           fn(rank, group, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 292, in _test_identical_outputs
E           shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 63, in _train_for_several_steps
E           loss = scaler.scale(loss)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 161, in scale
E           assert outputs.is_cuda
E       AssertionError

If I remove the to(model_device) I get a different error, but probably still due to the devices:

E       torch.multiprocessing.spawn.ProcessRaisedException:
E
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 812, in init_and_run
E           fn(rank, group, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 292, in _test_identical_outputs
E           shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 74, in _train_for_several_steps
E           scaler.step(optim)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 324, in step
E           self.unscale_(optimizer)
E         File "/home/sean/fairscale/fairscale/optim/grad_scaler.py", line 48, in unscale_
E           super().unscale_(optimizer)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 275, in unscale_
E           optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 223, in _unscale_grads_
E           torch._amp_foreach_non_finite_check_and_unscale_(grads,
E       RuntimeError: Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_amp_foreach_non_finite_check_and_unscale_' is only available for these backends: [CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].
E
E       CUDA: registered at /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7100 [kernel]
E       BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E       Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
E       AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradNestedTensor: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       UNKNOWN_TENSOR_TYPE_ID: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_0.cpp:10499 [kernel]
E       Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
E       Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
E       VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Proposed fix

Modify the ShardedGradScaler to work with cpu_offload. It seems the issue boils down to having to deal with gradients that are on CPU, so the question becomes can we modify the inf check and unscale operation for CPU?

Another direction is motivated by the PyTorch GradScaler class being hella confusing, and after finding this in fairseq, maybe it's better we define our own grad scaler logic for FSDP? If this is something that would be preferred, I can work on this!

Environment

PyTorch version: 1.8.0+cu112
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB

Nvidia driver version: 460.32.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] pytorch-lightning==1.3.0.dev0
[pip3] torch==1.8.0+cu112
[pip3] torchvision==0.8.2
[conda] numpy                     1.20.1                   pypi_0    pypi
[conda] pytorch-lightning         1.3.0.dev0                dev_0    <develop>
[conda] torch                     1.8.0+cu112              pypi_0    pypi
[conda] torchvision               0.8.2                    pypi_0    pypi

cc @myleott @sshleifer @blefaudeux @min-xu-ai

@myleott
Copy link
Contributor

myleott commented Feb 23, 2021

I haven't looked closely at the PyTorch GradScaler, but do the tests pass for ShardedGradScaler + CUDA FSDP? If so, that's great!

Re: CPU offload, glancing at the PyTorch version makes me think it's pretty hardcoded for CUDA :/ I can take a look though, maybe we can patch ShardedGradScalar to make it work with CPU grads

@SeanNaren
Copy link
Author

SeanNaren commented Feb 24, 2021

I haven't looked closely at the PyTorch GradScaler, but do the tests pass for ShardedGradScaler + CUDA FSDP? If so, that's great!

At-least on the server! Runs for me fine in the lightning integration, + I double checked with the fairscale tests which all pass, except CPU offload :)

Re: CPU offload, glancing at the PyTorch version makes me think it's pretty hardcoded for CUDA :/ I can take a look though, maybe we can patch ShardedGradScalar to make it work with CPU grads

Agreed, the assertions for cuda in the gradscaler everywhere + random required CUDA tensors are insane... Already on the same thought, I tried to hack it by swapping the grads to CUDA then bringing it back to CPU but I gave up because that can't be a fix...

I should get time today to investigate as I think that's the largest blocker for the lightning integration

EDIT: @tchaton has been working on a fix for this. Currently he delays the move to CPU in the post reduction hook till after the scaling/unscaling, which seems to work (and before the optimizer step he manually moves back).

Might not be the most long term fix however, since it means we call move to cpu at once after the scaling logic, so no overlap. Not sure what the throughput will be if we do this

@SeanNaren SeanNaren changed the title FSDP support with AMP Grad scaler [FSDP] support with AMP Grad scaler Feb 24, 2021
@SeanNaren SeanNaren changed the title [FSDP] support with AMP Grad scaler [FSDP] Support with AMP Grad scaler Feb 24, 2021
@SeanNaren
Copy link
Author

hey @myleott would delaying the move grad to CPU when using CPU offload be viable like this?

class LightningFullyShardedDataParallel(FullyShardedDataParallel):

    def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
        """Hook to call on each param after the reduce-scatter."""
        assert torch.cuda.current_stream() == self._streams["post_backward"]
        assert param.grad is not None
        self.assert_state(TrainingState.BACKWARD)
        param.grad.data = reduced_grad
        # Cast grad to param's dtype (typically FP32). Note: we do this
        # before the move_grads_to_cpu step so that this entire hook remains
        # non-blocking. The downside is a bit more D2H transfer in that case.
        if self.mixed_precision:
            param.grad.data = param.grad.data.to(dtype=param.data.dtype)
        # Optionally move gradients to CPU, typically used if one is running
        # the optimizer on the CPU.
        # issues with this part

        # This part needs to be done after unscaling the gradients, so we comment out the lines
        #if self.move_grads_to_cpu:
        #    param._cpu_grad.copy_(param.grad.data, non_blocking=True)
        #    param.grad.data = param._cpu_grad
        # Don't let this memory get reused until after the transfers.
        #reduced_grad.record_stream(torch.cuda.current_stream())    

And within the lightning code, after _unscale:

...
self.scaler.unscale_(optimizer)
self.move_grad_to_cpu(model.trainer.model)

move_grad_to_cpu:

def move_grad_to_cpu(self, model):
        if hasattr(model, "cpu_offload"):
            if model.cpu_offload:
                for param in model.params:
                    param._cpu_grad.copy_(param.grad.data, non_blocking=True)
                    param.grad.data = param._cpu_grad

I'm not sure how this would work with multiple nested FSDP wrappers (if that's even an issue) or what the performance hit looks like.

It's also not a nice long term fix, since we're having to do some weird delay for the scaler op to run on GPU. I also looked into the AMP implementation for unscale grads and am curious what would happen if we swapped this implementation for a iterative inf check across all grads. Would be great to get your thoughts here!

@min-xu-ai
Copy link
Contributor

maybe you can contribute a small test that demonstrate this problem like in #437? That way, we can put the test in and hopefully fix the issue with the test.

@SeanNaren
Copy link
Author

Absolutely, me being lazy so apologies, here ya go @min-xu-ai:

import os
import unittest
from unittest import mock

import torch
from torch.cuda.amp import autocast
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler


@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def test_scaler_cpu_offload():
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    scaler = ShardedGradScaler()
    model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
    optim = torch.optim.SGD(model.parameters(), lr=1e-3)

    input = torch.rand((1, 5), dtype=torch.float).to(device)
    optim.zero_grad()
    with autocast():
        output = model(input)
        loss = F.mse_loss(input, output)

    scaler.scale(loss).backward()
    scaler.step(optim)
    scaler.update()
    torch.distributed.destroy_process_group()

min-xu-ai pushed a commit that referenced this issue Mar 1, 2021
min-xu-ai added a commit that referenced this issue Mar 1, 2021
* [test] FSDP: add the failing test for #421

* skip on 1.5

* better skipping

* Update tests/nn/data_parallel/test_fsdp_grad_scaler.py

Co-authored-by: Sam Shleifer <[email protected]>

Co-authored-by: Sam Shleifer <[email protected]>
@min-xu-ai min-xu-ai self-assigned this Mar 18, 2021
@min-xu-ai
Copy link
Contributor

I will work on this at a later point since no use case is waiting on this yet. Likely we need a scaler that's specific to FSDP.

@min-xu-ai min-xu-ai added the FSDP FullyShardedDataParallel (zero-3) label Apr 3, 2021
@min-xu-ai
Copy link
Contributor

@anj-s you were running into something similar to this. Have you solved it on your code?

@anj-s
Copy link
Contributor

anj-s commented Apr 3, 2021

@anj-s you were running into something similar to this. Have you solved it on your code?

Yes, that's right. We have a similar problem for OffloadModel. This is work in progress where we are adapting a dynamic loss scaler from fairseq that will work for optimizer updates on the CPU. It might need additional logic for it to work with FSDP.

@sgugger
Copy link

sgugger commented Apr 15, 2021

I will work on this at a later point since no use case is waiting on this yet.

Hi @min-xu-ai, that's not entirely true. As I mentioned in the PR linked above and to Sean in private, this is blocking the zero offload option with fairscale in the Transformers library. We've not developed a short-term fix on our side since we have zero DP3 wth zero offload via deepspeed for now, but it would be nice to be able to use fairscale for this too!

@SeanNaren
Copy link
Author

I think what @min-xu-ai is getting at is for some use cases (where you do not want to scale that large) such as VISSL currently you do not need CPU offloading (to scale size you can just scale GPUs). I do agree that this is annoying for Transformers/PL because we have to now make it clear that CPU Offloading will not work.

Is #589 tied to this? It seems like this is what allowed FairSeq to side step the issue, since they use their own grad scaler

@myleott
Copy link
Contributor

myleott commented Apr 15, 2021

It seems deepspeed also uses their own DynamicLossScaler, similar to fairseq. Theirs is pretty disconnected from the rest of deepspeed, so one could probably just import and use that with FSDP...

Usage example: https://github.com/microsoft/DeepSpeed/blob/ab5534fc4c0f8ca21ada321f9730d723aa31288b/deepspeed/runtime/fp16/loss_scaler.py#L173-L221

@SeanNaren
Copy link
Author

Thanks @anupambhatnagar for getting this fixed :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants