Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FullyShardedDataParallel (FSDP) #413

Merged
merged 51 commits into from
Feb 23, 2021
Merged

Add FullyShardedDataParallel (FSDP) #413

merged 51 commits into from
Feb 23, 2021

Conversation

myleott
Copy link
Contributor

@myleott myleott commented Feb 23, 2021

Co-authored-by: @min-xu-ai and @sshleifer

Overview

Recent work by Microsoft and Google has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new FullyShardedDataParallel (FSDP) wrapper, which is a drop-in replacement for PyTorch's DistributedDataParallel (DDP) wrapper.

Compared to PyTorch DDP:

  • FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
  • FSDP with reshard_after_forward=False has the same communication cost as PyTorch DDP and is similar to ZeRO-2
  • FSDP with reshard_after_forward=True increases total communication by 50% and is similar to ZeRO-3:
    • all-gather parameters at start of forward pass and start of backward pass
    • reduce-scatter grads at end of backward pass
  • in practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass
  • FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the cpu_offload=True option, it's possible to train 1T parameter models on 256 GPUs.

General usage notes

  • for best memory efficiency wrap each layer in your network with FSDP and set reshard_after_forward=True
  • for best training speed set reshard_after_forward=False (wrapping each layer is not required, but will improve speed further)
  • if you're using torch.cuda.amp.autocast for mixed precision, that's fully compatible with the FSDP wrapper, just set mixed_precision=True
  • if combining with activation checkpointing, prefer FSDP(checkpoint_wrapper(module)) over checkpoint_wrapper(FSDP(module)). The latter will result in more communication and will be slower.
  • this is full compatible with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.

How it works

In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an all-reduce operation. While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers.

The key insight to unlock full parameter sharding is that we can decompose the all-reduce operation in DDP into separate all-gather and reduce-scatter operations:

Screen Shot 2021-01-12 at 12 35 19 PM

Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right):

Screen Shot 2021-02-24 at 4 39 55 PM

To maximize memory efficiency we can discard the full weights after each layer's forward pass, saving memory for subsequent layers. This can be implemented by applying the FSDP wrapper to every layer in your network (with reshard_after_forward=True). In pseudo-code:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        forward pass for layer_i
        discard full weights for layer_i
FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        backward pass for layer_i
        discard full weights for layer_i
        reduce-scatter gradients for layer_i

myleott and others added 30 commits February 1, 2021 11:14
* [mypy]: fixed all the mypy errors

* make older version python happy with typing hints

* [chore] Fix lint errors that broke master (#348)

authored-by: Anjali Sridhar <[email protected]>

Co-authored-by: anj-s <[email protected]>
* Test CPU offload

* remove dead code
…pper (#42)

* Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wrapper

* Fix mypy, add test, address some comments

* Add missing assert

* Comments
* Test backward hooks are registered

* expand

* fs_test

* passing

* assert again

* add assert not called

* naming
* format change

* [test]: test apply_to_tensors

* formatting

* added some skeletons

* name for TODO

* fixing the use of lru_cache

* formatting
* Do reduce-scatter in a separate CUDA stream

* Add _post_backward_stream to stubs
* Fix delayed_reduce_scatter test
* Decompose NestedWrappedModule from ModuleWithDelay
* add unit test pack/unpack kwargs

* added two more corner cases

* more doc and more tests

* more corner cases

* formatting

* Update fairscale/utils/containers.py

Co-authored-by: Sam Shleifer <[email protected]>

* with pytest.raises is awesome

* addressed comment

* add tuple to be tested

Co-authored-by: Sam Shleifer <[email protected]>
- Add two new tests (TestParamInit and TestSerialization) which would have failed previously. These mostly cover the fairseq usage that was not captured by tests before.
- Deprecate the `compute_device` option, since we don't actually use it anywhere (nor do we need it in fairseq)
- Remove `_move_fp32_shard_to_cuda` and embrace a stronger invariant: p.data and p._fp32_shard should always be the same at the start and end of each function (namely, state_dict, and also forward/backward).
- Slightly unrelated, but refactor streams logic a bit, so we have a single `self._streams` dictionary -- this will make an upcoming PR that adds more streams easier
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@min-xu-ai min-xu-ai changed the title Add FullyShardedDataParallel Add FullyShardedDataParallel (FSDP) Feb 23, 2021
Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats @myleott, @min-xu-ai and @sshleifer :)

optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a quick sneak at the PR, was wondering if this test will pass if we add a GradScaler object? If weights/states are sharded across will we need to use something similar to the ShardedGradScaler: https://github.com/facebookresearch/fairscale/blob/master/fairscale/optim/grad_scaler.py#L24

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question and I think you're probably right. PR's welcome :)

@myleott myleott merged commit 15512d9 into master Feb 23, 2021
@myleott myleott deleted the fsdp branch February 23, 2021 01:11
@myleott
Copy link
Contributor Author

myleott commented Feb 23, 2021

Merging so that future developments can be done as separate PRs. Thanks for all the help reviewing this @anj-s and @blefaudeux!

