Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FullyShardedDataParallel (FSDP) #413

Merged
merged 51 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a7be0d2
Add fairscale.utils.testing.DeviceAndTypeCheckModule
myleott Jan 28, 2021
e2ad716
Add fairscale.utils.containers
myleott Jan 28, 2021
4d6a5c9
Add ShardParamsDataParallel
myleott Jan 26, 2021
dd57e30
[test]: skip autocast-needed tests on torch < 1.6 (#34)
min-xu-ai Jan 29, 2021
80678fc
[mypy]: fairscale/utils/containers.py (#33)
min-xu-ai Jan 29, 2021
5fb36d1
[mypy]: fixed fairscale/utils/testing.py (#32)
min-xu-ai Jan 29, 2021
774e130
More docs (#35)
myleott Jan 29, 2021
b35a28a
[mypy]: fixed all the mypy errors (#37)
min-xu-ai Feb 1, 2021
92c550b
Sharded DDP: test cpu_offload arg (#40)
sshleifer Feb 1, 2021
5bb212f
Misc comments from @anj-s (#43)
myleott Feb 1, 2021
cbd243e
Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wra…
myleott Feb 2, 2021
8db0cf6
Fix streams test (#45)
myleott Feb 2, 2021
bc7e337
move_grads_to_cpu defaults to same value as cpu_offload (#44)
sshleifer Feb 2, 2021
a1b3924
formatting change (#46)
min-xu-ai Feb 2, 2021
e10df73
Test that backward hooks are registered (#49)
sshleifer Feb 3, 2021
e139857
[test] add test for apply_to_tensors (#50)
min-xu-ai Feb 3, 2021
6f153b0
Test save/load state_dict V2 (#51)
sshleifer Feb 4, 2021
4114772
Replace x.view(-1) with torch.flatten(x) (#59)
myleott Feb 4, 2021
36b2d39
Add more comments + docstrings (#58)
myleott Feb 4, 2021
bc5190b
Rearrange dtype and device change in post-backward hook (#61)
myleott Feb 4, 2021
8a5f81c
Do reduce-scatter in a separate CUDA stream (#62)
myleott Feb 4, 2021
72c1f63
tests use spawn_for_all_world_sizes (#63)
sshleifer Feb 5, 2021
f481877
Fix state_dict bugs (#60)
sshleifer Feb 6, 2021
515411b
update comments to reflect where we are in stack (#69)
sshleifer Feb 7, 2021
ec4e75e
[CI] use parameterized.expand to make each test faster (#68)
sshleifer Feb 7, 2021
b1460d3
Fix delayed reduce_scatter test (#74)
myleott Feb 7, 2021
014ad05
add unit test pack/unpack kwargs (#65)
min-xu-ai Feb 7, 2021
0ca378b
Refactor param init and streams logic (#73)
myleott Feb 7, 2021
bebe7fd
Add test for NoGrad mode
myleott Feb 8, 2021
74b0223
Add all_gather stream and disable reshard_after_forward on root insta…
myleott Feb 8, 2021
6797964
Leave buffers on self.compute_device (#67)
sshleifer Feb 9, 2021
0937594
Pad parameters inside: right before gather, scatter (#76)
sshleifer Feb 9, 2021
93670cb
Add no_sync() context manager (#77)
myleott Feb 10, 2021
1d0bf73
rename
sshleifer Feb 10, 2021
5fc1f12
Slightly faster execution when world_size == 1 (#81)
myleott Feb 11, 2021
3681242
merge new base (which is public/master) (#82)
min-xu-ai Feb 11, 2021
dfada29
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
9fb974b
Merge branch 'shard_params_ddp_base' into shard_params_ddp
myleott Feb 11, 2021
7bd82d1
two small changes (#83)
min-xu-ai Feb 12, 2021
d8f3349
fixing uneven shard support and add uneven shard unit test (#80)
min-xu-ai Feb 14, 2021
366c38e
rename test file (#86)
min-xu-ai Feb 15, 2021
87d28c9
adding TrainingState enum for asserting purpose (#87)
min-xu-ai Feb 19, 2021
72b5967
Misc (#88)
myleott Feb 19, 2021
a0fe832
Bugfix for backward stream (#91)
myleott Feb 20, 2021
2abf21f
Small fix to activation checkpointing + FSDP (#92)
myleott Feb 22, 2021
977f459
Bugfix (+ add test) for no_sync before first forward (#93)
myleott Feb 22, 2021
43d1f73
FSDP.clip_grad_norm (#89)
sshleifer Feb 22, 2021
77dc364
Add ReduceScatterBucketer (#94)
myleott Feb 22, 2021
bf30e59
Merge branch 'oss-master' into shard_params_ddp
myleott Feb 22, 2021
ec8bf9f
Fix merge conflict
myleott Feb 22, 2021
3f2b7f1
Update docstring slightly
myleott Feb 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .fully_sharded_data_parallel import FullyShardedDataParallel
from .sharded_ddp import ShardedDataParallel
947 changes: 947 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions fairscale/nn/misc/checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
import functools
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Generator, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -73,6 +74,23 @@ def set_rng_state(state: Dict[str, Any]) -> None:
torch.cuda.set_rng_state(state["cuda_rng_state"])


def is_autocast_enabled() -> bool:
"""Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1"""
if hasattr(torch, "is_autocast_enabled"):
return torch.is_autocast_enabled()
return False


@contextmanager
def autocast(enabled: bool) -> Generator:
"""Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1"""
if enabled:
with torch.cuda.amp.autocast(enabled):
yield
else:
yield


class CheckpointFunction(torch.autograd.Function):
"""Similar to the torch version, but support non-Tensor outputs.

Expand All @@ -96,13 +114,13 @@ def forward( # type: ignore
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
ctx.fwd_rng_state = get_rng_state()
ctx.had_autocast_in_fwd = is_autocast_enabled()

tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs)

else:
ctx.fwd_device, ctx.grad_requirements = None, None

Expand Down Expand Up @@ -142,10 +160,11 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
# Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state)

with torch.enable_grad():
with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)

# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)

Expand Down
17 changes: 11 additions & 6 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# Licensed under the MIT License.

from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor
import torch.nn as nn

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401


class FlattenParamsWrapper(nn.Module):
"""
Expand Down Expand Up @@ -127,21 +130,23 @@ def __getattr__(self, name: str) -> Any:
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module

def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict."""
with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
return self.module.state_dict(*args, **kwargs)

def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
return super().state_dict(*args, **kwargs)

def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None:
def load_state_dict(
self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True
) -> NamedTuple:
if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True)
return super().load_state_dict(state_dict, strict=strict)
else:
with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs)
return self.module.load_state_dict(state_dict, strict)

def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views()
Expand Down
10 changes: 3 additions & 7 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.nn import Parameter
from torch.optim import SGD, Optimizer

from .utils import broadcast_object, recursive_copy_to_device
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device

__all__ = ["OSS"]

Expand Down Expand Up @@ -284,18 +284,14 @@ def clip_grad_norm(
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params

local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params)
total_norm = local_norm
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type,
)

# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
Expand Down
22 changes: 21 additions & 1 deletion fairscale/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import collections
import io
from typing import Any, Callable, Dict, Optional
from math import inf
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -102,3 +103,22 @@ def reset(self) -> None:
def full(self) -> bool:
""" is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in


def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda par: par.grad is not None, parameters))

if len(parameters) == 0:
return torch.tensor(0.0)
p = float(p)
if p == inf:
local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore
else:
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore
return local_norm
45 changes: 45 additions & 0 deletions fairscale/utils/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""Useful functions for parallel training."""

from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F


def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks


def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
"""
if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking.
return

world_size = process_group.size()
if "cuda" in str(device):
input_tensor = torch.ones(1).to(device)
output = list(torch.zeros(world_size).to(device).chunk(world_size))
dist.all_gather(output, input_tensor, group=process_group)
assert torch.cat(output).sum() == float(world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={world_size}. Check torch.cuda.set_device is called properly"
)
151 changes: 151 additions & 0 deletions fairscale/utils/reduce_scatter_bucketer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup


class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(data[0])

def flush(self) -> None:
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])


class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.

Usage::

bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2

Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""

def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

@torch.no_grad()
def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.

Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.

Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()

assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

first_input = input_list[0]
first_input_size = first_input.numel()

bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return

bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()

# copy data from input_list into bucket
stacked_input = torch.stack(input_list).view(world_size, first_input_size)
offset = bucket.offset
bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size

# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))

@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()

@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)

def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group)
return self.buckets[key]
Loading