diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index f8bdd814b..d119dfb00 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -3,4 +3,5 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from .fully_sharded_data_parallel import FullyShardedDataParallel from .sharded_ddp import ShardedDataParallel diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py new file mode 100644 index 000000000..2c19b6d42 --- /dev/null +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -0,0 +1,947 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +from enum import Enum, auto +import functools +from math import inf +from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union + +import torch +from torch.autograd import Variable +import torch.distributed as dist +from torch.distributed import ProcessGroup +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F + +from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.optim.utils import calc_grad_norm +from fairscale.utils.containers import ( + apply_to_tensors, + pack_kwargs, + split_non_tensors, + unpack_kwargs, + unpack_non_tensors, +) +from fairscale.utils.parallel import chunk_and_pad, validate_process_group +from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer + +if TYPE_CHECKING: + from collections import OrderedDict # noqa: F401 + + +class TrainingState(Enum): + """ + Simple enum to indicate what state FSDP is in. Used for asserting + to make sure APIs are called in the correct state. + + TODO (Min): It would be nice to capture the stepping state as well. + Maybe we can use the model.zero_grad() call, but not sure if it + is called if optim.zero_grad() is used instead. + It would be nice to have clear state transition be explicit like: + + zero_grad -> fwd -> bwd -> optionally accum grad by repeating + fwd/bwd -> stepping -> loop back to zero_grad + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD = auto() + + +class FullyShardedDataParallel(nn.Module): + """ + A wrapper for sharding Module parameters across data parallel workers. This + is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. + + .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 + .. _DeepSpeed: https://www.deepspeed.ai/ + + Usage:: + + sharded_module = FullyShardedDataParallel(my_module) + optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + x = sharded_module(x, y=3, z=torch.Tensor([1])) + loss = x.sum() + loss.backward() + optim.step() + + It is also possible to shard individual layers separately and have an outer + wrapper handle any leftover parameters. This can be helpful to further + reduce memory usage and to improve training speed by distributing the + unsharding (all-gather) across the forward pass. For example:: + + sharded_model = FullyShardedDataParallel( + nn.Sequential( + nn.Linear(5, 100), + FullyShardedDataParallel(nn.Linear(100, 100)), + FullyShardedDataParallel(nn.Linear(100, 100)), + nn.Linear(100, 5), + ) + ) + + Args: + module (nn.Module): module to checkpoint + process_group (Optional): process group for sharding + reshard_after_forward (bool, Optional): if ``True``, reshard parameters + after the forward pass. This saves memory but slows training. This + is only relevant when resharding individual layers. + mixed_precision (bool, Optional): if ``True``, inputs, activations and + gradients will be kept in FP16; computation and communication will + occur in FP16; and a (sharded) master copy of the model weights will + be maintained in FP32. + fp32_reduce_scatter (bool, Optional): if ``True``, then reduce-scatter + gradients in FP32. This is only relevant when *``mixed_precision``* + is ``True``. + flatten_parameters (bool, Optional): if ``True``, flatten parameters + into a single contiguous tensor, which improves training speed. + cpu_offload (bool, Optional): if ``True``, offload FP32 params to CPU. + This is only relevant when *``mixed_precision``* is ``True``. + compute_dtype (torch.dtype, Optional): dtype for full parameters for + computation. This defaults to ``torch.float32`` unless + *``mixed_precision``* is set, in which case it defaults to + ``torch.float16``. + move_grads_to_cpu (bool, Optional): move gradient shard to CPU after + reduction. This is useful when combined with CPU-based optimizers. + It defaults to the value of *``cpu_offload``*. + bucket_cap_mb (int, Optional): FSDP will bucket parameters so that + gradient reduction can potentially overlap with backward + computation. bucket_cap_mb controls the bucket size in MegaBytes + (MB). Buckets are sub-divided based on world_size, so the max shard + size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable + bucketing. Default: 25. + """ + + def __init__( + self, + module: nn.Module, + process_group: Optional[ProcessGroup] = None, + reshard_after_forward: bool = True, + mixed_precision: bool = False, + fp32_reduce_scatter: bool = False, + flatten_parameters: bool = True, + cpu_offload: bool = False, + compute_dtype: Optional[torch.dtype] = None, + move_grads_to_cpu: Optional[bool] = None, + bucket_cap_mb: int = 25, + ): + super().__init__() + self.process_group = process_group or dist.new_group() + self.rank = self.process_group.rank() + self.world_size = self.process_group.size() + self.reshard_after_forward = reshard_after_forward + self.mixed_precision = mixed_precision + self.fp32_reduce_scatter = fp32_reduce_scatter + self.flatten_parameters = flatten_parameters + self.cpu_offload = cpu_offload + self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) + self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu + self.bucket_cap_mb = bucket_cap_mb + + if self.fp32_reduce_scatter and not self.mixed_precision: + raise ValueError("fp32_reduce_scatter requires mixed_precision=True") + if self.cpu_offload and not self.mixed_precision: + raise ValueError("cpu_offload requires mixed_precision=True") + + compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device + validate_process_group(compute_device, self.process_group) + + # Only handle params which are not already sharded. This enables + # sharding individual layers of a Module, with an outer wrapper to + # shard any leftover parameters. + params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded")) + + if self.flatten_parameters and len(params) > 0: + self.module: nn.Module = FlattenParamsWrapper(module, param_list=params) + del module # free original module in case it helps garbage collection + self.params = [self.module.flat_param] + else: + self.module = module + self.params = params + + # Shard module parameters in place + self._shard_parameters_() + + # Make sure all parameters are sharded. + for n, p in self.named_parameters(): + assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}" + + self._reset_lazy_init() + + # Flag to indicate if we require gradient reduction in the backward + # pass. This will be False when inside the no_sync context manager. + self.require_backward_grad_sync: bool = True + + self.training_state = TrainingState.IDLE + + @torch.no_grad() + def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + """Move all buffers to the specified device and dtype, recursively.""" + cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) + self.apply(cast_fn) + + @property + def params_with_grad(self) -> List[Parameter]: + """[p for p in self.parameters() if p.grad is not None] """ + return [p for p in self.parameters() if p.grad is not None] + + @torch.no_grad() + def clip_grad_norm_( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + # filter_params_fn: Callable[[Any], Any] = None, + ) -> torch.Tensor: + """ + Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Arguments: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + .. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank + under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads + in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters + + .. warning: This needs to be called on all ranks, since synchronization primitives will be used + + """ + assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" + assert self.training_state == TrainingState.IDLE + + max_norm = float(max_norm) + norm_type = float(norm_type) + params_with_grad = self.params_with_grad + if not self.children_share_process_group: + raise NotImplementedError( + "clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work" + ) + # Computes the max norm for this shard's gradients and sync's across workers + local_norm = calc_grad_norm(params_with_grad, norm_type).cuda() + if norm_type == inf: + total_norm = local_norm + dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) + else: + total_norm = local_norm ** norm_type + dist.all_reduce(total_norm, group=self.process_group) + total_norm = total_norm ** (1.0 / norm_type) + + if self.move_grads_to_cpu: + total_norm = total_norm.cpu() + # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq) + clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) + if clip_coef < 1: + + # multiply by clip_coef + for p in params_with_grad: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore + + return total_norm + + @torch.no_grad() + def _shard_parameters_(self) -> None: + """ + At initialization we wrap a module with full parameters and shard the + parameters in-place. Sharding is implemented by viewing each parameter + as a 1D Tensor and retaining only a single slice, where the slice size + is determined by the number of data parallel workers. + + Wrapping modules with many small parameters (or with a very large data + parallel world size) will result in many small parameter shards and slow + performance. In this case it's better to set *``flatten_parameters``* to + ``True``, so that all of the small parameters in the module are combined + into a single contiguous Tensor and sharded once. + + After this initial sharding is complete, the user can initialize a + ``torch.optim.Optimizer`` in the usual way, i.e.:: + + .. code-block:: python + + optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + + The optimizer will see only a single slice of parameters and will thus + allocate less memory for optimizer state, avoiding redundancy across + data parallel workers. + """ + for p in self.params: + assert not hasattr(p, "_is_sharded") + assert p.is_floating_point() + if self.mixed_precision: + assert p.dtype == torch.float32 + + # If world_size is 1, then we all-reduce grads instead of sharding. + p._is_sharded = self.world_size > 1 + p._orig_size = p.data.size() + + if not p._is_sharded: + continue + p._is_sharded = True + + # Shard using torch.chunk to match all-gather/reduce-scatter. + chunks = list(torch.flatten(p.data).chunk(self.world_size)) + while len(chunks) < self.world_size: + chunks.append(chunks[0].new_empty(0)) + + # Determine number of padding elements. + num_to_pad = chunks[0].numel() - chunks[self.rank].numel() + assert num_to_pad >= 0, num_to_pad + + # Replace p.data with the relevant shard. + orig_data = p.data + p.data = chunks[self.rank].clone() # clone since we free storage below + if num_to_pad > 0: + p.data = F.pad(p.data, [0, num_to_pad]) + free_storage_(orig_data) + + def extra_repr(self) -> str: + return ( + f"rank={self.rank}, world_size={self.world_size}, " + f"reshard_after_forward={self.reshard_after_forward}, " + f"mixed_precision={self.mixed_precision}, " + f"fp32_reduce_scatter={self.fp32_reduce_scatter}, " + f"flatten_parameters={self.flatten_parameters}, " + f"cpu_offload={self.cpu_offload}, " + f"compute_dtype={self.compute_dtype}, " + f"move_grads_to_cpu={self.move_grads_to_cpu}" + ) + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def __getstate__(self) -> Dict[str, str]: + """Serialize the state of the current FullyShardedDataParallel instance. + + Some properties are not serializable (e.g., process groups, streams), so + we remove them and try to reconstruct them in :func:`__setstate__`. + """ + state = copy.copy(self.__dict__) + state["is_sharded"] = [p._is_sharded for p in self.params] + state["orig_sizes"] = [p._orig_size for p in self.params] + if state["process_group"] is not None: + state["process_group"] = "MISSING" # process_group isn't pickleable + self._reset_lazy_init() + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Intercept state setting and perform needed changes on params.""" + super().__setstate__(state) + + def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter: + assert isinstance(p, Parameter) + p.data = p.data.clone() # move tensors out of shared memory + p._is_sharded = is_sharded + p._orig_size = size + return p + + self.params = [ + fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes) + ] + del self.is_sharded + del self.orig_sizes + self._reset_lazy_init() + + # TODO (Min): figuring out how to do typing for this overloaded function. + def state_dict(self, *args, **kwargs): # type: ignore + """ + Returns the whole (unsharded) state of the module. Parameters are not + sharded, so the resulting state_dict can be loaded directly by the + wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32 + """ + torch.cuda.synchronize() + self._lazy_init() + self._rebuild_full_params() + self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters. + state_dict = self.module.state_dict(*args, **kwargs) + # We don't free the params after generating the state dict, since + # freeing is done in-place (via the Storage) and would corrupt the + # returned state dict. However, we need to maintain the invariant that + # p.data corresponds to the FP32 param shard, so we do that here. + self._use_fp32_param_shard() + self._all_buffers_to(dtype=self.compute_dtype) + return state_dict + + # TODO (Min): figuring out how to do typing for this overloaded function. + def local_state_dict(self, *args, **kwargs): # type: ignore + """ + Returns the local (sharded) state of the module. Parameters are sharded, + so the resulting state_dict can only be loaded after the Module has been + wrapped with FullyShardedDataParallel. + """ + torch.cuda.synchronize() + self._lazy_init() + if self.flatten_parameters: + return self.module.flat_state_dict(*args, **kwargs) # type: ignore + else: + return self.module.state_dict(*args, **kwargs) + + def load_state_dict( + self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True + ) -> NamedTuple: + """Load a whole (unsharded) state_dict.""" + torch.cuda.synchronize() + self._lazy_init() + self._rebuild_full_params() + output = self.module.load_state_dict(state_dict, strict) + self._free_full_params() + return output + + def load_local_state_dict( + self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True + ) -> NamedTuple: + """Load a local (sharded) state_dict.""" + torch.cuda.synchronize() + return self.module.load_state_dict(state_dict, strict) + + @contextlib.contextmanager + def no_sync(self) -> Generator: + """ + A context manager to disable gradient synchronizations across DDP + processes. Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + """ + self._lazy_init() + assert self._is_root, "no_sync on inner FSDP is not supported" + self.assert_state(TrainingState.IDLE) + # This instance may wrap other FullyShardedDataParallel instances and we + # need to set all of them to accumulate gradients. + old_flags = [] + for m in self.modules(): # includes self + if isinstance(m, FullyShardedDataParallel): + old_flags.append((m, m.require_backward_grad_sync)) + m.require_backward_grad_sync = False + try: + yield + finally: + for m, old_flag in old_flags: + m.require_backward_grad_sync = old_flag + + def _reset_lazy_init(self) -> None: + """Reset instance so :func:`_lazy_init` will run on the next forward.""" + self._is_root: Optional[bool] = None + self._streams: Dict[str, torch.cuda.Stream] = {} + self._reducer: Optional[ReduceScatterBucketer] = None + + def _lazy_init(self) -> None: + """Initialization steps that should happen lazily, typically right + before the first forward pass.""" + # Initialize param attributes lazily, in case the param's dtype or + # device changes after __init__. + for p in self.params: + self._init_param_attributes(p) + + # Initialize _is_root and setup streams. These steps would ideally + # happen in __init__, but _is_root can only be determined after the + # entire model hierarchy is setup, thus we run it lazily. + if self._is_root is None: + self._set_is_root() + self._setup_streams() + + if self.cpu_offload: # Buffers stay on GPU, and don't get sharded + self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) + else: + self._all_buffers_to(dtype=self.compute_dtype) + + if self._is_root: + # Don't free the full params for the outer-most (root) instance, + # since those params will be needed immediately after for the + # backward pass. + self.reshard_after_forward = False + + # Due to the use of streams, we need to make sure the previous + # ``optim.step()`` is done before we all-gather parameters. + self._wait_for_previous_optim_step() + + @torch.no_grad() + def _init_param_attributes(self, p: Parameter) -> None: + """ + We manage several attributes on each Parameter instance. The first two + are set by :func:`_shard_parameters_`: + + ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` + if the Parameter is intentionally not sharded (in which case we + will all-reduce grads for this param). + ``_orig_size``: the size of the original Parameter (before sharding) + + The remaining attributes are set here: + ``_fp32_shard``: a single shard of the parameters in full precision + (typically FP32, but this is dependent on the dtype of the model + as it's passed in by the user). This can be on CPU or GPU + depending on the value of *``cpu_offload``*. + ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be + a single shard of the parameters in FP16, used for all-gather. + ``_full_param_padded``: the full weight (padded to be evenly + divisible by ``world_size``), used for computation in the + forward and backward pass. This will be resized in place and + only materialized (via all-gather) as needed. + """ + assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size") + if hasattr(p, "_fp32_shard"): + return + + # Compute device defaults to CUDA when *cpu_offload* is enabled, or the + # param's current device otherwise (could be CPU). + compute_device = torch.device("cuda") if self.cpu_offload else p.device + + # A single shard of the parameters in full precision. + p._fp32_shard = p.data + + if self.mixed_precision: + assert p._fp32_shard.dtype == torch.float32 + + if self.cpu_offload: + assert p._fp32_shard.device == torch.device("cpu") + # If we plan to keep the FP32 parameters on CPU, then pinning + # memory allows us to later use non-blocking transfers when moving + # the FP32 param shard to compute_device. + p._fp32_shard = p._fp32_shard.pin_memory() + p.data = p._fp32_shard + + # In mixed precision mode, we maintain a reduced precision + # (typically FP16) parameter shard on compute_device for performing + # the computation in the forward/backward pass. We resize the + # storage to size 0 at init (here) and re-materialize (by copying + # from _fp32_shard) as needed. + p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype) + free_storage_(p._fp16_shard) + else: + p._fp16_shard = None # use _fp32_shard + + # We also maintain a full-sized parameter of type self.compute_dtype + # (FP16 for mixed_precision or FP32 otherwise). We resize the + # storage to size 0 at init (here) and only materialize as needed. The + # storage may contain padding elements so that it is evenly divisible by + # world_size, although these padding elements will be removed before the + # relevant computation. + if p._is_sharded: + p._full_param_padded = torch.zeros( + p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype + ) + free_storage_(p._full_param_padded) + + if self.move_grads_to_cpu: + # We can optionally move the grad shard to CPU during the backward + # pass. In this case, it's important to pre-allocate the CPU grad + # shard in pinned memory so that we can do a non-blocking transfer. + p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() + + def _set_is_root(self) -> None: + """If ``True``, implies that no other :class:`FullyShardedDataParallel` + instance wraps this one. Called once by :func:`_lazy_init`. + Also sets self.children_share_process_group = True if all child instances share the same process group. + If some child instances use a different process group, self.clip_grad_norm_ will raise an error. + """ + if self._is_root is not None: + return + # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False + self._is_root = True + # As the root, we now set all children instances to False. + self.children_share_process_group = True + for n, m in self.named_modules(): + if n != "" and isinstance(m, FullyShardedDataParallel): + assert m._is_root is None + m._is_root = False + if m.process_group != self.process_group: + self.children_share_process_group = False + + def _setup_streams(self) -> None: + """Create streams to overlap data transfer and computation.""" + if len(self._streams) > 0 or not self._is_root: + return + # Stream to move main FP32 params (may be on CPU) to FP16 for forward. + self._streams["fp32_to_fp16"] = torch.cuda.Stream() + # Stream for all-gathering parameters. + self._streams["all_gather"] = torch.cuda.Stream() + # Stream for overlapping grad reduction with the backward pass. + self._streams["post_backward"] = torch.cuda.Stream() + # Helper for bucketing reduce-scatter ops. This is also shared with + # children instances to improve bucket utilization. + self._reducer = ReduceScatterBucketer(self.bucket_cap_mb) + # We share streams with all children instances, which allows them to + # overlap transfers across the forward pass without synchronizing with + # the default stream. + for n, m in self.named_modules(): + if n != "" and isinstance(m, FullyShardedDataParallel): + m._streams = self._streams + m._reducer = self._reducer + + def _wait_for_previous_optim_step(self) -> None: + """ + The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root + instance) needs to synchronize with the default stream to ensure the + previous optimizer step is done. + """ + if self.mixed_precision: + self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) + else: + self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) + + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + self._lazy_init() + + # Start of a forward pass. + self.training_state = TrainingState.FORWARD + + if self.mixed_precision: + args, kwargs = cast_inputs_to_fp16(*args, **kwargs) + + # All-gather full parameters. This will also transfer FP32 parameters to + # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). + self._rebuild_full_params() + + # Register backward hooks to reshard params and reduce-scatter grads. + # These need to be re-registered every forward pass. + self._register_post_backward_hooks() + + outputs = self.module(*args, **kwargs) + + if self.reshard_after_forward: + self._free_full_params() + + # Switch to main FP32 param shard. We maintain this invariant throughout + # the code, i.e., ``p.data == p._fp32_shard`` after each function. This + # also ensures that after the first forward, the optimizer state will be + # initialized with the correct dtype and (sharded) size, since optimizer + # state is typically initialized lazily in ``optim.step()``. + self._use_fp32_param_shard() + + # Register pre-backward hooks to all-gather the params for the backward + # pass (if needed). + outputs = self._register_pre_backward_hooks(outputs) + + # Done with a forward pass. + self.training_state = TrainingState.IDLE + + return outputs + + def _register_pre_backward_hooks(self, outputs: Any) -> Any: + """Register pre-backward hook to run before the wrapped module's + backward. Hooks should be attached to all outputs from the forward.""" + if not torch.is_grad_enabled(): + return outputs # don't register hooks if grad isn't enabled + + pre_backward_hook_has_run = [False] + + def _pre_backward_hook(*unused: Any) -> None: + if pre_backward_hook_has_run[0]: + return # only run once + pre_backward_hook_has_run[0] = True + + # Start of a backward pass. + self.training_state = TrainingState.BACKWARD + + # All-gather full parameters. + if self.reshard_after_forward: + self._rebuild_full_params() + else: + self._use_full_params() + # Make sure p.grad has the correct size/device (or set it to None). + self._prep_grads_for_backward() + + def _register_hook(t: torch.Tensor) -> torch.Tensor: + t.register_hook(_pre_backward_hook) + return t + + # Attach hooks to Tensor outputs. + outputs = apply_to_tensors(_register_hook, outputs) + + return outputs + + def _register_post_backward_hooks(self) -> None: + """Register backward hooks to reshard params and reduce-scatter grads.""" + if not torch.is_grad_enabled(): + return # don't register grad hooks if grad isn't enabled + self._post_backward_callback_queued = False + for p in self.params: + if p.requires_grad: + if hasattr(p, "_shard_bwd_hook"): + p._shard_bwd_hook[1].remove() # remove existing handle + p_tmp = p.expand_as(p) + grad_acc = p_tmp.grad_fn.next_functions[0][0] + handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) + p._shard_bwd_hook = (grad_acc, handle) + + @torch.no_grad() + def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: + """ + At the start of :func:`_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will replace + ``param.grad`` with a single shard of the summed gradient across all + GPUs. This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 + + The local GPU's ``optim.step`` is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by :func:`_shard_parameters_`, which ensures that + the local optimizer only sees the relevant parameter shard. + """ + self.assert_state(TrainingState.BACKWARD) + if param.grad is None: + return + if param.grad.requires_grad: + raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") + + # Free full params and switch to FP32 shard after backward. + self._free_full_params([param]) + self._use_fp32_param_shard([param]) + if self.mixed_precision: + # This is a no-op if reshard_after_forward is True, since we already + # free the param shard when rebuilding the full params in the + # pre_backward_hook. + self._free_fp16_param_shard([param]) + + # Enqueue a callback at the end of the backward pass to ensure that all + # post-backward work has finished. We only need one callback and it only + # needs to be called from the outer-most (root) instance. + if self._is_root and not self._post_backward_callback_queued: + self._post_backward_callback_queued = True + Variable._execution_engine.queue_callback(self._wait_for_post_backward) + + if not self.require_backward_grad_sync: + return + + # Wait for all work in the current stream to finish, then start the + # reductions in post_backward stream. + self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._streams["post_backward"]): + orig_grad_data = param.grad.data + + if self.mixed_precision and self.fp32_reduce_scatter: + # Cast grad to FP32. + param.grad.data = param.grad.data.to(param.dtype) + + if self.world_size > 1: + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.world_size) + + callback_fn = functools.partial(self._post_reduction_hook, param) + if param._is_sharded: + assert param._is_sharded + assert self._reducer is not None + grad_chunks = chunk_and_pad(param.grad.data, self.world_size) + self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn) + else: + # Currently the only way for _is_sharded to be False is if + # world_size == 1. This could be relaxed in the future, in which + # case grads should be all-reduced here. + assert self.world_size == 1 + callback_fn(param.grad.data) + + # After _post_backward_hook returns, orig_grad_data will eventually + # go out of scope, at which point it could otherwise be freed for + # further reuse by the main stream while the div/reduce_scatter/copy + # are underway in the post_backward stream. See: + # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py + orig_grad_data.record_stream(self._streams["post_backward"]) + + def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + """Hook to call on each param after the reduce-scatter.""" + assert torch.cuda.current_stream() == self._streams["post_backward"] + assert param.grad is not None + self.assert_state(TrainingState.BACKWARD) + param.grad.data = reduced_grad + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Optionally move gradients to CPU, typically used if one is running + # the optimizer on the CPU. + if self.move_grads_to_cpu: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad + # Don't let this memory get reused until after the transfers. + reduced_grad.record_stream(torch.cuda.current_stream()) + + @torch.no_grad() + def _wait_for_post_backward(self) -> None: + """Wait for post-backward work to finish. Only called on root instance.""" + assert self._is_root + self.assert_state(TrainingState.BACKWARD) + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self._streams["post_backward"]): + assert self._reducer is not None + self._reducer.flush() + torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) + if self.move_grads_to_cpu: + # Wait for the non-blocking GPU -> CPU grad transfers to finish. + torch.cuda.current_stream().synchronize() + # A backward pass is done. + self.training_state = TrainingState.IDLE + + @torch.no_grad() + def _rebuild_full_params(self) -> None: + """Gather all shards of params.""" + with torch.cuda.stream(self._streams["all_gather"]): + if self.mixed_precision: + self._cast_fp32_param_shards_to_fp16() + + for p in self.params: + if not p._is_sharded: + if self.mixed_precision: + p.data = p._fp16_shard + continue + + p_size = p._full_param_padded.size() + if p._full_param_padded.storage().size() != p_size.numel(): + # Allocate based on full size from all shards. + alloc_storage_(p._full_param_padded, size=p_size) + assert p_size.numel() % self.world_size == 0 + if p._is_sharded: + # Fill p._full_param_padded with (p.data for each shard in self.world_size) + chunks = list(p._full_param_padded.chunk(self.world_size)) + dist.all_gather(chunks, p.data, group=self.process_group) + else: + p._full_param_padded.copy_(torch.flatten(p.data), non_blocking=True) + + p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) + + if self.mixed_precision: + self._free_fp16_param_shard([p]) + torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) + + @torch.no_grad() + def _use_full_params(self) -> None: + for p in self.params: + if not p._is_sharded: + if self.mixed_precision: + assert p._fp16_shard.storage().size() != 0 + p.data = p._fp16_shard + else: + assert p._full_param_padded.storage().size() != 0 + p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) + + @torch.no_grad() + def _prep_grads_for_backward(self) -> None: + """Make sure p.grad has the correct size/device, otherwise set it to None.""" + for p in self.params: + if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device): + p.grad = None + + @torch.no_grad() + def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: + """Free up storage for full parameters.""" + if params is None: + params = self.params + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(self._streams["all_gather"]): + for p in params: + if not p._is_sharded: + if self.mixed_precision: + self._free_fp16_param_shard([p]) + continue + # There may be external references to the Tensor Storage that we + # can't modify, such as references that are created by + # ctx.save_for_backward in the forward pass. Thus when we + # unshard parameters, we should reuse the original Tensor + # Storage object and unshard it in-place. For now, just resize + # the Storage to 0 to save memory. + p._full_param_padded.record_stream(current_stream) + free_storage_(p._full_param_padded) + + @torch.no_grad() + def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: + """Use FP32 shard for a list of params.""" + if params is None: + params = self.params + for p in params: + p.data = p._fp32_shard + + @torch.no_grad() + def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: + """Cast FP32 param shard to FP16 for a list of params.""" + if params is None: + params = self.params + with torch.cuda.stream(self._streams["fp32_to_fp16"]): + for p in params: + assert p._fp16_shard is not None + alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) + p._fp16_shard.copy_( + # If cpu_offload is True, this will be non-blocking because + # _fp32_shard is pinned, otherwise it's a no-op. + p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) + ) + p.data = p._fp16_shard + torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) + + @torch.no_grad() + def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None: + """Free storage for FP16 shards for a list of params.""" + if params is None: + params = self.params + current_stream = torch.cuda.current_stream() + for p in params: + if p._fp16_shard is not None: + # _fp16_shard is allocated in _fp32_to_fp16_stream, so we can't + # free it until the work in the current stream completes. + p._fp16_shard.record_stream(current_stream) + free_storage_(p._fp16_shard) + + def assert_state(self, state: TrainingState) -> None: + """Assert we are in the given state.""" + assert ( + self.training_state == state + ), f"expected to be in state {state} but current state is {self.training_state}" + + +@torch.no_grad() +def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: + """ + Cast any Tensors in *args or **kwargs to FP16. + + Doesn't currently support Tensors nested inside containers (e.g., dict). + """ + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) + tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) + flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + return args, kwargs + + +def cast_buffers_( + module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None +) -> None: + """Cast all of module.named_buffers to device, dtype.""" + # if buffers are already on the right device and/or dtype this is just python loop cost + for key, buf in module.named_buffers(recurse=False): + if buf is not None: + setattr(module, key, buf.to(dtype=dtype, device=device)) + + +def free_storage_(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + assert data.storage().size() == data.numel() + data.storage().resize_(0) + + +@torch.no_grad() +def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: + """Allocate storage for a tensor.""" + if data.storage().size() == size.numel(): # no need to reallocate + return + assert data.storage().size() == 0 + data.storage().resize_(size.numel()) diff --git a/fairscale/nn/misc/checkpoint_activations.py b/fairscale/nn/misc/checkpoint_activations.py index 519503229..6eb40b16f 100644 --- a/fairscale/nn/misc/checkpoint_activations.py +++ b/fairscale/nn/misc/checkpoint_activations.py @@ -3,8 +3,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from contextlib import contextmanager import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Generator, Optional, Tuple import torch from torch import Tensor @@ -73,6 +74,23 @@ def set_rng_state(state: Dict[str, Any]) -> None: torch.cuda.set_rng_state(state["cuda_rng_state"]) +def is_autocast_enabled() -> bool: + """Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1""" + if hasattr(torch, "is_autocast_enabled"): + return torch.is_autocast_enabled() + return False + + +@contextmanager +def autocast(enabled: bool) -> Generator: + """Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1""" + if enabled: + with torch.cuda.amp.autocast(enabled): + yield + else: + yield + + class CheckpointFunction(torch.autograd.Function): """Similar to the torch version, but support non-Tensor outputs. @@ -96,13 +114,13 @@ def forward( # type: ignore ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() + ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs) - else: ctx.fwd_device, ctx.grad_requirements = None, None @@ -142,10 +160,11 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: # Set the states to what it used to be before the forward pass. set_rng_state(ctx.fwd_rng_state) - with torch.enable_grad(): + with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd): unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 095828f30..d05f0f3a0 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -2,12 +2,15 @@ # Licensed under the MIT License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor import torch.nn as nn +if TYPE_CHECKING: + from collections import OrderedDict # noqa: F401 + class FlattenParamsWrapper(nn.Module): """ @@ -127,21 +130,23 @@ def __getattr__(self, name: str) -> Any: except AttributeError: return getattr(self.module, name) # fallback to wrapped module - def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore + def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore """Return an unflattened state_dict.""" with self.unflatten_params(): - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + return self.module.state_dict(*args, **kwargs) def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Return the flattened state_dict.""" return super().state_dict(*args, **kwargs) - def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None: + def load_state_dict( + self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True + ) -> NamedTuple: if "flat_param" in state_dict: - super().load_state_dict(state_dict, strict=True) + return super().load_state_dict(state_dict, strict=strict) else: with self.unflatten_params(): - return self.module.load_state_dict(state_dict, *args, **kwargs) + return self.module.load_state_dict(state_dict, strict) def forward(self, *inputs: Any, **kwinputs: Any) -> Any: self._unflatten_params_as_views() diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 805352d21..035c80137 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -15,7 +15,7 @@ from torch.nn import Parameter from torch.optim import SGD, Optimizer -from .utils import broadcast_object, recursive_copy_to_device +from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device __all__ = ["OSS"] @@ -284,18 +284,14 @@ def clip_grad_norm( # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params + local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device) # Compute the norm on this grad set, # then sync all the norms from all ranks if norm_type == inf: - total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params) + total_norm = local_norm # all reduce over data parallel and model parallel workers dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) else: - local_norm = torch.norm( - input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore - p=norm_type, - ) - # local norm result can be accumulated with the remote ones if put to the right power # n_i = sum_rank(a^p)^1/p # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p diff --git a/fairscale/optim/utils.py b/fairscale/optim/utils.py index e8ac44fcd..cb84be661 100644 --- a/fairscale/optim/utils.py +++ b/fairscale/optim/utils.py @@ -5,7 +5,8 @@ import collections import io -from typing import Any, Callable, Dict, Optional +from math import inf +from typing import Any, Callable, Dict, List, Optional import torch import torch.distributed as dist @@ -102,3 +103,22 @@ def reset(self) -> None: def full(self) -> bool: """ is the bucket full ? """ return self.max_params_checked_in == self.params_checked_in + + +def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: + r"""Calculate gradient norm of an iterable of parameters. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda par: par.grad is not None, parameters)) + + if len(parameters) == 0: + return torch.tensor(0.0) + p = float(p) + if p == inf: + local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore + else: + local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore + return local_norm diff --git a/fairscale/utils/parallel.py b/fairscale/utils/parallel.py new file mode 100644 index 000000000..2b3cfaf30 --- /dev/null +++ b/fairscale/utils/parallel.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Useful functions for parallel training.""" + +from typing import List + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +import torch.nn.functional as F + + +def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: + """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" + chunks = list(torch.flatten(tensor).chunk(num_chunks)) + # torch.chunk may return fewer than num_chunks chunks, pad accordingly. + num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) + if len(chunks) < num_chunks: + chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) + return chunks + + +def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: + """Do a quick test in case user called FSDP without calling torch.cuda.set_device() + correctly. This can easily happen in cpu_offload case where the model resides on + the CPU. + """ + if not hasattr(process_group, "allgather"): + # Likely a dummy pg for unit test, skip checking. + return + + world_size = process_group.size() + if "cuda" in str(device): + input_tensor = torch.ones(1).to(device) + output = list(torch.zeros(world_size).to(device).chunk(world_size)) + dist.all_gather(output, input_tensor, group=process_group) + assert torch.cat(output).sum() == float(world_size), ( + f"found {torch.cat(output).sum()} devices in process group but " + f"world_size={world_size}. Check torch.cuda.set_device is called properly" + ) diff --git a/fairscale/utils/reduce_scatter_bucketer.py b/fairscale/utils/reduce_scatter_bucketer.py new file mode 100644 index 000000000..b8e2eba54 --- /dev/null +++ b/fairscale/utils/reduce_scatter_bucketer.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Bucket: + def __init__(self, data: Tensor, group: ProcessGroup): + self.data = data + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + self.output_shard = torch.zeros_like(data[0]) + + def flush(self) -> None: + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.reduce_scatter( + self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group + ) + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.data[:, : self.offset].zero_() + self.offset = 0 + self.callbacks.clear() + self.output_shard = torch.zeros_like(self.data[0]) + + +class ReduceScatterBucketer: + """ + Helper for bucketing multiple reduce-scatter operations on small tensors + into larger reduce-scatter ops to improve communication efficiency. + + Usage:: + + bucketer = ReduceScatterBucketer() + bucketer.reduce_scatter_async( + small_tensors, callback_fn=lambda result: print("small") + ) + bucketer.reduce_scatter_async( + big_tensors, callback_fn=lambda result: print("big") + ) + bucketer.reduce_scatter_async( + more_small_tensors, callback_fn=lambda result: print("small2") + ) + bucketer.flush() # callbacks only guaranteed to be called after flush() + # Example output (note that it is out of order, due to bucketing): + # big + # small + # small2 + + Args: + bucket_cap_mb (int, Optional): bucket size for communicating. Buckets + are sub-divided based on world_size. Values <= 0 disable bucketing. + """ + + def __init__(self, bucket_cap_mb: int = 25): + self.bucket_cap_mb = bucket_cap_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def reduce_scatter_async( + self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, + ) -> None: + """ + Reduce-scatter a list of tensors asynchronously, so smaller reductions + can be bucketed together. The given callback (``callback_fn``) will be + called with the reduced result at some later time. Call ``flush()`` to + force all queued ops and callbacks to be executed. + + Note that large inputs will be reduced immediately, and this function + may also flush the relevant bucket to make room for ``input_list``. + + Args: + input_list (List[Tensor]): list of tensors to reduce-scatter. List + should contain ``group.size()`` tensors and each tensor should + have identical shape, dtype and device. + group (ProcessGroup): process group for reduction + callback_fn (Callable, Optional): callback function to call after + the reduction executes. Function will be called with a single + argument corresponding to the reduced result. + """ + world_size = group.size() + + assert ( + len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + + first_input = input_list[0] + first_input_size = first_input.numel() + + bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) + if first_input_size > bucket_shard_size: + # input is too big to fit in the bucket, reduce-scatter directly + output = torch.zeros_like(input_list[0]) + dist.reduce_scatter(output, input_list, group=group) + if callback_fn is not None: + callback_fn(output) + return + + bucket = self._get_bucket(first_input, group) + if first_input_size > bucket.data.size(1) - bucket.offset: + # not enough space remaining in bucket, flush it now + bucket.flush() + + # copy data from input_list into bucket + stacked_input = torch.stack(input_list).view(world_size, first_input_size) + offset = bucket.offset + bucket.data[:, offset : offset + first_input_size].copy_(stacked_input) + bucket.offset += first_input_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input) + bucket.callbacks.append(functools.partial(callback_fn, result_view)) + + @torch.no_grad() + def flush(self) -> None: + """Reduce-scatter any partial buckets.""" + for bucket in self.buckets.values(): + bucket.flush() + + @functools.lru_cache() + def _get_shard_size(self, element_size: int, num_shards: int) -> int: + if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_cap_mb * MB / element_size + return int(bucket_size // num_shards) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) + world_size = group.size() + shard_size = self._get_shard_size(tensor.element_size(), world_size) + data = tensor.new_zeros((world_size, shard_size)) + self.buckets[key] = Bucket(data, group) + return self.buckets[key] diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index fd099b4b9..96bc81237 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -33,11 +33,12 @@ import random import sys import tempfile -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy import pytest import torch +from torch import Tensor import torch.distributed as dist from torch.distributed import rpc import torch.multiprocessing as mp @@ -46,6 +47,11 @@ from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed +if TYPE_CHECKING: + Base = nn.Module[Tensor] +else: + Base = nn.Module + skip_if_no_cuda = pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required" ) @@ -75,12 +81,12 @@ _, filename_mpi = tempfile.mkstemp() -class IdentityLayer(torch.nn.Module): +class IdentityLayer(Base): def __init__(self, size: int, scale: float = 1.0) -> None: super(IdentityLayer, self).__init__() self.weight = torch.nn.Parameter(scale * torch.randn(size)) - def forward(self, *_: Any, **__: Any) -> Any: + def forward(self, *_: Any, **__: Any) -> Tensor: return self.weight @@ -103,7 +109,7 @@ def torch_version() -> Tuple[int, ...]: # Assuming that we're interested in the second usecase more than the first, # return the pre-release or dev numbering - logging.warning(f"Pytorch pre-relase version {torch.__version__} - assuming intent to test it") + logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it") numbering[2] = "0" return tuple(int(n) for n in numbering) @@ -301,7 +307,7 @@ def replacement(*args: Any, **kwargs: Any) -> None: return prepare_test -class _Block(nn.Module): +class _Block(Base): def __init__(self, embed_dim: int, num_heads: int) -> None: super().__init__() self.ln_1 = nn.LayerNorm(embed_dim) @@ -309,7 +315,7 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) - def forward(self, *inputs: Any, **kwargs: Any) -> Any: + def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: x = inputs[0] attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype) attn_mask = torch.triu(attn_mask, diagonal=1) @@ -322,7 +328,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: return x -class GPT2(nn.Module): +class GPT2(Base): """ GPT2 pytorch implementation, for testing purposes in the image-GPT context Credits: https://github.com/teddykoker/image-gpt""" @@ -349,7 +355,7 @@ def __init__( self.head = nn.Linear(embed_dim, num_vocab, bias=False) self.clf_head = nn.Linear(embed_dim, num_classes) - def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore + def forward(self, x: Tensor, classify: bool = False) -> Any: # type: ignore """ Expect input as shape [sequence len, batch] If classify, return classification logits @@ -451,3 +457,89 @@ def check_same_models_across_ranks( assert not params_should_be_equal or torch.all( torch.eq(receptacle[0], sync_b) ), "Models differ in between ranks" + + +class DeviceAndTypeCheckModule(Base): + """A simple module for checking Tensor devices and dtypes.""" + + def __init__( + self, + expected_input_dtype: Optional[torch.dtype] = None, + expected_input_device: Optional[torch.device] = None, + expected_param_dtype: Optional[torch.dtype] = None, + expected_param_device: Optional[torch.device] = None, + expected_loss_dtype: Optional[torch.dtype] = None, + expected_loss_device: Optional[torch.device] = None, + ): + super().__init__() + self.expected_input_dtype = expected_input_dtype + self.expected_input_device = expected_input_device + self.expected_param_dtype = expected_param_dtype + self.expected_param_device = expected_param_device + self.expected_loss_dtype = expected_loss_dtype + self.expected_loss_device = expected_loss_device + + self.linear = nn.Linear(5, 5) + + def _check( + self, + key: str, + x: Union[torch.device, torch.dtype], + expected: Union[Optional[torch.device], Optional[torch.dtype]], + ) -> None: + assert expected in {None, x}, f"{key} ({x}) != expected ({expected})" + + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + x = input[0] + self._check("input.dtype", x.dtype, self.expected_input_dtype) + self._check("input.device", x.device, self.expected_input_device) + + param = self.linear.weight + self._check("param.dtype", param.dtype, self.expected_param_dtype) + self._check("param.device", param.device, self.expected_param_device) + + loss = self.linear(x).sum() + self._check("loss.dtype", loss.dtype, self.expected_loss_dtype) + self._check("loss.device", loss.device, self.expected_loss_device) + + return loss + + +@functools.lru_cache() +def get_cycles_per_ms() -> float: + """Approximate number of cycles per millisecond for torch.cuda._sleep + + Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + + ..note:: + This doesn't seems to return consistent cycles on desktop GPUs likely + due to frequency scaling. + >>> get_cycles_per_ms() + 227.6441091140009 + # new python process + >>> get_cycles_per_ms() + 564.652154766248 + # new python process + >>> get_cycles_per_ms() + 245.56459442962856 + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + return cycles_per_ms + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size diff --git a/requirements-test.txt b/requirements-test.txt index 77ea87544..47b7173cc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,3 +13,5 @@ pytest-cov == 2.10.0 pytest-mpi == 0.4 pytest-timeout == 1.4.2 mpi4py == 3.0.3 +remote-pdb >= 2.1.0 +parameterized >= 0.8.1 diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index 64049ba9a..5732644b8 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -84,6 +84,7 @@ class Size(tuple): class Storage: def size(self) -> _int: ... def element_size(self) -> _int: ... + def resize_(self, int) -> None: ... #END # See https://github.com/python/mypy/issues/4146 for why these workarounds @@ -1913,6 +1914,7 @@ def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API def set_default_dtype(d : _dtype) -> None: ... def manager_path() -> str: ... def compiled_with_cxx11_abi() -> _bool: ... +def is_autocast_enabled() -> _bool: ... # The return value of this function depends on the value of `as_tuple`, # (similar to `unique`, `lu`, etc.); as such, it is not diff --git a/stubs/torch/cuda/amp/__init__.pyi b/stubs/torch/cuda/amp/__init__.pyi index 848378a60..f5bc87a59 100644 --- a/stubs/torch/cuda/amp/__init__.pyi +++ b/stubs/torch/cuda/amp/__init__.pyi @@ -1,3 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Any, Generator + from .grad_scaler import GradScaler as GradScaler + +class autocast: + def __init__(self, enabled=True) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, *args: Any) -> None: ... diff --git a/stubs/torch/distributed/__init__.pyi b/stubs/torch/distributed/__init__.pyi index 8c951b65e..56307fe61 100644 --- a/stubs/torch/distributed/__init__.pyi +++ b/stubs/torch/distributed/__init__.pyi @@ -37,12 +37,15 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce def is_initialized() -> bool: ... def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ... -def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ... +def new_group(ranks: Optional[List[int]] = None, + timeout: Optional[datetime.timedelta] = datetime.timedelta(0, 1800), + backend: Optional[Union[str, Backend]] = None): ... def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... +def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def destroy_process_group() -> None: ... diff --git a/stubs/torch/nn/modules/module.pyi b/stubs/torch/nn/modules/module.pyi index 2838fdd72..458b32668 100644 --- a/stubs/torch/nn/modules/module.pyi +++ b/stubs/torch/nn/modules/module.pyi @@ -2,7 +2,7 @@ from ... import Tensor, device, dtype from .. import Parameter -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic, NamedTuple from collections import OrderedDict from ...utils.hooks import RemovableHandle @@ -65,9 +65,10 @@ class Module(Generic[T_co]): def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ... - # TODO double-check this def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ... + def __setstate__(self, state: Dict[str, Any]) -> None: ... + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # back that same object. But if they pass nothing, an `OrederedDict` is created and returned. T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor]) @@ -78,7 +79,7 @@ class Module(Generic[T_co]): @overload def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ... - def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...): ... + def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...) -> NamedTuple: ... def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ... diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 14f5a0bf9..c6fdd30d0 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -1,9 +1,20 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .. import Tensor +from typing import Optional +from .. import Size, Tensor +from ..cuda import Stream import builtins class Parameter(Tensor): + # These are dynamic attributes added by shard_params_data_parallel class. + # Added here for better type checking. + _is_sharded: bool + _orig_size: Size + _cpu_grad: Tensor + _full_param_padded: Tensor + _fp32_shard: Tensor + _fp16_shard: Optional[Tensor] + def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... ... diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py new file mode 100644 index 000000000..27f4f5986 --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp.py @@ -0,0 +1,810 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import functools +import itertools +from math import inf +import pickle +import sys +from typing import Dict +import unittest +from unittest import mock + +from parameterized import parameterized +import torch +from torch import nn + +from fairscale.nn.data_parallel import FullyShardedDataParallel +from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper +from fairscale.utils.testing import ( + DeviceAndTypeCheckModule, + DummyProcessGroup, + dist_init, + get_cycles_per_ms, + objects_are_equal, + spawn_for_all_world_sizes, +) + +# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 +# All helper functions called by spawn must be either @classmethod, @staticmethod + +_BUFFER_NAME = "vocab_bias" + + +class DistributedTest(unittest.TestCase): + def setUp(self): + major, minor = torch.__version__.split(".")[:2] + major, minor = int(major), int(minor) + if major < 1 or (major == 1 and minor < 6): + raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast") + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + @staticmethod + def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): + model_device = next(model.parameters()).device + # use SGD with momentum instead of Adam, since Adam is scale invariant + # and this makes it bad for tests + optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + for _ in range(num_steps): + optim.zero_grad() + with torch.cuda.amp.autocast(enabled=autocast): + # Inputs always cuda regardless of move_grads_cpu, or model.device + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + loss = model.module.get_loss(input, output).to(model_device) + assert loss.dtype == torch.float32 + model.module.run_backward(loss) + if norm_type is not None: + clip_norm = 0.3 + if isinstance(model, FullyShardedDataParallel): + model.clip_grad_norm_(clip_norm, norm_type) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) + optim.step() + if hasattr(model, "assert_idle"): + model.assert_idle() + return loss.detach() + + @staticmethod + def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> FullyShardedDataParallel: + if cuda_first: + model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs).cuda(), group, **config) + else: + model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda() + return model + + +class TestMixedPrecision(DistributedTest): + def test_all_fp32(self): + self._spawn_test_case( + {"mixed_precision": False}, + False, # autocast enabled + torch.float32, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ) + + def test_mixed_precision(self): + self._spawn_test_case( + {"mixed_precision": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float16, # expected_reduce_dtype + ) + + def test_mixed_precision_autocast(self): + """If autocast enabled, loss should be fp32.""" + self._spawn_test_case( + {"mixed_precision": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float16, # expected_reduce_dtype + ) + + def test_mixed_precision_autocast_fp32_compute(self): + self._spawn_test_case( + {"mixed_precision": True, "compute_dtype": torch.float32}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ) + + def test_fp32_reduce_scatter(self): + self._spawn_test_case( + {"mixed_precision": True, "fp32_reduce_scatter": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ) + + def test_fp32_reduce_scatter_autocast(self): + self._spawn_test_case( + {"mixed_precision": True, "fp32_reduce_scatter": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ) + + def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2): + """Call test_dtypes inside of torch.multiprocessing.spawn""" + fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype) + spawn_and_init(fn, world_sizes=[world_size]) + + @staticmethod + def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): + # Patch torch.distributed.reduce_scatter to check the dtype of the reduction + orig_reduce_scatter = torch.distributed.reduce_scatter + + model = DeviceAndTypeCheckModule( + expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, + ) + + def _reduce_scatter(output, input_list, **kwargs): + for tensor in input_list: + model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) + return orig_reduce_scatter(output, input_list, **kwargs) + + with mock.patch("torch.distributed.reduce_scatter", new=_reduce_scatter): + model = FullyShardedDataParallel(model, group, **cfg).cuda() + device = next(model.parameters()).device + x = torch.rand(2, 5).to(device) + with torch.cuda.amp.autocast(enabled=autocast): + loss = model(x) + loss.backward() + + +keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"] +CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] + + +def rename_test(testcase_func, param_num, param): + return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) + + +class TestComparisonToPyTorchDDP(DistributedTest): + """ + Compare losses and parameter values after several updates when using + PyTorch DDP vs. FullyShardedDataParallel. + """ + + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_transformer_parameterized(self, config): + # Test every combination of these options: + spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config)) + + def test_cpu_offload_and_cpu_grads(self): + # We don't test the False condition because that requires the optimizer to internally do + # the device transfer and PyTorch optimizers don't support this. + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01 + ) + spawn_and_init(test_fn) + + def test_cpu_offload_and_cuda_grads_breaks(self): + # If grads are on gpu, but model and optimizer are on cpu, backward breaks. + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} + with self.assertRaises(Exception): # RuntimeError inside spawn + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False + ) + spawn_and_init(test_fn) + + def test_delayed_optim_step(self): + # We use a model with a long CUDA delay right before the optimizer step. + # This tests our streams logic, and that we don't start the FP32 -> FP16 + # transfer until after the optimization step completes. + config = {"mixed_precision": True} + model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_after_loss_ms=250) + test_fn = functools.partial(self._test_identical_outputs, model_fn, config) + spawn_and_init(test_fn) + + def test_delayed_reduce_scatter(self): + # We insert a delay in the torch.distributed.reduce_scatter op, so that + # the post_backward_stream takes much longer than the backward pass. + # This tests that we properly block at the end of the backward pass for + # the reductions to finish. + config = {"mixed_precision": True} + model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_before_reduction_ms=250) + test_fn = functools.partial(self._test_identical_outputs, model_fn, config) + spawn_and_init(test_fn) + + @parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test) + def test_mixture_of_experts(self, moe_config): + fsdp_config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, + functools.partial(MixtureOfExperts, **moe_config), + fsdp_config, + # MixtureOfExperts implements custom reduce logic, so the reference + # behavior should use that logic instead of PyTorch DDP. + ref_ddp_fn=self._dummy_ddp_fn, + norm_type=None, + ) + spawn_and_init(test_fn) + + def test_mixture_of_experts_grad_clip_breaks(self): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, MixtureOfExperts, config, ref_ddp_fn=self._dummy_ddp_fn, norm_type=2, + ) + with self.assertRaises(Exception): + spawn_and_init(test_fn) + + @classmethod + def _dummy_ddp_fn(self, model, group): + return DummyDDP(model) + + @classmethod + def _test_identical_outputs( + cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, + ): + if config["mixed_precision"]: + autocast = True + # Force the compute dtype to be torch.float32 so that we get + # identical results as PyTorch DDP when using autocast. Note that + # this will cause the all-gather to happen in FP32, which is slower + # than necessary in most cases. + config["compute_dtype"] = torch.float32 + else: + autocast = False + + # Establish reference behavior with PyTorch DDP (+ optionally autocast). + model = model_init_fn(group=group, wrapper_config=None).cuda() + if ref_ddp_fn is None: + model = nn.parallel.DistributedDataParallel( + model, device_ids=[rank], output_device=rank, process_group=group + ) + else: + model = ref_ddp_fn(model, group) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) + ref_state_dict = model.module.state_dict() + + # Confirm we get the same behavior using FullyShardedDataParallel. + model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) + if use_cuda: + model = model.cuda() + else: + assert next(model.parameters()).device == torch.device("cpu") + shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) + shard_state_dict = model.state_dict() + + try: + torch.testing.assert_allclose(ref_loss, shard_loss) + assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) + except (AssertionError, RuntimeError) as e: + raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") + + @parameterized.expand([[1], [inf]], name_func=rename_test) + def test_clip_norm_transformer(self, norm_type): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, norm_type=norm_type, + ) + spawn_and_init(test_fn) + + +class TestParamInit(DistributedTest): + def test_param_change_after_init(self): + test_fn = functools.partial(self._test_param_change_after_init, config={"mixed_precision": True}) + spawn_and_init(test_fn) + + @classmethod + def _test_param_change_after_init(self, rank, group, config): + # Establish reference behavior. + model = self.get_wrapped_model(group, cuda_first=False, config=config) + model.eval() # no dropout for this test + input = model.module.get_input(torch.device("cuda")) + ref_output = model(*input) + + # Change the weights in place. + model = self.get_wrapped_model(group, cuda_first=False, config=config) + model.eval() # no dropout for this test + first_param = next(model.parameters()) + nn.init.normal_(first_param.data) + new_output = model(*input) + + assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" + + +class TestSerialization(DistributedTest): + @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) + def test_pickle(self, mixed_precision, cpu_offload): + """Ensure that wrapped modules can be pickled/unpickled.""" + config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} + test_fn = functools.partial(self._test_pickle, config=config) + spawn_and_init(test_fn, world_sizes=[2]) + + @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) + def test_multiprocessing(self, mixed_precision, cpu_offload): + """Ensure that wrapped modules can be sent via multiprocessing.""" + config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} + test_fn = functools.partial(self._test_multiprocessing, config=config) + spawn_and_init(test_fn, world_sizes=[2]) + + @classmethod + def _test_pickle(self, rank, group, config): + model = self._get_model(group, config) + model = pickle.loads(pickle.dumps(model)) + if not config["cpu_offload"]: + model = model.cuda() + self._one_step(model, group) + + @classmethod + def _test_multiprocessing(self, rank, group, config): + mp = torch.multiprocessing.Pool(1) + dummy_group = DummyProcessGroup(rank=group.rank(), size=group.size()) + model = mp.apply(self._get_model, (dummy_group, config)) + if not config["cpu_offload"]: + model = model.cuda() + self._one_step(model, group) + + @classmethod + def _get_model(self, group, config): + with torch.no_grad(): # required for multiprocessing + model = NestedWrappedModule(group, wrapper_config=config) + return FullyShardedDataParallel(model, group, **config) + + @classmethod + def _one_step(self, model, group): + # reset the process group (required after unpickling) + for m in model.modules(): + if isinstance(m, FullyShardedDataParallel): + m.process_group = group + optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + loss = model.module.get_loss(input, output) + model.module.run_backward(loss) + optim.step() + + +class TestLocalStateDict(DistributedTest): + @parameterized.expand([[True, True], [False, False]], name_func=rename_test) + def test_load_local_state_dict(self, flatten_params, mixed_precision): + test_fn = functools.partial( + self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision} + ) + spawn_and_init(test_fn) + + @classmethod + def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23): + """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" + model = self.get_wrapped_model(group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model) + state_1 = model.local_state_dict() + state_before_training = {k: v.cpu().clone() for k, v in state_1.items()} + assert len(state_1) > 0 + model.load_local_state_dict(state_1) + weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight" + + state_1_weight = state_1[weight_key] + assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32" + if not model.flatten_parameters: + # The weight will be sharded since we access module.state_dict directly + state_1_module_weight = model.module.state_dict()[weight_key] + torch.testing.assert_allclose(state_1_weight, state_1_module_weight) + torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) + self._train_for_several_steps(model, 1, model.mixed_precision) + + state_2 = model.local_state_dict() + state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} + model.load_local_state_dict(state_2) + + assert state_1.keys() == state_2.keys() + + # Assert that parameters were updated since before training + unchanged = [] + for k in state_1: + if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k): + unchanged.append(k) + if unchanged: + raise AssertionError(f"params {unchanged} not changed after training") + + +class TestSaveLoadStateDict(DistributedTest): + @parameterized.expand([[False], [True]], name_func=rename_test) + def test_calling_state_dict_twice_mixed_precision(self, mixed_precision): + test_fn = functools.partial( + self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) + spawn_and_init(test_fn) + + @classmethod + def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs): + ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs) + autocast = ddp_model.mixed_precision + self._train_for_several_steps(ddp_model, 1, autocast) + ddp_model.state_dict() + ddp_model.state_dict() # second call + + @parameterized.expand([[False], [True]], name_func=rename_test) + def test_state_dict_after_forward_mixed_precision(self, mixed_precision): + test_fn = functools.partial( + self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) + spawn_and_init(test_fn) + + @parameterized.expand([[False], [True]], name_func=rename_test) + def test_state_dict_before_forward(self, mixed_precision): + test_fn = functools.partial( + self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) + spawn_and_init(test_fn) + + @classmethod + def _test_state_dict_before_forward(cls, config, rank, group): + ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) + sd = ddp_model.state_dict() + expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32 + wt = sd["embed_tokens.weight"] + assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}" + cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision) + + @classmethod + def _test_module_state_dict(cls, config, rank, group): + ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) + autocast = ddp_model.mixed_precision + cls._train_for_several_steps(ddp_model, 2, autocast) + state_1 = ddp_model.state_dict() + # You must make a new FullyShardedDataParallel instance to use module.load_state_dict + unwrapped_model = TransformerWithSharedParams(group) + unwrapped_model.load_state_dict(state_1) + new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda() + cls._train_for_several_steps(new_ddp_model, 2, autocast) + try: + ddp_model.load_state_dict(new_ddp_model.state_dict()) + assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" + except Exception: + pass + + +class TestHooks(DistributedTest): + # Feel free to modify these tests as the implementation changes. + # They aspire to make sure that backward hooks are registered and used + + @parameterized.expand([[True], [False]]) + def test_output_backward_hooks(self, cuda_first): + fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first) + spawn_and_init(fn) + + def test_backward_hooks_after_save(self): + fn = functools.partial(self._test_backward_hooks_after_save, cuda_first=False) + spawn_and_init(fn) + + @classmethod + def _test_backward_hooks_after_save(self, rank, group, cuda_first=False): + model = self.get_wrapped_model(group, cuda_first=cuda_first) + self._train_for_several_steps(model, 2, model.mixed_precision) + state_1 = model.local_state_dict() + model.load_local_state_dict(state_1) + self._test_output_backward_hooks(rank, group, cuda_first=cuda_first, model=model) + + @classmethod + def _test_output_backward_hooks(self, rank, group, cuda_first=False, model=None): + if model is None: + model = self.get_wrapped_model(group, cuda_first=cuda_first) + optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optim.zero_grad() + # Inputs always cuda regardless of move_grads_cpu, or model.device + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + assert len(output._backward_hooks) == 1 # this is pre-bwd hook + loss = model.module.get_loss(input, output).cuda() + loss.backward() + assert len(output._backward_hooks) == 1 # It doesn't get removed + optim.step() + assert len(output._backward_hooks) == 1 + + @parameterized.expand([[True], [False]]) + def test_register_functions_called(self, cuda_first): + fn = functools.partial(self._test_register_functions_called, cuda_first=cuda_first) + spawn_and_init(fn) + + @classmethod + def _test_register_functions_called(self, rank, group, cuda_first=False): + """Tests that _register_{pre|post}_backward_hooks called during forward.""" + model = self.get_wrapped_model(group, cuda_first=cuda_first) + input = model.module.get_input(torch.device("cuda")) + model._register_post_backward_hooks = mock.MagicMock(return_value=None) + model._register_pre_backward_hooks = mock.MagicMock(return_value=None) + assert not model._register_post_backward_hooks.called + assert not model._register_pre_backward_hooks.called + model(*input) + assert model._register_post_backward_hooks.called + assert model._register_pre_backward_hooks.called + + +class TestNoGrad(DistributedTest): + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_transformer_parameterized(self, config): + test_fn = functools.partial(self._test_transformer, config=config) + spawn_and_init(test_fn) + + @classmethod + def _test_transformer(self, rank, group, config): + autocast = config["mixed_precision"] + + # Train model for a step + model = self.get_wrapped_model(group, cuda_first=False, config=config) + self._train_for_several_steps(model, 1, autocast) + + model.eval() # no dropout for this test + + # Eval in standard mode (i.e., without no_grad) + input = model.module.get_input(torch.device("cuda")) + ref_output = model(*input) + + # Eval with no_grad and compare + with torch.no_grad(): + no_grad_output = model(*input) + + assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output" + + +class TestNoSync(DistributedTest): + def test_transformer(self): + fn = functools.partial(self._test_transformer, config={}) + spawn_and_init(fn) + + def test_transformer_no_flat_params(self): + config = {"flatten_parameters": False} + fn = functools.partial(self._test_transformer, config=config) + spawn_and_init(fn) + + def test_nested_wrapper(self): + fn = functools.partial(self._test_nested_wrapper, config={}) + spawn_and_init(fn) + + def test_no_sync_before_first_forward(self): + group = DummyProcessGroup(rank=0, size=1) + model = self.get_wrapped_model(group, config={}) + batch = model.module.get_input(torch.device("cuda")) + with model.no_sync(): + output = model(*batch) + loss = model.module.get_loss(batch, output) + loss.backward() + output = model(*batch) + loss = model.module.get_loss(batch, output) + loss.backward() + + @classmethod + def _test_transformer(self, rank, group, config): + model = self.get_wrapped_model(group, config=config) + model.eval() # turn off dropout for the test + self._test_no_sync(model, batch_dim=1) + + @classmethod + def _test_nested_wrapper(self, rank, group, config): + model = NestedWrappedModule(group, config) + model = FullyShardedDataParallel(model, group, **config).cuda() + self._test_no_sync(model, batch_dim=0) + + @classmethod + def _test_no_sync(self, model, batch_dim): + # Generate two input batches. We'll test that we get the same grads if + # we train on them sequentially while accumulating grads (with no_sync) + # vs. concatenating the batches and training in one go. + batch1 = model.module.get_input(torch.device("cuda")) + assert isinstance(batch1, tuple) + batch2 = tuple( + # This randomly permutes the values in a multi-dim tensor. + x.view(-1)[torch.randperm(x.numel())].view_as(x) + for x in batch1 + ) + for x, y in zip(batch1, batch2): + assert not torch.all(x == y) + + # Concat the batches along batch dimension. + concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2)) + + # Establish reference behavior on the concat batch. + model.zero_grad() + output = model(*concat_batch) + ref_loss = model.module.get_loss(concat_batch, output) + ref_loss.backward() + ref_grads = [p.grad.detach().clone() for p in model.parameters()] + + # Test that we get the same results by accumulating grads. + model.zero_grad() + with model.no_sync(): # accumulate gradients from the first batch + output = model(*batch1) + loss1 = model.module.get_loss(batch1, output) + loss1.backward() + output = model(*batch2) + loss2 = model.module.get_loss(batch2, output) + loss2.backward() + accumulated_loss = loss1 + loss2 + accumulated_grads = [p.grad.detach().clone() for p in model.parameters()] + + torch.testing.assert_allclose(ref_loss, accumulated_loss) + assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True) + + +class TransformerWithSharedParams(nn.Module): + def __init__(self, group, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): + super().__init__() + self.rank = group.rank() + self.world_size = group.size() + torch.manual_seed(0) # keep everything deterministic + assert d_vocab >= 12 # we use torch.arange(12) as input + self.embed_tokens = nn.Embedding(d_vocab, d_model) + self.transformer = nn.Transformer( + d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, + ) + self.output_proj = nn.Linear(d_model, d_vocab) + # share the embedding and output projection weights + self.output_proj.weight = self.embed_tokens.weight + self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,))) + + def get_input(self, device): + torch.manual_seed(1 + self.rank) # keep everything deterministic + src = torch.arange(12, device=device).view(6, 2) # T x B + tgt = torch.arange(8, device=device).view(4, 2) # T x B + return (src, tgt) + + def forward(self, src_ids, tgt_ids): + src = self.embed_tokens(src_ids) + src = src + self.vocab_bias + tgt = self.embed_tokens(tgt_ids) + x = self.transformer(src, tgt) + return self.output_proj(x) + + def get_loss(self, input, output): + _, tgt = input + return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") + + def run_backward(self, loss): + loss.backward() + + +class NestedWrappedModule(nn.Module): + def __init__(self, group, wrapper_config): + super().__init__() + self.rank = group.rank() + self.world_size = group.size() + self.wrapper_config = wrapper_config + + def _maybe_wrap(layer): + if wrapper_config is not None: + return FullyShardedDataParallel(layer, group, **wrapper_config) + return layer + + torch.manual_seed(0) # keep everything deterministic + self.module = nn.Sequential( + nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), + ) + + def get_input(self, device): + torch.manual_seed(1 + self.rank) # keep everything deterministic + return (torch.rand(4, 8, device=device),) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = output.sum() + return loss + + def run_backward(self, loss): + loss.backward() + + +class DummyDDP(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +class MixtureOfExperts(NestedWrappedModule): + def __init__(self, group, wrapper_config, checkpoint_act=False): + super().__init__(group, wrapper_config) + self.group = group + + # "expert" params are different on each rank + torch.manual_seed(42 + group.rank()) + expert = nn.Linear(16, 4) + for p in expert.parameters(): + p.expert = True + + # everything else is shared + torch.manual_seed(0) + shared = nn.Linear(4, 16) + + if checkpoint_act: + expert = checkpoint_wrapper(expert) + shared = checkpoint_wrapper(shared) + + if wrapper_config is not None: + # we create a process group of size 1 for the expert params + expert_group = torch.distributed.new_group([group.rank()]) + expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) + + shared = FullyShardedDataParallel(shared, group, **wrapper_config) + + self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8)) + + def run_backward(self, loss): + loss.backward() + + # manually reduce gradients if not wrapped in FullyShardedDataParallel + if self.wrapper_config is None: + with torch.no_grad(): + for p in self.parameters(): + if hasattr(p, "expert"): + continue # these params don't need grad reduction + p.grad.data.div_(self.world_size) + torch.distributed.all_reduce(p.grad.data, group=self.group) + + +class ModuleWithDelay(nn.Module): + def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0): + super().__init__() + self.delay_after_loss_ms = delay_after_loss_ms + self.delay_before_reduction_ms = delay_before_reduction_ms + self.module = module + + def get_input(self, device): + return self.module.get_input(device) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = self.module.get_loss(input, output) + if self.delay_after_loss_ms > 0: + torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) + return loss + + def run_backward(self, loss): + orig_reduce_scatter = torch.distributed.reduce_scatter + + def _delayed_reduce_scatter(*args, **kwargs): + if self.delay_before_reduction_ms > 0: + torch.cuda._sleep(int(self.delay_before_reduction_ms * get_cycles_per_ms())) + return orig_reduce_scatter(*args, **kwargs) + + with mock.patch("torch.distributed.reduce_scatter", _delayed_reduce_scatter): + self.module.run_backward(loss) + + +class NestedWrappedModuleWithDelay(ModuleWithDelay): + def __init__(self, group, wrapper_config, **kwargs): + super().__init__(NestedWrappedModule(group, wrapper_config), **kwargs) + + +def spawn_and_init(fn, args=None, **spawn_kwargs): + if args is None: + args = () + + run_fn = functools.partial(init_and_run, fn, args) + spawn_for_all_world_sizes(run_fn, **spawn_kwargs) + + +def init_and_run(fn, args, rank, world_size, filename, filename_rpc): + dist_init(rank, world_size, filename, filename_rpc) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/nn/data_parallel/test_fsdp_uneven.py b/tests/nn/data_parallel/test_fsdp_uneven.py new file mode 100644 index 000000000..e2ce116db --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_uneven.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test FSDP with uneven parameter shards. """ + +import tempfile + +import pytest +import torch +from torch import Tensor +import torch.multiprocessing as mp +from torch.nn import Linear, Sequential +from torch.optim import SGD + +from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP +from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState +from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version + + +def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): + result = dist_init(rank, world_size, tempfile_name, unused) + assert result, "Dist init failed" + + if test_case["assert_ref_out"]: + with torch.no_grad(): + weight = model.weight.T.clone().cuda() + v = torch.Tensor(test_case["inputs"][0][rank]).cuda() + ref_out = torch.matmul(v, weight) + model.to("cuda") + assert isinstance(fsdp_config, dict), str(fsdp_config) + model = FSDP(model, **fsdp_config) + optim = SGD(model.parameters(), lr=0.1) + inputs = test_case["inputs"] + assert len(inputs) == 1 or not test_case["assert_ref_out"] + assert len(inputs[0]) >= world_size + for in_data in inputs: + in_data = Tensor(in_data[rank]).cuda() + out = model(in_data) + out.sum().backward() + optim.step() + optim.zero_grad() + + if test_case["assert_ref_out"]: + torch.testing.assert_allclose(ref_out, out) + + model.assert_state(TrainingState.IDLE) + teardown() + + +@skip_if_single_gpu +@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}]) +@pytest.mark.parametrize( + "fsdp_config", [{}, {"flatten_parameters": False}], +) +@pytest.mark.parametrize("world_size", list(range(2, 9))) +def test_one_iteration(world_size, test_case, fsdp_config): + """Test FSDP with uneven divide of parameter shards.""" + if torch_version() < (1, 6, 0): + pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") + + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs.") + + temp_file_name = tempfile.mkstemp()[1] + unused = tempfile.mkstemp()[1] + + # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers + # more cases in FSDP. Also, assert_ref_out can be extended to multiple + # iterations. This could be a good bootcamp task. I should file a github + # issue once we merge. + model = Linear(3, 3, bias=False) + mp.spawn( + _test_func, + args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), + nprocs=world_size, + join=True, + ) + + +@skip_if_single_gpu +@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3), torch.rand(8, 3)], "assert_ref_out": False}]) +@pytest.mark.parametrize("fsdp_config", [{}, {"flatten_parameters": False}]) +@pytest.mark.parametrize("world_size", list(range(2, 9))) +def test_smaller_than_world_size(world_size, test_case, fsdp_config): + """Test FSDP with uneven divide of parameter shards.""" + if torch_version() < (1, 6, 0): + pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") + + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs.") + + temp_file_name = tempfile.mkstemp()[1] + unused = tempfile.mkstemp()[1] + + model = Sequential( + Linear(3, 3, bias=False), + Linear(3, 4, bias=False), + Linear(4, 5, bias=False), + Linear(5, 4, bias=False), + Linear(4, 3, bias=False), + Linear(3, 1, bias=False), + Linear(1, 1, bias=False), # param here is smaller than world_size if unflattened. + ) + mp.spawn( + _test_func, + args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), + nprocs=world_size, + join=True, + ) diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index d8d2bc45c..6aa146892 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -631,6 +631,7 @@ def check(norm): loss_oss = loss_fn(outputs_oss, target) loss_oss.backward() + torch.testing.assert_allclose(loss_oss, loss) # Check the equivalence with the non-sharded optim oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm) diff --git a/tests/utils/test_parallel.py b/tests/utils/test_parallel.py new file mode 100644 index 000000000..a09c2a385 --- /dev/null +++ b/tests/utils/test_parallel.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test utility classes from fairscale.utils.parallel """ + +from parameterized import parameterized +import torch + +from fairscale.utils.parallel import chunk_and_pad + + +@parameterized.expand([[num_chunks] for num_chunks in range(1, 33)]) +def test_chunk_and_pad(num_chunks): + max_tensor_size = 256 + tensor = torch.zeros(max_tensor_size) + for tensor_size in range(1, max_tensor_size + 1): + tensor_i = tensor[:tensor_size] + chunks = chunk_and_pad(tensor_i, num_chunks) + assert len(chunks) == num_chunks + assert all(len(chunks[0]) == len(chunk) for chunk in chunks) diff --git a/tests/utils/test_reduce_scatter_bucketer.py b/tests/utils/test_reduce_scatter_bucketer.py new file mode 100644 index 000000000..4baa23c29 --- /dev/null +++ b/tests/utils/test_reduce_scatter_bucketer.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import itertools +import sys +import unittest +from unittest import mock + +from parameterized import parameterized +import torch + +from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer +from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes + + +def rename_test(testcase_func, param_num, param): + return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) + + +CONFIG_OPTIONS = [ + [dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144]) +] + + +class TestReduceScatterBucketer(unittest.TestCase): + # TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`. + def setUp(self): + major, minor = torch.__version__.split(".")[:2] + major, minor = int(major), int(minor) + if major < 1 or (major == 1 and minor < 6): + raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter") + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_reduce_scatter(self, config): + spawn_and_init(functools.partial(self._test_reduce_scatter, **config)) + + @staticmethod + def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None): + bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb) + world_size = group.size() + + tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)] + tensors[rank].fill_(0) + + input_bytes = shard_size * world_size * 4 + bucket_bytes = bucket_cap_mb * 1024 * 1024 + + callback = mock.MagicMock() + bucketer.reduce_scatter_async(tensors, group, callback_fn=callback) + + if bucket_cap_mb > 0 and input_bytes < bucket_bytes: + assert callback.call_count == 0 + bucketer.flush() + assert callback.call_count == 1 + + result = callback.call_args[0][0] # get first positional arg + assert torch.is_tensor(result), result + assert torch.all(result == (world_size - 1)) + + def test_out_of_order_reduction(self): + spawn_and_init(self._test_out_of_order_reduction) + + @staticmethod + def _test_out_of_order_reduction(rank, group): + bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25) + world_size = group.size() + + small_tensors = [torch.ones(1).cuda() for _ in range(world_size)] + big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)] + more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)] + + callback1 = mock.MagicMock() + callback2 = mock.MagicMock() + callback3 = mock.MagicMock() + + bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1) + assert callback1.call_count == 0 + bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2) + assert callback1.call_count == 0 + assert callback2.call_count == 1 + bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3) + assert callback1.call_count == 0 + assert callback2.call_count == 1 + assert callback3.call_count == 0 + + bucketer.flush() + assert callback1.call_count == 1 + assert callback2.call_count == 1 + assert callback3.call_count == 1 + + +def spawn_and_init(fn, args=None, **spawn_kwargs): + if args is None: + args = () + run_fn = functools.partial(init_and_run, fn, args) + spawn_for_all_world_sizes(run_fn, **spawn_kwargs) + + +def init_and_run(fn, args, rank, world_size, filename, filename_rpc): + dist_init(rank, world_size, filename, filename_rpc) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +if __name__ == "__main__": + unittest.main()