@anj-s
Copy link
Contributor

anj-s commented Feb 23, 2021

Great work @myleott @sshleifer @min-xu-ai! Excited to see this PR :)

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

This is full ZeRO-DP stage 3, right?

@myleott
Copy link
Contributor Author

myleott commented Feb 23, 2021

Yep, ZeRO stage 3 and CPU offloading. It can also implement stage 2 (with reshard_after_forward=False), which is a bit faster than stage 3.

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

Absolutely amazing!!! This dramatically changes what we can do with huge models. Can't wait to test it out.

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

OK, I'd like to start integrating this right away in HF transformers - do you have examples I could learn from/mimic?

Also wrt:

FSDP with reshard_after_forward=False has the same communication cost as PyTorch DDP and is similar to ZeRO-2

where does ShardedDDP fit into this picture now? Is FSDP w/ reshard_after_forward=False == ShardedDDP?

And I'm also thinking about the long term - as fairscale has been backporting this upstream to pytorch, we should probably extended the HF trainer API to assume that eventually this will be native and until then just require fairscale to be loaded to get the components, right?

@min-xu-ai
Copy link
Contributor

OK, I'd like to start integrating this right away in HF transformers - do you have examples I could learn from/mimic?

The doc is still being worked on. We are adding more tests and making sure FSDP work with more models. The docstring right now has a minimal example. test_fsdp.py has more examples. But with an vision model in VISSL, I basically did the following:

model = FSDP(nn.Sequential(
checkpoint_wrapper(FSDP(model.layer1, mixed_precision=True),
checkpoint_wrapper(FSDP(model.layer2, mixed_precision=True),
...), mixed_precision=True
)

I think Myle's usage in Transformer might be slightly different with respect to checkpoint_wrapper & FSDP's order.

Again, I'd warn that this is fresh off the press. Sharp edges are likely in multiple places.

where does ShardedDDP fit into this picture now? Is FSDP w/ reshard_after_forward=False == ShardedDDP?

They are only roughly == in terms memory overhead. I'd think they will be coexisting for a while if not forever. They may have different fits for different model and model sizes. We are still gaining experiences with them. In the end, it will be a comparison of usability, generalizability, performance, memory efficiency, etc.

And I'm also thinking about the long term - as fairscale has been backporting this upsteam, we should probably extended the HF trainer API to assume that eventually this will be native and until then just require fairscale to be loaded to get the components, right?

yes

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

OK, I'd like to start integrating this right away in HF transformers - do you have examples I could learn from/mimic?

The doc is still being worked on. We are adding more tests and making sure FSDP work with more models. The docstring right now has a minimal example. test_fsdp.py has more examples. But with an vision model in VISSL, I basically did the following:

model = FSDP(nn.Sequential(
checkpoint_wrapper(FSDP(model.layer1, mixed_precision=True),
checkpoint_wrapper(FSDP(model.layer2, mixed_precision=True),
...), mixed_precision=True
)

Thank you for the example, @min-xu-ai

Oh, the OP mentioned nothing about the model needing to be nn.Sequential. Is this just an example, or is this a requirement?

Could one traverse the model graph and inject those intermediary layers automatically? Perhaps with pytorch FX?

Again, I'd warn that this is fresh off the press. Sharp edges are likely in multiple places.

I appreciate the disclaimer. I realize this is an early alpha and I promise to thread with care.

where does ShardedDDP fit into this picture now? Is FSDP w/ reshard_after_forward=False == ShardedDDP?

They are only roughly == in terms memory overhead. I'd think they will be coexisting for a while if not forever. They may have different fits for different model and model sizes. We are still gaining experiences with them. In the end, it will be a comparison of usability, generalizability, performance, memory efficiency, etc.

Thank you for clarifying that. I'm just starting to think about the user-facing API in the HF trainer, hence trying to think how to best name things.

@min-xu-ai
Copy link
Contributor

Oh, the OP mentioned nothing about the model needing to be nn.Sequential. Is this just an example, or is this a requirement?

Sorry for misleading you. That's just an example. In fact, it was a ModuleDict in vissl's case. I used nn.Sequential out of habit.

Could one traverse the model graph and inject those intermediary layers automatically? Perhaps with pytorch FX?

Definitely a possibility! I think figuring out where to inject for best results will be the harder part though.

@myleott
Copy link
Contributor Author

myleott commented Feb 23, 2021

where does ShardedDDP fit into this picture now? Is FSDP w/ reshard_after_forward=False == ShardedDDP?

They are only roughly == in terms memory overhead.

Correct, it’s not implemented the same way as ShardedDDP, so there may be different use cases that prefer one or the other. That said, the version here (with reshard_after_forward=False) should be faster in most cases, since it can overlap the all-gather with the forward pass.

Oh, the OP mentioned nothing about the model needing to be nn.Sequential. Is this just an example, or is this a requirement?
Could one traverse the model graph and inject those intermediary layers automatically? Perhaps with pytorch FX?

To clarify, no need for nn.Sequential, almost any module structure works. I agree we should traverse the graph and wrap the intermediate layers with FSDP automatically. In fact, there’s almost no downside in terms of perf of just wrapping every layer that’s big enough, because FSDP internally manages the child instances (e.g., it will share CUDA streams with the children to avoid unnecessary synchronization).

The one caveat is that some model structures break if you wrap a child module. In particular, if the parent module accesses the weights of the child module directly (i.e., not via the child module’s forward), then we don’t execute the hook to all-gather the params, so it breaks. The nn.Linears in fairseq’s MultiheadAttention is one example; had we wrapped out_proj with FSDP, then this line would break, since the weight is still sharded at this point: https://github.com/pytorch/fairseq/blob/ab560669cd9baaa4009e1fd01c970f8ffccd1ee0/fairseq/modules/multihead_attention.py#L175

So for now we’re requiring layers to be wrapped manually, but I do think this can be relaxed once we add better fallback logic/error reporting for the case above.

do you have examples I could learn from/mimic?

I’ll try to put up a diff in fairseq today or tomorrow that works for Transformer.

@blefaudeux
Copy link
Contributor

Correct, it’s not implemented the same way as ShardedDDP, so there may be different use cases that prefer one or the other. That said, the version here (with reshard_after_forward=False) should be faster in most cases, since it can overlap the all-gather with the forward pass.

For the sake of precision, that's true (overlap FW and gather) if you can wrap subparts of your model with FSDP, not if the model is in one FSDP wrap

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

Thank you very much for your detailed commentary to my questions - that's very useful and I'm looking forward to studying the examples

Request: to make things consistent with DDP could this be named FullyShardedDistributedDataParallel - otherwise it's confusing as it implies DP, which is different from the convention of: DDP is multiproc, DP is multi-thread.

On the other hand it implements ZeRO-DP, so it matches the name in the paper. These are conflicting arguments.

May I suggest that removing ambiguity for the user is more important than matching the name paper?

Note that you rename ShardedDataParallel to ShardedDDP here:
https://github.com/facebookresearch/fairscale#optimizer-state-sharding-zero

@tchaton
Copy link

tchaton commented Feb 23, 2021

Dear @myleott,

Thanks for your incredible work.

I have a sill question:

Currently, sharding is done by splitting the contiguous parameter tensor into world_size chunks.

            chunks = list(torch.flatten(p.data).chunk(self.world_size))

I was wondering If you considered to perform a chunking in the following way.

Chunk each individual parameters into their own world_size chunks and concatenate them across the entire model in a contiguous way.

For all rank I, [chunks_world_size(p1, rank_i), chunks_world_size(p2, rank_i), ..., chunks_world_size(pn, rank_i)]].

Then, it could be possible to override the the parameter getattr function to perform the in-place all_gather when accessed + the following ones in async way, and perform memory release with a post forward hook, which could maybe reduce some memory.

My assumption is: It would avoid to call _rebuild_full_params and _free_full_params as it would be lazy execution. It could also save some memory as there would be no need to have the entire model on each rank, but only the next n layers parameters for each nn.Module forward call.

I hope I was able to properly express myself :)

Best,
T.C

@myleott
Copy link
Contributor Author

myleott commented Feb 23, 2021

Would it be possible to override getattr for all the sharded parameters to perform all_gather directly when being accessed during forward + adding a forward_hook to clean the memory after each Module forward ?

Nice idea. I think this would be possible, especially as we move towards automatic wrapping of children modules (discussed above). As always, I'm sure there are some corner cases we'll need to be careful of, but this is a good direction to explore 😄

A related, but separate next step is to plug into the DDP Communication Hooks interface to handle the reductions. Right now we're doing our own bucketing, which is really slow when flatten_parameters=False.

@tchaton
Copy link

tchaton commented Feb 23, 2021

Dear @myleott,

Sounds great ! If you are interested, we could try to explore this idea with your team.

I recently worked on integrating a DDP Comm Hook to perform training for unstable NaN Loss such as the CTC Loss: https://github.com/PyTorchLightning/pytorch-lightning/blob/490c40a8be339afd08dcfdcdd658db6b0be4671b/pytorch_lightning/plugins/ddp_comm_hooks/allreduce_invalid.py#L52.
I need to finish off the PR, but the interface was great to work with !

Side Note: I think it would be also interesting to work on composable DDP Comm Hooks by making it more modular.

I am going to keep studying the code. Thanks you and your team for such a clean work ! Learning a lot :)

Best,
T.C

@min-xu-ai
Copy link
Contributor

Chunk each individual parameters into their own world_size chunks and concatenate them across the entire model in a contiguous way.

Not sure I totally follow this idea, Thomas. Say the world_size is 2. Do you mean break each param in 2 chunks and concat all 1st chunks on rank 0 and concat all 2nd chunks on rank 1? Is the concat needed so that you don't all_gather small params one by one? With concat, you just all_gather a partial view into the big array? The idea to use getattr to trigger all_gather is nice but it assumes different ranks will always execute all getattr in the same order, which is controlled by the model code that we don't control.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants