-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Support with AMP Grad scaler #421
Comments
I haven't looked closely at the PyTorch Re: CPU offload, glancing at the PyTorch version makes me think it's pretty hardcoded for CUDA :/ I can take a look though, maybe we can patch |
At-least on the server! Runs for me fine in the lightning integration, + I double checked with the fairscale tests which all pass, except CPU offload :)
Agreed, the assertions for cuda in the gradscaler everywhere + random required CUDA tensors are insane... Already on the same thought, I tried to hack it by swapping the grads to CUDA then bringing it back to CPU but I gave up because that can't be a fix... I should get time today to investigate as I think that's the largest blocker for the lightning integration EDIT: @tchaton has been working on a fix for this. Currently he delays the move to CPU in the post reduction hook till after the scaling/unscaling, which seems to work (and before the optimizer step he manually moves back). Might not be the most long term fix however, since it means we call move to cpu at once after the scaling logic, so no overlap. Not sure what the throughput will be if we do this |
hey @myleott would delaying the move grad to CPU when using CPU offload be viable like this? class LightningFullyShardedDataParallel(FullyShardedDataParallel):
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD)
param.grad.data = reduced_grad
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
# issues with this part
# This part needs to be done after unscaling the gradients, so we comment out the lines
#if self.move_grads_to_cpu:
# param._cpu_grad.copy_(param.grad.data, non_blocking=True)
# param.grad.data = param._cpu_grad
# Don't let this memory get reused until after the transfers.
#reduced_grad.record_stream(torch.cuda.current_stream()) And within the lightning code, after ...
self.scaler.unscale_(optimizer)
self.move_grad_to_cpu(model.trainer.model)
def move_grad_to_cpu(self, model):
if hasattr(model, "cpu_offload"):
if model.cpu_offload:
for param in model.params:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad I'm not sure how this would work with multiple nested FSDP wrappers (if that's even an issue) or what the performance hit looks like. It's also not a nice long term fix, since we're having to do some weird delay for the scaler op to run on GPU. I also looked into the AMP implementation for unscale grads and am curious what would happen if we swapped this implementation for a iterative inf check across all grads. Would be great to get your thoughts here! |
maybe you can contribute a small test that demonstrate this problem like in #437? That way, we can put the test in and hopefully fix the issue with the test. |
Absolutely, me being lazy so apologies, here ya go @min-xu-ai: import os
import unittest
from unittest import mock
import torch
from torch.cuda.amp import autocast
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def test_scaler_cpu_offload():
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
torch.distributed.destroy_process_group() |
* [test] FSDP: add the failing test for #421 * skip on 1.5 * better skipping * Update tests/nn/data_parallel/test_fsdp_grad_scaler.py Co-authored-by: Sam Shleifer <[email protected]> Co-authored-by: Sam Shleifer <[email protected]>
I will work on this at a later point since no use case is waiting on this yet. Likely we need a scaler that's specific to FSDP. |
@anj-s you were running into something similar to this. Have you solved it on your code? |
Yes, that's right. We have a similar problem for OffloadModel. This is work in progress where we are adapting a dynamic loss scaler from fairseq that will work for optimizer updates on the CPU. It might need additional logic for it to work with FSDP. |
Hi @min-xu-ai, that's not entirely true. As I mentioned in the PR linked above and to Sean in private, this is blocking the zero offload option with fairscale in the Transformers library. We've not developed a short-term fix on our side since we have zero DP3 wth zero offload via deepspeed for now, but it would be nice to be able to use fairscale for this too! |
I think what @min-xu-ai is getting at is for some use cases (where you do not want to scale that large) such as VISSL currently you do not need CPU offloading (to scale size you can just scale GPUs). I do agree that this is annoying for Transformers/PL because we have to now make it clear that CPU Offloading will not work. Is #589 tied to this? It seems like this is what allowed FairSeq to side step the issue, since they use their own grad scaler |
It seems deepspeed also uses their own DynamicLossScaler, similar to fairseq. Theirs is pretty disconnected from the rest of deepspeed, so one could probably just import and use that with FSDP... |
Thanks @anupambhatnagar for getting this fixed :) |
🐛 Bug
If you modify the FSDP test here to include a Grad Scaler, the cpu_offload test fails:
Modification:
Command:
Error:
If I remove the
to(model_device)
I get a different error, but probably still due to the devices:Proposed fix
Modify the ShardedGradScaler to work with
cpu_offload
. It seems the issue boils down to having to deal with gradients that are on CPU, so the question becomes can we modify the inf check and unscale operation for CPU?Another direction is motivated by the PyTorch GradScaler class being hella confusing, and after finding this in fairseq, maybe it's better we define our own grad scaler logic for FSDP? If this is something that would be preferred, I can work on this!
Environment
cc @myleott @sshleifer @blefaudeux @min-xu-ai
The text was updated successfully, but these errors were encountered: