From a7be0d2e6303082f00879dd89bd8ea7c3c1a7c16 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 27 Jan 2021 17:11:00 -0800 Subject: [PATCH 01/48] Add fairscale.utils.testing.DeviceAndTypeCheckModule --- fairscale/utils/testing.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index d902383ba..a2b8b0658 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -396,3 +396,43 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: return False else: return a == b + + +class DeviceAndTypeCheckModule(nn.Module): + """A simple module for checking Tensor devices and dtypes.""" + + def __init__( + self, + expected_input_dtype: Optional[torch.dtype] = None, + expected_input_device: Optional[torch.device] = None, + expected_param_dtype: Optional[torch.dtype] = None, + expected_param_device: Optional[torch.device] = None, + expected_loss_dtype: Optional[torch.dtype] = None, + expected_loss_device: Optional[torch.device] = None, + ): + super().__init__() + self.expected_input_dtype = expected_input_dtype + self.expected_input_device = expected_input_device + self.expected_param_dtype = expected_param_dtype + self.expected_param_device = expected_param_device + self.expected_loss_dtype = expected_loss_dtype + self.expected_loss_device = expected_loss_device + + self.linear = nn.Linear(5, 5) + + def _check(self, key, x, expected): + assert expected in {None, x}, f"{key} ({x}) != expected ({expected})" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self._check("input.dtype", x.dtype, self.expected_input_dtype) + self._check("input.device", x.device, self.expected_input_device) + + param = self.linear.weight + self._check("param.dtype", param.dtype, self.expected_param_dtype) + self._check("param.device", param.device, self.expected_param_device) + + loss = self.linear(x).sum() + self._check("loss.dtype", loss.dtype, self.expected_loss_dtype) + self._check("loss.device", loss.device, self.expected_loss_device) + + return loss From e2ad716f28c7f091189d2c0282a131fc98bdedcd Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 27 Jan 2021 18:51:04 -0800 Subject: [PATCH 02/48] Add fairscale.utils.containers --- fairscale/utils/containers.py | 94 +++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 fairscale/utils/containers.py diff --git a/fairscale/utils/containers.py b/fairscale/utils/containers.py new file mode 100644 index 000000000..651931789 --- /dev/null +++ b/fairscale/utils/containers.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Union + +import torch + + +def apply_to_tensors(fn, container): + def _apply(x): + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(container) + + +def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: + """ + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == [1, 2] + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return kwarg_keys, flat_args + + +def unpack_kwargs(kwarg_keys: List[str], flat_args: List[Any]) -> Tuple[List[Any], Dict[str, Any]]: + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors(mixed: Union[torch.Tensor, Tuple[Any]]) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: + """ + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + tensors = [] + packed_non_tensors = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors(tensors: Tuple[torch.Tensor], packed_non_tensors: Dict[str, List[Any]],) -> Tuple[Any]: + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict) + mixed = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) From 4d6a5c9724eb42f6acb837149f0ab3b506f7fb16 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 26 Jan 2021 15:40:11 -0800 Subject: [PATCH 03/48] Add ShardParamsDataParallel --- fairscale/nn/data_parallel/__init__.py | 1 + .../shard_params_data_parallel.py | 453 ++++++++++++++++++ .../test_shard_params_data_parallel.py | 253 ++++++++++ 3 files changed, 707 insertions(+) create mode 100644 fairscale/nn/data_parallel/shard_params_data_parallel.py create mode 100644 tests/nn/data_parallel/test_shard_params_data_parallel.py diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index f8bdd814b..c9f70e949 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -3,4 +3,5 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from .shard_params_data_parallel import ShardParamsDataParallel from .sharded_ddp import ShardedDataParallel diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py new file mode 100644 index 000000000..a2affe79d --- /dev/null +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -0,0 +1,453 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import functools +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.utils.containers import ( + apply_to_tensors, + pack_kwargs, + split_non_tensors, + unpack_kwargs, + unpack_non_tensors, +) + + +class ShardParamsDataParallel(nn.Module): + """ + A wrapper for sharding Module parameters. + + Usage:: + + sharded_module = ShardParamsDistributedWrapper(my_module) + x = sharded_module(x, y=3, z=torch.Tensor([1])) + loss = x.sum() + loss.backward() + + It is also possible to shard individual layers separately and have an outer + wrapper handle any leftover parameters:: + + model = nn.Sequential( + nn.Linear(5, 100), + ShardParamsDistributedWrapper(nn.Linear(100, 100)), + ShardParamsDistributedWrapper(nn.Linear(100, 100)), + nn.Linear(100, 5), + ) + sharded_model = ShardParamsDistributedWrapper(model) + x = sharded_model(x) + loss = x.sum() + loss.backward() + + Args: + module (nn.Module): module to checkpoint + process_group (Optional): process group for sharding + reshard_after_forward (bool, Optional): if True, reshard parameters + after the forward pass. This saves memory but slows training. + mixed_precision (bool, Optional): if True, inputs, activations and + gradients will be kept in FP16; computation and communication will + occur in FP16; and a (sharded) master copy of the model weights will + be maintained in FP32. + fp32_reduce_scatter (bool, Optional): if True, then reduce-scatter + gradients in FP32. This is only relevant when *mixed_precision* is + ``True``. + flatten_parameters (bool, Optional): if True, flatten parameters into a + single contiguous tensor, which improves training speed. + cpu_offload (bool, Optional): if True, offload FP32 params to CPU. This + is only relevant when *mixed_precision* is ``True``. + compute_device (torch.device, Optional): device to move params to for + computation. This is primarily relevant with *cpu_offload* and + defaults to "cuda". + compute_dtype (torch.dtype, Optional): dtype for full parameters for + computation. This defaults to torch.float32 unless *mixed_precision* + is set, in which case it defaults to torch.float16. + move_grads_to_cpu (bool, Optional): move grad shard to CPU after + reduction. This is useful when combined with CPU-based optimizers. + """ + + def __init__( + self, + module: nn.Module, + process_group=None, + reshard_after_forward: bool = True, + mixed_precision: bool = False, + fp32_reduce_scatter: bool = False, + flatten_parameters: bool = True, + cpu_offload: bool = False, + compute_device: Optional[torch.device] = None, + compute_dtype: Optional[torch.dtype] = None, + move_grads_to_cpu: Optional[bool] = False, + ): + super().__init__() + self.process_group = process_group or dist.new_group() + self.rank = self.process_group.rank() + self.world_size = self.process_group.size() + self.reshard_after_forward = reshard_after_forward + self.mixed_precision = mixed_precision + self.fp32_reduce_scatter = fp32_reduce_scatter + self.flatten_parameters = flatten_parameters + self.cpu_offload = cpu_offload + self.compute_device = compute_device or torch.device("cuda") + self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) + self.move_grads_to_cpu = move_grads_to_cpu + + if self.fp32_reduce_scatter and not self.mixed_precision: + raise ValueError("fp32_reduce_scatter requires mixed_precision=True") + if self.cpu_offload and not self.mixed_precision: + raise ValueError("cpu_offload requires mixed_precision=True") + + # Only handle params which are not already sharded. This enables + # sharding individual layers of a Module, with an outer wrapper to + # shard any leftover parameters. + params = list(p for p in module.parameters() if not getattr(p, "_is_sharded", False)) + + if self.flatten_parameters and len(params) > 0: + self.module = FlattenParamsWrapper(module, param_list=params) + del module + self.params = [self.module.flat_param] + else: + self.module = module + self.params = params + + # Shard module parameters. + self._shard_initial_params() + + if self.mixed_precision: + # Cast all module buffers to FP16 (buffers are not sharded). + self.apply(cast_buffers_to_fp16) + + # Make sure all parameters are sharded. + for n, p in self.named_parameters(): + assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" + + @torch.no_grad() + def _shard_initial_params(self): + for p in self.params: + assert not hasattr(p, "_is_sharded") + assert p.is_floating_point() + if self.mixed_precision: + assert p.dtype == torch.float32 + + p._is_sharded = True + p._orig_size = p.data.size() + + shard_size = p.data.numel() // self.world_size + s = self.rank * shard_size + e = (self.rank + 1) * shard_size + + orig_data = p.data + p.data = p.data.view(-1)[s:e].clone() + free_storage_(orig_data) + + def __getattr__(self, name): + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def __getstate__(self): + state = copy.copy(self.__dict__) + state["orig_sizes"] = [p._orig_size for p in self.params] + if state["process_group"] is not None: + state["process_group"] = "MISSING" # raise error if used + if "_fp32_to_fp16_stream" in state: + del state["_fp32_to_fp16_stream"] + return state + + def __setstate__(self, state): + super().__setstate__(state) + + def fixup(p, size): + assert isinstance(p, torch.nn.Parameter) + p.data = p.data.clone() # move tensors out of shared memory + p._is_sharded = True + p._orig_size = size + return p + + self.params = [fixup(p, size) for p, size in zip(self.params, self.orig_sizes)] + del self.orig_sizes + + def state_dict(self, *args, **kwargs): + """ + Returns the whole (unsharded) state of the module. Parameters are not + sharded, so the resulting state_dict can be loaded directly by the + wrapped Module without any sharding-specific logic. + """ + torch.cuda.synchronize() + self._rebuild_full_params() + # We don't free the params after generating the state dict, since + # freeing is done in-place (via the Storagee) and would corrupt the + # returned state dict. + return self.module.state_dict(*args, **kwargs) + + def local_state_dict(self, *args, **kwargs): + """ + Returns the local (sharded) state of the module. Parameters are sharded, + so the resulting state_dict can only be loaded after the Module has been + wrapped with ShardParamsDistributedWrapper. + """ + if self.flatten_parameters: + kwargs["unflatten_params"] = False + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + """Load a whole (unsharded) state_dict.""" + self._rebuild_full_params() + output = self.module.load_state_dict(*args, **kwargs) + self._free_full_params() + return output + + def load_local_state_dict(self, *args, **kwargs): + """Load a local (sharded) state_dict.""" + return self.module.load_state_dict(*args, **kwargs) + + @torch.no_grad() + def _pre_forward_init(self): + did_init = False + for p in self.params: + if not hasattr(p, "_full_param"): + did_init = True + assert p._is_sharded + + p._fp32_shard = p.data + + if self.mixed_precision: + assert p._fp32_shard.dtype == torch.float32 + if self.cpu_offload: + p._fp32_shard = p._fp32_shard.pin_memory() + p._fp16_shard = torch.zeros_like( + p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype, + ) + free_storage_(p._fp16_shard) + p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) + else: + p._fp16_shard = None # use _fp32_shard + p._full_param = p._fp32_shard.new_empty(p._orig_size) + + p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device) + free_storage_(p._full_param) + + p.data = p._fp32_shard + + if self.move_grads_to_cpu: + if self.mixed_precision and not self.fp32_reduce_scatter: + grad_dtype = torch.float16 + else: + grad_dtype = torch.float32 + p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() + + if did_init: + self._fp32_to_fp16_stream = torch.cuda.Stream() + self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) + + def forward(self, *args, **kwargs): + self._pre_forward_init() + + if self.mixed_precision: + args, kwargs = cast_inputs_to_fp16(*args, **kwargs) + + # All-gather full parameters. + self._rebuild_full_params() + + # Register backward hooks to reshard params and reduce-scatter grads. + # These need to be re-registered every forward pass. + self._register_post_backward_hooks() + + outputs = self.module(*args, **kwargs) + + if self.reshard_after_forward: + self._free_full_params() + + # Switch to FP32 param shard after forward so that the optimizer will be + # initialized with the correct dtype and size. + self._use_fp32_param_shard() + + # Register pre-backward hook to run before the wrapped module's backward. + if torch.is_grad_enabled(): + pre_backward_hook_has_run = [False] + + def _pre_backward_hook(*unused): + if pre_backward_hook_has_run[0]: + return # only run once + pre_backward_hook_has_run[0] = True + + if self.reshard_after_forward: + self._rebuild_full_params() + else: + self._use_full_params() + + def _register_hook(t): + t.register_hook(_pre_backward_hook) + return t + + # Attach hooks to Tensor outputs. + outputs = apply_to_tensors(_register_hook, outputs) + + return outputs + + def _register_post_backward_hooks(self): + # Register backward hooks to reshard params and reduce-scatter grads. + if not torch.is_grad_enabled(): + return # don't register grad hooks if grad isn't enabled + for p in self.params: + if p.requires_grad: + if hasattr(p, "_shard_bwd_hook"): + p._shard_bwd_hook[1].remove() # remove existing handle + p_tmp = p.expand_as(p) + grad_acc = p_tmp.grad_fn.next_functions[0][0] + handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) + p._shard_bwd_hook = (grad_acc, handle) + + @torch.no_grad() + def _post_backward_hook(self, param, *unused): + if param.grad is None: + return + if param.grad.requires_grad: + raise RuntimeError("ShardParamsDistributedWrapper only works with gradients that don't require grad") + + # Free full params and switch to FP32 shard after backward. + self._free_full_params([param]) + self._use_fp32_param_shard([param]) + + if self.mixed_precision: + self._free_fp16_param_shard([param]) + + if self.fp32_reduce_scatter: + # Cast grad to FP32. + param.grad.data = param.grad.data.to(param.dtype) + + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.world_size) + + # Reduce-scatter grad. + param.grad.data = self._reduce_scatter(param.grad.data.view(-1)) + + if self.move_grads_to_cpu: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad + + # Cast grad to param's dtype (typically FP32). + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) + + @torch.no_grad() + def _rebuild_full_params(self): + if self.mixed_precision: + self._cast_fp32_param_shards_to_fp16() + + for p in self.params: + # All-gather parameters + alloc_storage_(p._full_param, size=p._orig_size) + output_list = list(p._full_param.view(-1).chunk(self.world_size)) + dist.all_gather(output_list, p.data, group=self.process_group) + p.data = p._full_param + p.grad = None + + if self.mixed_precision: + self._free_fp16_param_shard([p]) + + @torch.no_grad() + def _use_full_params(self): + for p in self.params: + assert p._full_param.storage().size() != 0 + p.data = p._full_param + + @torch.no_grad() + def _free_full_params(self, params=None): + if params is None: + params = self.params + current_stream = torch.cuda.current_stream() + for p in params: + # There may be external references to the Tensor Storage that we + # can't modify, such as references that are created by + # ctx.save_for_backward in the forward pass. Thus when we unshard + # parameters, we should reuse the original Tensor Storage object + # and unshard it in-place. For now, just resize the Storage to 0 to + # save memory. + p._full_param.record_stream(current_stream) + free_storage_(p._full_param) + + @torch.no_grad() + def _use_fp32_param_shard(self, params=None): + if params is None: + params = self.params + for p in params: + p.data = p._fp32_shard + + @torch.no_grad() + def _cast_fp32_param_shards_to_fp16(self, params=None): + if params is None: + params = self.params + with torch.cuda.stream(self._fp32_to_fp16_stream): + for p in params: + assert p._fp16_shard is not None + alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) + p._fp16_shard.copy_(p._fp32_shard, non_blocking=True) + p.data = p._fp16_shard + torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream) + + @torch.no_grad() + def _free_fp16_param_shard(self, params=None): + if params is None: + params = self.params + current_stream = torch.cuda.current_stream() + for p in params: + if p._fp16_shard is not None: + # _fp16_shard is allocated in _fp32_to_fp16_stream, so we can't + # free it until the work in the current stream completes. + p._fp16_shard.record_stream(current_stream) + free_storage_(p._fp16_shard) + + @torch.no_grad() + def _reduce_scatter(self, tensor, output=None): + assert tensor.numel() % self.world_size == 0 + tensor = tensor.view(self.world_size, -1) + if output is None: + output = torch.zeros_like(tensor[0]) + dist.reduce_scatter(output, list(tensor.unbind(0)), group=self.process_group) + return output + + +@torch.no_grad() +def cast_inputs_to_fp16(*args, **kwargs): + """ + Cast any Tensors in *args or **kwargs to FP16. + + Doesn't currently support Tensors nested inside containers (e.g., dict). + """ + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) + tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) + flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + return args, kwargs + + +def cast_buffers_to_fp16(module): + for key, buf in module.named_buffers(recurse=False): + if buf is not None: + setattr(module, key, buf.half()) + + +def free_storage_(data): + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + assert data.storage().size() == data.numel() + data.storage().resize_(0) + + +@torch.no_grad() +def alloc_storage_(data, size): + assert data.storage().size() == 0 + data.storage().resize_(size.numel()) + # data.set_(size=size) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py new file mode 100644 index 000000000..f126e3756 --- /dev/null +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import itertools +import sys +import tempfile +import unittest +from unittest import mock + +import torch +from torch import nn + +from fairscale.nn.data_parallel import ShardParamsDataParallel +from fairscale.utils.testing import DeviceAndTypeCheckModule, objects_are_equal + + +class DistributedTest(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + +class TestMixedPrecision(DistributedTest): + def test_all_fp32(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": False}, + False, # autocast enabled + torch.float32, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ), + world_size=2, + ) + + def test_mixed_precision(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float16, # expected_reduce_dtype + ), + world_size=2, + ) + + def test_mixed_precision_autocast(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float16, # expected_reduce_dtype + ), + world_size=2, + ) + + def test_mixed_precision_autocast_fp32_compute(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": True, "compute_dtype": torch.float32}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ), + world_size=2, + ) + + def test_fp32_reduce_scatter(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": True, "fp32_reduce_scatter": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ), + world_size=2, + ) + + def test_fp32_reduce_scatter_autocast(self): + spawn_and_init( + functools.partial( + self.__class__._test_dtypes, + {"mixed_precision": True, "fp32_reduce_scatter": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype + ), + world_size=2, + ) + + @staticmethod + def _test_dtypes(cfg, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): + # Patch _reduce_scatter op to check the dtype of the reduction + orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter + + model = DeviceAndTypeCheckModule( + expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, + ) + + def _reduce_scatter(self, tensor): + model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) + return orig_reduce_scatter(self, tensor) + + with mock.patch.object(ShardParamsDataParallel, "_reduce_scatter", new=_reduce_scatter): + model = ShardParamsDataParallel(model, group, **cfg).cuda() + device = next(model.parameters()).device + x = torch.rand(2, 5).to(device) + with torch.cuda.amp.autocast(enabled=autocast): + loss = model(x) + loss.backward() + + +class TestComparisonToPyTorchDDP(DistributedTest): + """ + Compare losses and parameter values after several updates when using + PyTorch DDP vs. ShardParamsDataParallel. + """ + + def test_transformer(self): + # Test every combination of these options: + keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"] + for config in itertools.product([True, False], repeat=len(keys)): + config = dict(zip(keys, config)) + spawn_and_init( + functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2, + ) + + @classmethod + def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3): + if config["mixed_precision"]: + autocast = True + # Force the compute dtype to be torch.float32 so that we get + # identical results as PyTorch DDP when using autocast. Note that + # this will cause the all-gather to happen in FP32, which is slower + # than necessary in most cases. + config["compute_dtype"] = torch.float32 + else: + autocast = False + + # Establish reference behavior with PyTorch DDP (+ optionally autocast). + model = nn.parallel.DistributedDataParallel( + model_cls().cuda(), device_ids=[rank], output_device=rank, process_group=group + ) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast) + ref_state_dict = model.module.state_dict() + + # Confirm we get the same behavior using ShardParamsDataParallel. + model = ShardParamsDataParallel(model_cls(), group, **config).cuda() + shard_loss = cls._train_for_several_steps(model, num_steps, autocast) + shard_state_dict = model.state_dict() + + try: + torch.testing.assert_allclose(ref_loss, shard_loss) + assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) + except (AssertionError, RuntimeError) as e: + raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") + + @classmethod + def _train_for_several_steps(cls, model, num_steps, autocast): + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + for _ in range(num_steps): + optim.zero_grad() + with torch.cuda.amp.autocast(enabled=autocast): + device = next(model.parameters()).device + input = model.module.get_input(device) + output = model(*input) + loss = model.module.get_loss(input, output) + assert loss.dtype == torch.float32 + loss.backward() + optim.step() + return loss.detach() + + +class TransformerWithSharedParams(nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) # keep everything deterministic + self.embed_tokens = nn.Embedding(50, 16) + self.transformer = nn.Transformer( + d_model=16, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=32, dropout=0.1, + ) + self.output_proj = nn.Linear(16, 50) + # share the embedding and output projection weights + self.output_proj.weight = self.embed_tokens.weight + + def get_input(self, device): + torch.manual_seed(1) # keep everything deterministic + src = torch.arange(12, device=device).view(6, 2) # T x B + tgt = torch.arange(8, device=device).view(4, 2) # T x B + return (src, tgt) + + def forward(self, src_ids, tgt_ids): + src = self.embed_tokens(src_ids) + tgt = self.embed_tokens(tgt_ids) + x = self.transformer(src, tgt) + return self.output_proj(x) + + def get_loss(self, input, output): + _, tgt = input + return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") + + +def spawn_and_init(fn, world_size, args=None): + if args is None: + args = () + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + torch.multiprocessing.spawn( + fn=functools.partial(init_and_run, fn, args), + args=(world_size, tmp_file.name), + nprocs=world_size, + join=True, + ) + + +def distributed_init(rank, world_size, tmp_file): + torch.distributed.init_process_group( + backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank, + ) + torch.cuda.set_device(rank) + + +def init_and_run(fn, args, rank, world_size, tmp_file): + distributed_init(rank, world_size, tmp_file) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +if __name__ == "__main__": + unittest.main() From dd57e30a5217d139ca10a50b38982ab0d53608a1 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Fri, 29 Jan 2021 07:10:34 -0800 Subject: [PATCH 04/48] [test]: skip autocast-needed tests on torch < 1.6 (#34) --- tests/nn/data_parallel/test_shard_params_data_parallel.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index f126e3756..b1b4d735f 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -19,6 +19,10 @@ class DistributedTest(unittest.TestCase): def setUp(self): + major, minor = torch.__version__.split(".")[:2] + major, minor = int(major), int(minor) + if major < 1 or minor < 6: + raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast") if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") if sys.platform == "win32": From 80678fcf39d523b5982638f32050381805202292 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Fri, 29 Jan 2021 07:12:35 -0800 Subject: [PATCH 05/48] [mypy]: fairscale/utils/containers.py (#33) --- fairscale/utils/containers.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/fairscale/utils/containers.py b/fairscale/utils/containers.py index 651931789..f705054e3 100644 --- a/fairscale/utils/containers.py +++ b/fairscale/utils/containers.py @@ -3,13 +3,17 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch +"""Useful functions to deal with tensor types with other python container types.""" -def apply_to_tensors(fn, container): - def _apply(x): + +def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: + """Recursively apply to all tensor in 4 kinds of container types.""" + + def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: if torch.is_tensor(x): return fn(x) elif isinstance(x, dict): @@ -26,7 +30,7 @@ def _apply(x): return _apply(container) -def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: +def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[List[str], List[Any]]: """ Usage:: @@ -44,6 +48,7 @@ def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: def unpack_kwargs(kwarg_keys: List[str], flat_args: List[Any]) -> Tuple[List[Any], Dict[str, Any]]: + """See pack_kwargs.""" if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -51,7 +56,9 @@ def unpack_kwargs(kwarg_keys: List[str], flat_args: List[Any]) -> Tuple[List[Any return args, kwargs -def split_non_tensors(mixed: Union[torch.Tensor, Tuple[Any]]) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any]] +) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: """ Usage:: @@ -63,8 +70,8 @@ def split_non_tensors(mixed: Union[torch.Tensor, Tuple[Any]]) -> Tuple[Tuple[tor """ if isinstance(mixed, torch.Tensor): return (mixed,), None - tensors = [] - packed_non_tensors = {"is_tensor": [], "objects": []} + tensors: List[torch.Tensor] = [] + packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []} for o in mixed: if isinstance(o, torch.Tensor): packed_non_tensors["is_tensor"].append(True) @@ -75,7 +82,8 @@ def split_non_tensors(mixed: Union[torch.Tensor, Tuple[Any]]) -> Tuple[Tuple[tor return tuple(tensors), packed_non_tensors -def unpack_non_tensors(tensors: Tuple[torch.Tensor], packed_non_tensors: Dict[str, List[Any]],) -> Tuple[Any]: +def unpack_non_tensors(tensors: Tuple[torch.Tensor], packed_non_tensors: Dict[str, List[Any]]) -> Tuple[Any, ...]: + """See split_non_tensors.""" if packed_non_tensors is None: return tensors assert isinstance(packed_non_tensors, dict) From 5fb36d1d41fb0dab3a65a6088130183dfc157640 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Fri, 29 Jan 2021 07:14:51 -0800 Subject: [PATCH 06/48] [mypy]: fixed fairscale/utils/testing.py (#32) --- fairscale/utils/testing.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index a2b8b0658..f6214a071 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -32,11 +32,12 @@ import os import random import tempfile -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy import pytest import torch +from torch import Tensor import torch.distributed as dist from torch.distributed import rpc import torch.multiprocessing as mp @@ -45,6 +46,11 @@ from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed +if TYPE_CHECKING: + Base = nn.Module[Tensor] +else: + Base = nn.Module + skip_if_no_cuda = pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required" ) @@ -56,12 +62,12 @@ _, filename_mpi = tempfile.mkstemp() -class IdentityLayer(torch.nn.Module): +class IdentityLayer(Base): def __init__(self, size: int, scale: float = 1.0) -> None: super(IdentityLayer, self).__init__() self.weight = torch.nn.Parameter(scale * torch.randn(size)) - def forward(self, *_: Any, **__: Any) -> Any: + def forward(self, *_: Any, **__: Any) -> Tensor: return self.weight @@ -282,7 +288,7 @@ def replacement(*args: Any, **kwargs: Any) -> None: return prepare_test -class _Block(nn.Module): +class _Block(Base): def __init__(self, embed_dim: int, num_heads: int) -> None: super().__init__() self.ln_1 = nn.LayerNorm(embed_dim) @@ -290,7 +296,7 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) - def forward(self, *inputs: Any, **kwargs: Any) -> Any: + def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: x = inputs[0] attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype) attn_mask = torch.triu(attn_mask, diagonal=1) @@ -303,7 +309,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: return x -class GPT2(nn.Module): +class GPT2(Base): """ GPT2 pytorch implementation, for testing purposes in the image-GPT context Credits: https://github.com/teddykoker/image-gpt""" @@ -330,7 +336,7 @@ def __init__( self.head = nn.Linear(embed_dim, num_vocab, bias=False) self.clf_head = nn.Linear(embed_dim, num_classes) - def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore + def forward(self, x: Tensor, classify: bool = False) -> Any: # type: ignore """ Expect input as shape [sequence len, batch] If classify, return classification logits @@ -398,7 +404,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: return a == b -class DeviceAndTypeCheckModule(nn.Module): +class DeviceAndTypeCheckModule(Base): """A simple module for checking Tensor devices and dtypes.""" def __init__( @@ -420,10 +426,16 @@ def __init__( self.linear = nn.Linear(5, 5) - def _check(self, key, x, expected): + def _check( + self, + key: str, + x: Union[torch.device, torch.dtype], + expected: Union[Optional[torch.device], Optional[torch.dtype]], + ) -> None: assert expected in {None, x}, f"{key} ({x}) != expected ({expected})" - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + x = input[0] self._check("input.dtype", x.dtype, self.expected_input_dtype) self._check("input.device", x.device, self.expected_input_device) From 774e1301965c00d8bc4824c4d0e7707ded955335 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 29 Jan 2021 11:57:09 -0500 Subject: [PATCH 07/48] More docs (#35) --- .../shard_params_data_parallel.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index a2affe79d..bf5e35346 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -28,23 +28,25 @@ class ShardParamsDataParallel(nn.Module): Usage:: sharded_module = ShardParamsDistributedWrapper(my_module) + optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor([1])) loss = x.sum() loss.backward() + optim.step() It is also possible to shard individual layers separately and have an outer - wrapper handle any leftover parameters:: - - model = nn.Sequential( - nn.Linear(5, 100), - ShardParamsDistributedWrapper(nn.Linear(100, 100)), - ShardParamsDistributedWrapper(nn.Linear(100, 100)), - nn.Linear(100, 5), + wrapper handle any leftover parameters. This can be helpful to further + reduce memory usage and to improve training speed by distributing the + unsharding (all-gather) across the forward pass. For example:: + + sharded_model = ShardParamsDistributedWrapper( + nn.Sequential( + nn.Linear(5, 100), + ShardParamsDistributedWrapper(nn.Linear(100, 100)), + ShardParamsDistributedWrapper(nn.Linear(100, 100)), + nn.Linear(100, 5), + ) ) - sharded_model = ShardParamsDistributedWrapper(model) - x = sharded_model(x) - loss = x.sum() - loss.backward() Args: module (nn.Module): module to checkpoint @@ -129,6 +131,27 @@ def __init__( @torch.no_grad() def _shard_initial_params(self): + """ + At initialization we wrap a module with full parameters and shard the + parameters in-place. Sharding is implemented by viewing each parameter + as a 1D Tensor and retaining only a single slice, where the slice size + is determined by the number of data parallel workers. + + Wrapping modules with many small parameters (or with a very large data + parallel world size) will result in many small parameter shards and slow + performance. In this case it's better to set *flatten_parameters* to + ``True``, so that all of the small parameters in the module are combined + into a single contiguous Tensor and sharded once. + + After this initial sharding is complete, the user can initialize a + ``torch.optim.Optimizer`` in the usual way, i.e.:: + + optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + + The optimizer will see only a single slice of parameters and will thus + allocate less memory for optimizer state, avoiding redundancy across + data parallel workers. + """ for p in self.params: assert not hasattr(p, "_is_sharded") assert p.is_floating_point() From b35a28a014f9cd4c97802ba5fd80ce53ea50c84d Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:21:46 -0800 Subject: [PATCH 08/48] [mypy]: fixed all the mypy errors (#37) * [mypy]: fixed all the mypy errors * make older version python happy with typing hints * [chore] Fix lint errors that broke master (#348) authored-by: Anjali Sridhar Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- .../shard_params_data_parallel.py | 88 ++++++++++++------- fairscale/nn/misc/flatten_params_wrapper.py | 13 ++- fairscale/utils/containers.py | 18 ++-- pyproject.toml | 2 +- stubs/torch/__init__.pyi | 1 + stubs/torch/distributed/__init__.pyi | 5 +- stubs/torch/nn/modules/module.pyi | 7 +- stubs/torch/nn/parameter.pyi | 9 ++ 9 files changed, 94 insertions(+), 51 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe653ae98..9644109f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: rev: 19.10b0 hooks: - id: black - language_version: python3.6 + language_version: python3.7 - repo: https://gitlab.com/pycqa/flake8 rev: 3.7.9 diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index bf5e35346..00b194203 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -5,10 +5,11 @@ import copy import functools -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup import torch.nn as nn from fairscale.nn.misc import FlattenParamsWrapper @@ -20,6 +21,9 @@ unpack_non_tensors, ) +if TYPE_CHECKING: + from collections import OrderedDict # noqa: F401 + class ShardParamsDataParallel(nn.Module): """ @@ -77,7 +81,7 @@ class ShardParamsDataParallel(nn.Module): def __init__( self, module: nn.Module, - process_group=None, + process_group: Optional[ProcessGroup] = None, reshard_after_forward: bool = True, mixed_precision: bool = False, fp32_reduce_scatter: bool = False, @@ -85,7 +89,7 @@ def __init__( cpu_offload: bool = False, compute_device: Optional[torch.device] = None, compute_dtype: Optional[torch.dtype] = None, - move_grads_to_cpu: Optional[bool] = False, + move_grads_to_cpu: bool = False, ): super().__init__() self.process_group = process_group or dist.new_group() @@ -111,7 +115,7 @@ def __init__( params = list(p for p in module.parameters() if not getattr(p, "_is_sharded", False)) if self.flatten_parameters and len(params) > 0: - self.module = FlattenParamsWrapper(module, param_list=params) + self.module: nn.Module = FlattenParamsWrapper(module, param_list=params) del module self.params = [self.module.flat_param] else: @@ -130,7 +134,7 @@ def __init__( assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" @torch.no_grad() - def _shard_initial_params(self): + def _shard_initial_params(self) -> None: """ At initialization we wrap a module with full parameters and shard the parameters in-place. Sharding is implemented by viewing each parameter @@ -139,13 +143,15 @@ def _shard_initial_params(self): Wrapping modules with many small parameters (or with a very large data parallel world size) will result in many small parameter shards and slow - performance. In this case it's better to set *flatten_parameters* to + performance. In this case it's better to set *``flatten_parameters``* to ``True``, so that all of the small parameters in the module are combined into a single contiguous Tensor and sharded once. After this initial sharding is complete, the user can initialize a ``torch.optim.Optimizer`` in the usual way, i.e.:: + .. code-block:: python + optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) The optimizer will see only a single slice of parameters and will thus @@ -169,14 +175,14 @@ def _shard_initial_params(self): p.data = p.data.view(-1)[s:e].clone() free_storage_(orig_data) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.module, name) - def __getstate__(self): + def __getstate__(self) -> Dict[str, str]: state = copy.copy(self.__dict__) state["orig_sizes"] = [p._orig_size for p in self.params] if state["process_group"] is not None: @@ -185,12 +191,14 @@ def __getstate__(self): del state["_fp32_to_fp16_stream"] return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: + """Intercept state setting and perform needed changes on params.""" super().__setstate__(state) - def fixup(p, size): + def fixup(p: torch.nn.Parameter, size: int) -> torch.nn.Parameter: assert isinstance(p, torch.nn.Parameter) p.data = p.data.clone() # move tensors out of shared memory + # Ignore mypy error since we add additional fields to a param. p._is_sharded = True p._orig_size = size return p @@ -198,7 +206,8 @@ def fixup(p, size): self.params = [fixup(p, size) for p, size in zip(self.params, self.orig_sizes)] del self.orig_sizes - def state_dict(self, *args, **kwargs): + # TODO: figuring out how to do typing for this overloaded function. + def state_dict(self, *args, **kwargs): # type: ignore """ Returns the whole (unsharded) state of the module. Parameters are not sharded, so the resulting state_dict can be loaded directly by the @@ -211,7 +220,8 @@ def state_dict(self, *args, **kwargs): # returned state dict. return self.module.state_dict(*args, **kwargs) - def local_state_dict(self, *args, **kwargs): + # TODO: figuring out how to do typing for this overloaded function. + def local_state_dict(self, *args, **kwargs): # type: ignore """ Returns the local (sharded) state of the module. Parameters are sharded, so the resulting state_dict can only be loaded after the Module has been @@ -221,19 +231,23 @@ def local_state_dict(self, *args, **kwargs): kwargs["unflatten_params"] = False return self.module.state_dict(*args, **kwargs) - def load_state_dict(self, *args, **kwargs): + def load_state_dict( + self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True + ) -> NamedTuple: """Load a whole (unsharded) state_dict.""" self._rebuild_full_params() - output = self.module.load_state_dict(*args, **kwargs) + output = self.module.load_state_dict(state_dict, strict) self._free_full_params() return output - def load_local_state_dict(self, *args, **kwargs): + def load_local_state_dict( + self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True + ) -> NamedTuple: """Load a local (sharded) state_dict.""" - return self.module.load_state_dict(*args, **kwargs) + return self.module.load_state_dict(state_dict, strict) @torch.no_grad() - def _pre_forward_init(self): + def _pre_forward_init(self) -> None: did_init = False for p in self.params: if not hasattr(p, "_full_param"): @@ -271,7 +285,7 @@ def _pre_forward_init(self): self._fp32_to_fp16_stream = torch.cuda.Stream() self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_init() if self.mixed_precision: @@ -297,7 +311,7 @@ def forward(self, *args, **kwargs): if torch.is_grad_enabled(): pre_backward_hook_has_run = [False] - def _pre_backward_hook(*unused): + def _pre_backward_hook(*unused: Any) -> None: if pre_backward_hook_has_run[0]: return # only run once pre_backward_hook_has_run[0] = True @@ -307,7 +321,7 @@ def _pre_backward_hook(*unused): else: self._use_full_params() - def _register_hook(t): + def _register_hook(t: torch.Tensor) -> torch.Tensor: t.register_hook(_pre_backward_hook) return t @@ -316,7 +330,7 @@ def _register_hook(t): return outputs - def _register_post_backward_hooks(self): + def _register_post_backward_hooks(self) -> None: # Register backward hooks to reshard params and reduce-scatter grads. if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled @@ -330,7 +344,7 @@ def _register_post_backward_hooks(self): p._shard_bwd_hook = (grad_acc, handle) @torch.no_grad() - def _post_backward_hook(self, param, *unused): + def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None: if param.grad is None: return if param.grad.requires_grad: @@ -362,7 +376,8 @@ def _post_backward_hook(self, param, *unused): param.grad.data = param.grad.data.to(dtype=param.data.dtype) @torch.no_grad() - def _rebuild_full_params(self): + def _rebuild_full_params(self) -> None: + """Get all shards of params.""" if self.mixed_precision: self._cast_fp32_param_shards_to_fp16() @@ -378,13 +393,14 @@ def _rebuild_full_params(self): self._free_fp16_param_shard([p]) @torch.no_grad() - def _use_full_params(self): + def _use_full_params(self) -> None: for p in self.params: assert p._full_param.storage().size() != 0 p.data = p._full_param @torch.no_grad() - def _free_full_params(self, params=None): + def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + """Free up storage for full parameters.""" if params is None: params = self.params current_stream = torch.cuda.current_stream() @@ -399,14 +415,16 @@ def _free_full_params(self, params=None): free_storage_(p._full_param) @torch.no_grad() - def _use_fp32_param_shard(self, params=None): + def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + """Use FP32 shard for a list of params.""" if params is None: params = self.params for p in params: p.data = p._fp32_shard @torch.no_grad() - def _cast_fp32_param_shards_to_fp16(self, params=None): + def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + """Cast FP32 param shard to FP16 for a list of params.""" if params is None: params = self.params with torch.cuda.stream(self._fp32_to_fp16_stream): @@ -418,7 +436,8 @@ def _cast_fp32_param_shards_to_fp16(self, params=None): torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream) @torch.no_grad() - def _free_fp16_param_shard(self, params=None): + def _free_fp16_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + """Free storage for FP16 shards for a list of params.""" if params is None: params = self.params current_stream = torch.cuda.current_stream() @@ -430,7 +449,7 @@ def _free_fp16_param_shard(self, params=None): free_storage_(p._fp16_shard) @torch.no_grad() - def _reduce_scatter(self, tensor, output=None): + def _reduce_scatter(self, tensor: torch.Tensor, output: Optional[torch.Tensor] = None) -> torch.Tensor: assert tensor.numel() % self.world_size == 0 tensor = tensor.view(self.world_size, -1) if output is None: @@ -440,7 +459,7 @@ def _reduce_scatter(self, tensor, output=None): @torch.no_grad() -def cast_inputs_to_fp16(*args, **kwargs): +def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast any Tensors in *args or **kwargs to FP16. @@ -454,13 +473,15 @@ def cast_inputs_to_fp16(*args, **kwargs): return args, kwargs -def cast_buffers_to_fp16(module): +def cast_buffers_to_fp16(module: nn.Module) -> None: + """Cast buffers of a module to FP16.""" for key, buf in module.named_buffers(recurse=False): if buf is not None: setattr(module, key, buf.half()) -def free_storage_(data): +def free_storage_(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" if data.storage().size() > 0: # Since we're modifying the Tensor's Storage directly, make sure the Tensor # is the sole occupant of the Storage. @@ -470,7 +491,8 @@ def free_storage_(data): @torch.no_grad() -def alloc_storage_(data, size): +def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: + """Allocate storage for a tensor.""" assert data.storage().size() == 0 data.storage().resize_(size.numel()) # data.set_(size=size) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 095828f30..6e039f78b 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -2,12 +2,15 @@ # Licensed under the MIT License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor import torch.nn as nn +if TYPE_CHECKING: + from collections import OrderedDict # noqa: F401 + class FlattenParamsWrapper(nn.Module): """ @@ -136,12 +139,14 @@ def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Return the flattened state_dict.""" return super().state_dict(*args, **kwargs) - def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None: + def load_state_dict( + self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True + ) -> NamedTuple: if "flat_param" in state_dict: - super().load_state_dict(state_dict, strict=True) + return super().load_state_dict(state_dict, strict=strict) else: with self.unflatten_params(): - return self.module.load_state_dict(state_dict, *args, **kwargs) + return self.module.load_state_dict(state_dict, strict) def forward(self, *inputs: Any, **kwinputs: Any) -> Any: self._unflatten_params_as_views() diff --git a/fairscale/utils/containers.py b/fairscale/utils/containers.py index f705054e3..0f5134c2e 100644 --- a/fairscale/utils/containers.py +++ b/fairscale/utils/containers.py @@ -30,7 +30,7 @@ def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: return _apply(container) -def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[List[str], List[Any]]: +def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]: """ Usage:: @@ -39,15 +39,15 @@ def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[List[str], List[Any]]: assert args == [1, 2] assert kwargs == {"a": 3, "b": 4} """ - kwarg_keys = [] - flat_args = list(args) + kwarg_keys: List[str] = [] + flat_args: List[Any] = list(args) for k, v in kwargs.items(): kwarg_keys.append(k) flat_args.append(v) - return kwarg_keys, flat_args + return tuple(kwarg_keys), tuple(flat_args) -def unpack_kwargs(kwarg_keys: List[str], flat_args: List[Any]) -> Tuple[List[Any], Dict[str, Any]]: +def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """See pack_kwargs.""" if len(kwarg_keys) == 0: return flat_args, {} @@ -57,7 +57,7 @@ def unpack_kwargs(kwarg_keys: List[str], flat_args: List[Any]) -> Tuple[List[Any def split_non_tensors( - mixed: Union[torch.Tensor, Tuple[Any]] + mixed: Union[torch.Tensor, Tuple[Any, ...]] ) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: """ Usage:: @@ -82,12 +82,14 @@ def split_non_tensors( return tuple(tensors), packed_non_tensors -def unpack_non_tensors(tensors: Tuple[torch.Tensor], packed_non_tensors: Dict[str, List[Any]]) -> Tuple[Any, ...]: +def unpack_non_tensors( + tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]] +) -> Tuple[Any, ...]: """See split_non_tensors.""" if packed_non_tensors is None: return tensors assert isinstance(packed_non_tensors, dict) - mixed = [] + mixed: List[Any] = [] is_tensor_list = packed_non_tensors["is_tensor"] objects = packed_non_tensors["objects"] assert len(tensors) + len(objects) == len(is_tensor_list) diff --git a/pyproject.toml b/pyproject.toml index 9b2d7cee8..363c814e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] +known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index 64049ba9a..f22720d09 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -84,6 +84,7 @@ class Size(tuple): class Storage: def size(self) -> _int: ... def element_size(self) -> _int: ... + def resize_(self, int) -> None: ... #END # See https://github.com/python/mypy/issues/4146 for why these workarounds diff --git a/stubs/torch/distributed/__init__.pyi b/stubs/torch/distributed/__init__.pyi index 8c951b65e..56307fe61 100644 --- a/stubs/torch/distributed/__init__.pyi +++ b/stubs/torch/distributed/__init__.pyi @@ -37,12 +37,15 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce def is_initialized() -> bool: ... def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ... -def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ... +def new_group(ranks: Optional[List[int]] = None, + timeout: Optional[datetime.timedelta] = datetime.timedelta(0, 1800), + backend: Optional[Union[str, Backend]] = None): ... def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... +def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def destroy_process_group() -> None: ... diff --git a/stubs/torch/nn/modules/module.pyi b/stubs/torch/nn/modules/module.pyi index 2838fdd72..458b32668 100644 --- a/stubs/torch/nn/modules/module.pyi +++ b/stubs/torch/nn/modules/module.pyi @@ -2,7 +2,7 @@ from ... import Tensor, device, dtype from .. import Parameter -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic, NamedTuple from collections import OrderedDict from ...utils.hooks import RemovableHandle @@ -65,9 +65,10 @@ class Module(Generic[T_co]): def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ... - # TODO double-check this def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ... + def __setstate__(self, state: Dict[str, Any]) -> None: ... + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # back that same object. But if they pass nothing, an `OrederedDict` is created and returned. T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor]) @@ -78,7 +79,7 @@ class Module(Generic[T_co]): @overload def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ... - def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...): ... + def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...) -> NamedTuple: ... def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ... diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 14f5a0bf9..ae0ea0861 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -4,6 +4,15 @@ from .. import Tensor import builtins class Parameter(Tensor): + # These are dynamic attributes added by shard_params_data_parallel class. + # Added here for better type checking. + _is_sharded: bool + _orig_size: int + _cpu_grad: Parameter + _full_param: Parameter + _fp32_shard: Parameter + _fp16_shard: Parameter + def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... ... From 92c550bfd1975f35eadb3f0d4f2a01a08503d50c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 1 Feb 2021 18:24:27 -0500 Subject: [PATCH 09/48] Sharded DDP: test cpu_offload arg (#40) * Test CPU offload * remove dead code --- .../shard_params_data_parallel.py | 60 +++--- fairscale/utils/testing.py | 6 +- .../test_shard_params_data_parallel.py | 195 ++++++++++-------- 3 files changed, 142 insertions(+), 119 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 00b194203..a38da0818 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -250,36 +250,42 @@ def load_local_state_dict( def _pre_forward_init(self) -> None: did_init = False for p in self.params: - if not hasattr(p, "_full_param"): - did_init = True - assert p._is_sharded - - p._fp32_shard = p.data - - if self.mixed_precision: - assert p._fp32_shard.dtype == torch.float32 - if self.cpu_offload: - p._fp32_shard = p._fp32_shard.pin_memory() - p._fp16_shard = torch.zeros_like( - p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype, - ) - free_storage_(p._fp16_shard) - p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) - else: - p._fp16_shard = None # use _fp32_shard - p._full_param = p._fp32_shard.new_empty(p._orig_size) + if hasattr(p, "_full_param"): + continue + did_init = True + assert p._is_sharded + + p._fp32_shard = p.data + + if self.mixed_precision: + assert p._fp32_shard.dtype == torch.float32 + + if self.cpu_offload: + assert p._fp32_shard.device == torch.device("cpu") + p._fp32_shard = p._fp32_shard.pin_memory() - p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device) - free_storage_(p._full_param) + p._fp16_shard = torch.zeros_like( + p._fp32_shard, + device=self.compute_device, + dtype=self.compute_dtype, + ) + free_storage_(p._fp16_shard) + p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) + else: + p._fp16_shard = None # use _fp32_shard + p._full_param = p._fp32_shard.new_empty(p._orig_size) - p.data = p._fp32_shard + p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device) + free_storage_(p._full_param) - if self.move_grads_to_cpu: - if self.mixed_precision and not self.fp32_reduce_scatter: - grad_dtype = torch.float16 - else: - grad_dtype = torch.float32 - p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() + p.data = p._fp32_shard + + if self.move_grads_to_cpu: + if self.mixed_precision and not self.fp32_reduce_scatter: + grad_dtype = torch.float16 + else: + grad_dtype = torch.float32 + p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() if did_init: self._fp32_to_fp16_stream = torch.cuda.Stream() diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index f6214a071..683825d4c 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -294,7 +294,11 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: self.ln_1 = nn.LayerNorm(embed_dim) self.ln_2 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore - self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.GELU(), + nn.Linear(embed_dim * 4, embed_dim), + ) def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: x = inputs[0] diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index b1b4d735f..e184c2584 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -15,6 +15,7 @@ from fairscale.nn.data_parallel import ShardParamsDataParallel from fairscale.utils.testing import DeviceAndTypeCheckModule, objects_are_equal +from typing import Dict class DistributedTest(unittest.TestCase): @@ -30,99 +31,100 @@ def setUp(self): if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + @staticmethod + def _train_for_several_steps(model, num_steps, autocast): + model_device = next(model.parameters()).device + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + for _ in range(num_steps): + optim.zero_grad() + with torch.cuda.amp.autocast(enabled=autocast): + # Inputs always cuda regardless of move_grads_cpu, or model.device + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + loss = model.module.get_loss(input, output).to(model_device) + print(f'loss device: {loss.device}') + assert loss.dtype == torch.float32 + loss.backward() + optim.step() + return loss.detach() + class TestMixedPrecision(DistributedTest): def test_all_fp32(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": False}, - False, # autocast enabled - torch.float32, # expected_input_dtype - torch.float32, # expected_param_dtype - torch.float32, # expected_loss_dtype - torch.float32, # expected_reduce_dtype - ), - world_size=2, + self._spawn_test_case( + {"mixed_precision": False}, + False, # autocast enabled + torch.float32, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype ) def test_mixed_precision(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": True}, - False, # autocast enabled - torch.float16, # expected_input_dtype - torch.float16, # expected_param_dtype - torch.float16, # expected_loss_dtype - torch.float16, # expected_reduce_dtype - ), - world_size=2, + self._spawn_test_case( + {"mixed_precision": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float16, # expected_reduce_dtype ) def test_mixed_precision_autocast(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": True}, - True, # autocast enabled - torch.float16, # expected_input_dtype - torch.float16, # expected_param_dtype - torch.float32, # expected_loss_dtype - torch.float16, # expected_reduce_dtype - ), - world_size=2, + """If autocast enabled, loss should be fp32.""" + self._spawn_test_case( + {"mixed_precision": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float16, # expected_reduce_dtype ) def test_mixed_precision_autocast_fp32_compute(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": True, "compute_dtype": torch.float32}, - True, # autocast enabled - torch.float16, # expected_input_dtype - torch.float32, # expected_param_dtype - torch.float32, # expected_loss_dtype - torch.float32, # expected_reduce_dtype - ), - world_size=2, + self._spawn_test_case( + {"mixed_precision": True, "compute_dtype": torch.float32}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float32, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype ) def test_fp32_reduce_scatter(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": True, "fp32_reduce_scatter": True}, - False, # autocast enabled - torch.float16, # expected_input_dtype - torch.float16, # expected_param_dtype - torch.float16, # expected_loss_dtype - torch.float32, # expected_reduce_dtype - ), - world_size=2, + self._spawn_test_case( + {"mixed_precision": True, "fp32_reduce_scatter": True}, + False, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float16, # expected_loss_dtype + torch.float32, # expected_reduce_dtype ) def test_fp32_reduce_scatter_autocast(self): - spawn_and_init( - functools.partial( - self.__class__._test_dtypes, - {"mixed_precision": True, "fp32_reduce_scatter": True}, - True, # autocast enabled - torch.float16, # expected_input_dtype - torch.float16, # expected_param_dtype - torch.float32, # expected_loss_dtype - torch.float32, # expected_reduce_dtype - ), - world_size=2, + self._spawn_test_case( + {"mixed_precision": True, "fp32_reduce_scatter": True}, + True, # autocast enabled + torch.float16, # expected_input_dtype + torch.float16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.float32, # expected_reduce_dtype ) + def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2): + """Call test_dtypes inside of torch.multiprocessing.spawn""" + fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype) + spawn_and_init(fn, world_size=world_size) + @staticmethod - def _test_dtypes(cfg, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): + def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): # Patch _reduce_scatter op to check the dtype of the reduction orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter model = DeviceAndTypeCheckModule( - expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, + expected_input_dtype=in_dtype, + expected_param_dtype=p_dtype, + expected_loss_dtype=loss_dtype, ) def _reduce_scatter(self, tensor): @@ -150,11 +152,24 @@ def test_transformer(self): for config in itertools.product([True, False], repeat=len(keys)): config = dict(zip(keys, config)) spawn_and_init( - functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2, + functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), + world_size=2, ) + def test_cpu_offload_and_cpu_grads(self): + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + spawn_and_init(test_fn) + + def test_cpu_offload_and_cuda_grads(self): + # If grads are on gpu, but model and optimizer are on cpu, backward breaks. + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} + with self.assertRaises(Exception): # RuntimeError inside spawn + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + spawn_and_init(test_fn) + @classmethod - def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3): + def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -173,7 +188,11 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3): ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. - model = ShardParamsDataParallel(model_cls(), group, **config).cuda() + model = ShardParamsDataParallel(model_cls(), group, **config) + if use_cuda: + model = model.cuda() + else: + assert next(model.parameters()).device == torch.device("cpu") shard_loss = cls._train_for_several_steps(model, num_steps, autocast) shard_state_dict = model.state_dict() @@ -183,31 +202,22 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3): except (AssertionError, RuntimeError) as e: raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") - @classmethod - def _train_for_several_steps(cls, model, num_steps, autocast): - optim = torch.optim.Adam(model.parameters(), lr=0.0001) - for _ in range(num_steps): - optim.zero_grad() - with torch.cuda.amp.autocast(enabled=autocast): - device = next(model.parameters()).device - input = model.module.get_input(device) - output = model(*input) - loss = model.module.get_loss(input, output) - assert loss.dtype == torch.float32 - loss.backward() - optim.step() - return loss.detach() - class TransformerWithSharedParams(nn.Module): def __init__(self): super().__init__() torch.manual_seed(0) # keep everything deterministic - self.embed_tokens = nn.Embedding(50, 16) + d_model = 16 + d_vocab = 32 + self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( - d_model=16, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=32, dropout=0.1, + d_model=d_model, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=8, + dropout=0.1, ) - self.output_proj = nn.Linear(16, 50) + self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights self.output_proj.weight = self.embed_tokens.weight @@ -228,7 +238,7 @@ def get_loss(self, input, output): return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") -def spawn_and_init(fn, world_size, args=None): +def spawn_and_init(fn, world_size=2, args=None): if args is None: args = () with tempfile.NamedTemporaryFile(delete=False) as tmp_file: @@ -242,7 +252,10 @@ def spawn_and_init(fn, world_size, args=None): def distributed_init(rank, world_size, tmp_file): torch.distributed.init_process_group( - backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank, + backend="nccl", + init_method="file://{}".format(tmp_file), + world_size=world_size, + rank=rank, ) torch.cuda.set_device(rank) From 5bb212f922e17b973e9e19327ef9cd24f2ed8682 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 1 Feb 2021 18:33:09 -0500 Subject: [PATCH 10/48] Misc comments from @anj-s (#43) --- .../shard_params_data_parallel.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index a38da0818..f79051baf 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -31,7 +31,7 @@ class ShardParamsDataParallel(nn.Module): Usage:: - sharded_module = ShardParamsDistributedWrapper(my_module) + sharded_module = ShardParamsDataParallel(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor([1])) loss = x.sum() @@ -43,11 +43,11 @@ class ShardParamsDataParallel(nn.Module): reduce memory usage and to improve training speed by distributing the unsharding (all-gather) across the forward pass. For example:: - sharded_model = ShardParamsDistributedWrapper( + sharded_model = ShardParamsDataParallel( nn.Sequential( nn.Linear(5, 100), - ShardParamsDistributedWrapper(nn.Linear(100, 100)), - ShardParamsDistributedWrapper(nn.Linear(100, 100)), + ShardParamsDataParallel(nn.Linear(100, 100)), + ShardParamsDataParallel(nn.Linear(100, 100)), nn.Linear(100, 5), ) ) @@ -186,7 +186,7 @@ def __getstate__(self) -> Dict[str, str]: state = copy.copy(self.__dict__) state["orig_sizes"] = [p._orig_size for p in self.params] if state["process_group"] is not None: - state["process_group"] = "MISSING" # raise error if used + state["process_group"] = "MISSING" # process_group isn't pickleable if "_fp32_to_fp16_stream" in state: del state["_fp32_to_fp16_stream"] return state @@ -216,7 +216,7 @@ def state_dict(self, *args, **kwargs): # type: ignore torch.cuda.synchronize() self._rebuild_full_params() # We don't free the params after generating the state dict, since - # freeing is done in-place (via the Storagee) and would corrupt the + # freeing is done in-place (via the Storage) and would corrupt the # returned state dict. return self.module.state_dict(*args, **kwargs) @@ -225,7 +225,7 @@ def local_state_dict(self, *args, **kwargs): # type: ignore """ Returns the local (sharded) state of the module. Parameters are sharded, so the resulting state_dict can only be loaded after the Module has been - wrapped with ShardParamsDistributedWrapper. + wrapped with ShardParamsDataParallel. """ if self.flatten_parameters: kwargs["unflatten_params"] = False @@ -313,26 +313,28 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # initialized with the correct dtype and size. self._use_fp32_param_shard() + if not torch.is_grad_enabled(): + return outputs + # Register pre-backward hook to run before the wrapped module's backward. - if torch.is_grad_enabled(): - pre_backward_hook_has_run = [False] + pre_backward_hook_has_run = [False] - def _pre_backward_hook(*unused: Any) -> None: - if pre_backward_hook_has_run[0]: - return # only run once - pre_backward_hook_has_run[0] = True + def _pre_backward_hook(*unused: Any) -> None: + if pre_backward_hook_has_run[0]: + return # only run once + pre_backward_hook_has_run[0] = True - if self.reshard_after_forward: - self._rebuild_full_params() - else: - self._use_full_params() + if self.reshard_after_forward: + self._rebuild_full_params() + else: + self._use_full_params() - def _register_hook(t: torch.Tensor) -> torch.Tensor: - t.register_hook(_pre_backward_hook) - return t + def _register_hook(t: torch.Tensor) -> torch.Tensor: + t.register_hook(_pre_backward_hook) + return t - # Attach hooks to Tensor outputs. - outputs = apply_to_tensors(_register_hook, outputs) + # Attach hooks to Tensor outputs. + outputs = apply_to_tensors(_register_hook, outputs) return outputs @@ -354,7 +356,7 @@ def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None: if param.grad is None: return if param.grad.requires_grad: - raise RuntimeError("ShardParamsDistributedWrapper only works with gradients that don't require grad") + raise RuntimeError("ShardParamsDataParallel only works with gradients that don't require grad") # Free full params and switch to FP32 shard after backward. self._free_full_params([param]) @@ -501,4 +503,3 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: """Allocate storage for a tensor.""" assert data.storage().size() == 0 data.storage().resize_(size.numel()) - # data.set_(size=size) From cbd243e128978109d6a746a933f0b3e4b50d055a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 2 Feb 2021 10:06:10 -0500 Subject: [PATCH 11/48] Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wrapper (#42) * Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wrapper * Fix mypy, add test, address some comments * Add missing assert * Comments --- .../shard_params_data_parallel.py | 120 +++++++++++------- fairscale/nn/misc/flatten_params_wrapper.py | 4 +- fairscale/utils/testing.py | 22 +++- stubs/torch/nn/parameter.pyi | 13 +- .../test_shard_params_data_parallel.py | 81 ++++++++---- 5 files changed, 158 insertions(+), 82 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index f79051baf..52b080d63 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -11,6 +11,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup import torch.nn as nn +from torch.nn import Parameter from fairscale.nn.misc import FlattenParamsWrapper from fairscale.utils.containers import ( @@ -133,6 +134,11 @@ def __init__( for n, p in self.named_parameters(): assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" + # Flag to indicate if this instance is wrapped by any other + # ShardParamsDataParallel instances. This flag is only set after the + # first forward pass. + self._is_root: Optional[bool] = None + @torch.no_grad() def _shard_initial_params(self) -> None: """ @@ -195,8 +201,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: """Intercept state setting and perform needed changes on params.""" super().__setstate__(state) - def fixup(p: torch.nn.Parameter, size: int) -> torch.nn.Parameter: - assert isinstance(p, torch.nn.Parameter) + def fixup(p: Parameter, size: torch.Size) -> Parameter: + assert isinstance(p, Parameter) p.data = p.data.clone() # move tensors out of shared memory # Ignore mypy error since we add additional fields to a param. p._is_sharded = True @@ -247,49 +253,71 @@ def load_local_state_dict( return self.module.load_state_dict(state_dict, strict) @torch.no_grad() - def _pre_forward_init(self) -> None: - did_init = False - for p in self.params: - if hasattr(p, "_full_param"): - continue - did_init = True - assert p._is_sharded + def _init_param(self, p: Parameter) -> None: + assert p._is_sharded + assert not hasattr(p, "_full_param") - p._fp32_shard = p.data + p._fp32_shard = p.data - if self.mixed_precision: - assert p._fp32_shard.dtype == torch.float32 + if self.mixed_precision: + assert p._fp32_shard.dtype == torch.float32 - if self.cpu_offload: - assert p._fp32_shard.device == torch.device("cpu") - p._fp32_shard = p._fp32_shard.pin_memory() + if self.cpu_offload: + assert p._fp32_shard.device == torch.device("cpu") + p._fp32_shard = p._fp32_shard.pin_memory() - p._fp16_shard = torch.zeros_like( - p._fp32_shard, - device=self.compute_device, - dtype=self.compute_dtype, - ) - free_storage_(p._fp16_shard) - p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) - else: - p._fp16_shard = None # use _fp32_shard - p._full_param = p._fp32_shard.new_empty(p._orig_size) + p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) + free_storage_(p._fp16_shard) + else: + p._fp16_shard = None # use _fp32_shard - p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device) - free_storage_(p._full_param) + p.data = p._fp32_shard - p.data = p._fp32_shard + p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) + free_storage_(p._full_param) - if self.move_grads_to_cpu: - if self.mixed_precision and not self.fp32_reduce_scatter: - grad_dtype = torch.float16 - else: - grad_dtype = torch.float32 - p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() + if self.move_grads_to_cpu: + if self.mixed_precision and not self.fp32_reduce_scatter: + grad_dtype = torch.float16 + else: + grad_dtype = torch.float32 + p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() - if did_init: - self._fp32_to_fp16_stream = torch.cuda.Stream() - self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) + @torch.no_grad() + def _pre_forward_init(self) -> None: + first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] + for p in first_time_params: + self._init_param(p) + + if len(first_time_params) > 0: + if self._is_root is None: + # This implies that no other ShardParamsDataParallel instance + # wraps this instance, otherwise it would have already set this + # flag to False. + self._is_root = True + + # As the root, we now set all children instances to False. + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + assert m._is_root is None + m._is_root = False + + if self._is_root: + # Stream for moving FP32 master params (which may be on CPU) to + # FP16 for computation. We share this stream with all children + # instances, which allows them to overlap transfers across the + # forward pass without synchronizing with the default stream. + self._fp32_to_fp16_stream = torch.cuda.Stream() + + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + m._fp32_to_fp16_stream = self._fp32_to_fp16_stream + + assert self._is_root is not None + if self._is_root: + # The top-most (root) instance needs to synchronize with the default + # stream, so we don't move the FP32 master weights prematurely. + self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_init() @@ -313,8 +341,12 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # initialized with the correct dtype and size. self._use_fp32_param_shard() - if not torch.is_grad_enabled(): - return outputs + if torch.is_grad_enabled(): + outputs = self._register_pre_backward_hooks(outputs) + + return outputs + + def _register_pre_backward_hooks(self, outputs: Any) -> Any: # Register pre-backward hook to run before the wrapped module's backward. pre_backward_hook_has_run = [False] @@ -352,7 +384,7 @@ def _register_post_backward_hooks(self) -> None: p._shard_bwd_hook = (grad_acc, handle) @torch.no_grad() - def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None: + def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if param.grad is None: return if param.grad.requires_grad: @@ -407,7 +439,7 @@ def _use_full_params(self) -> None: p.data = p._full_param @torch.no_grad() - def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: """Free up storage for full parameters.""" if params is None: params = self.params @@ -423,7 +455,7 @@ def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) - free_storage_(p._full_param) @torch.no_grad() - def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Use FP32 shard for a list of params.""" if params is None: params = self.params @@ -431,7 +463,7 @@ def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = Non p.data = p._fp32_shard @torch.no_grad() - def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: """Cast FP32 param shard to FP16 for a list of params.""" if params is None: params = self.params @@ -444,7 +476,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Paramet torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream) @torch.no_grad() - def _free_fp16_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Free storage for FP16 shards for a list of params.""" if params is None: params = self.params diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 6e039f78b..d05f0f3a0 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -130,10 +130,10 @@ def __getattr__(self, name: str) -> Any: except AttributeError: return getattr(self.module, name) # fallback to wrapped module - def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore + def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore """Return an unflattened state_dict.""" with self.unflatten_params(): - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + return self.module.state_dict(*args, **kwargs) def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Return the flattened state_dict.""" diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 683825d4c..323c0e4e7 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -294,11 +294,7 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: self.ln_1 = nn.LayerNorm(embed_dim) self.ln_2 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore - self.mlp = nn.Sequential( - nn.Linear(embed_dim, embed_dim * 4), - nn.GELU(), - nn.Linear(embed_dim * 4, embed_dim), - ) + self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: x = inputs[0] @@ -452,3 +448,19 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: self._check("loss.device", loss.device, self.expected_loss_device) return loss + + +@functools.lru_cache +def get_cycles_per_ms() -> float: + """Approximate number of cycles per millisecond for torch.cuda._sleep + + Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + return cycles_per_ms diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index ae0ea0861..05f24df38 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -1,17 +1,18 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .. import Tensor +from typing import Optional +from .. import Size, Tensor import builtins class Parameter(Tensor): # These are dynamic attributes added by shard_params_data_parallel class. # Added here for better type checking. _is_sharded: bool - _orig_size: int - _cpu_grad: Parameter - _full_param: Parameter - _fp32_shard: Parameter - _fp16_shard: Parameter + _orig_size: Size + _cpu_grad: Tensor + _full_param: Tensor + _fp32_shard: Tensor + _fp16_shard: Optional[Tensor] def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index e184c2584..289bded37 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -7,6 +7,7 @@ import itertools import sys import tempfile +from typing import Dict import unittest from unittest import mock @@ -14,8 +15,7 @@ from torch import nn from fairscale.nn.data_parallel import ShardParamsDataParallel -from fairscale.utils.testing import DeviceAndTypeCheckModule, objects_are_equal -from typing import Dict +from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal class DistributedTest(unittest.TestCase): @@ -30,6 +30,7 @@ def setUp(self): raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + torch.manual_seed(0) # keep everything deterministic @staticmethod def _train_for_several_steps(model, num_steps, autocast): @@ -42,7 +43,7 @@ def _train_for_several_steps(model, num_steps, autocast): input = model.module.get_input(torch.device("cuda")) output = model(*input) loss = model.module.get_loss(input, output).to(model_device) - print(f'loss device: {loss.device}') + print(f"loss device: {loss.device}") assert loss.dtype == torch.float32 loss.backward() optim.step() @@ -122,9 +123,7 @@ def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtyp orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter model = DeviceAndTypeCheckModule( - expected_input_dtype=in_dtype, - expected_param_dtype=p_dtype, - expected_loss_dtype=loss_dtype, + expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, ) def _reduce_scatter(self, tensor): @@ -152,8 +151,7 @@ def test_transformer(self): for config in itertools.product([True, False], repeat=len(keys)): config = dict(zip(keys, config)) spawn_and_init( - functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), - world_size=2, + functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2, ) def test_cpu_offload_and_cpu_grads(self): @@ -164,12 +162,34 @@ def test_cpu_offload_and_cpu_grads(self): def test_cpu_offload_and_cuda_grads(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} - with self.assertRaises(Exception): # RuntimeError inside spawn - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + with self.assertRaises(Exception): # RuntimeError inside spawn + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False + ) spawn_and_init(test_fn) + def test_delayed_optim_step(self): + # We use a model with a long CUDA delay right before the optimizer step. + # This tests our streams logic, and that we don't start the FP32 -> FP16 + # transfer until after the optimization step completes. + config = {"mixed_precision": True} + test_fn = functools.partial(self._test_identical_outputs, self._delayed_optim_step_model, config) + spawn_and_init(test_fn) + @classmethod - def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, use_cuda=True): + def _delayed_optim_step_model(cls, rank, group, config=None): + def _maybe_wrap(layer): + if config is not None: + return ShardParamsDataParallel(layer, group, **config) + return layer + + model = nn.Sequential( + nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), + ) + return ModuleWithDelay(model, delay_after_loss_ms=250) + + @classmethod + def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -181,14 +201,13 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, us autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = nn.parallel.DistributedDataParallel( - model_cls().cuda(), device_ids=[rank], output_device=rank, process_group=group - ) + model = model_init_fn(rank, group).cuda() + model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. - model = ShardParamsDataParallel(model_cls(), group, **config) + model = ShardParamsDataParallel(model_init_fn(rank, group, config), group, **config) if use_cuda: model = model.cuda() else: @@ -204,18 +223,14 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, us class TransformerWithSharedParams(nn.Module): - def __init__(self): + def __init__(self, *args, **kwargs): super().__init__() torch.manual_seed(0) # keep everything deterministic d_model = 16 d_vocab = 32 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( - d_model=d_model, - num_encoder_layers=2, - num_decoder_layers=2, - dim_feedforward=8, - dropout=0.1, + d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, ) self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights @@ -238,6 +253,25 @@ def get_loss(self, input, output): return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") +class ModuleWithDelay(nn.Module): + def __init__(self, module, delay_after_loss_ms): + super().__init__() + self.module = module + self.delay_after_loss_ms = delay_after_loss_ms + + def get_input(self, device): + torch.manual_seed(1) # keep everything deterministic + return (torch.rand(4, 8, device=device),) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = output.sum() + torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) + return loss + + def spawn_and_init(fn, world_size=2, args=None): if args is None: args = () @@ -252,10 +286,7 @@ def spawn_and_init(fn, world_size=2, args=None): def distributed_init(rank, world_size, tmp_file): torch.distributed.init_process_group( - backend="nccl", - init_method="file://{}".format(tmp_file), - world_size=world_size, - rank=rank, + backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank, ) torch.cuda.set_device(rank) From 8db0cf6fe505393752c8aa9c33a29cb235bee7b9 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 2 Feb 2021 10:14:33 -0500 Subject: [PATCH 12/48] Fix streams test (#45) --- tests/nn/data_parallel/test_shard_params_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 289bded37..d34aec697 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -30,7 +30,6 @@ def setUp(self): raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") - torch.manual_seed(0) # keep everything deterministic @staticmethod def _train_for_several_steps(model, num_steps, autocast): @@ -183,6 +182,7 @@ def _maybe_wrap(layer): return ShardParamsDataParallel(layer, group, **config) return layer + torch.manual_seed(0) # keep everything deterministic model = nn.Sequential( nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), ) From bc7e3375e46da5cd5b3980180e523404fcda4446 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 2 Feb 2021 10:47:32 -0500 Subject: [PATCH 13/48] move_grads_to_cpu defaults to same value as cpu_offload (#44) --- fairscale/nn/data_parallel/shard_params_data_parallel.py | 5 +++-- tests/nn/data_parallel/test_shard_params_data_parallel.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 52b080d63..e61372b0c 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -90,7 +90,7 @@ def __init__( cpu_offload: bool = False, compute_device: Optional[torch.device] = None, compute_dtype: Optional[torch.dtype] = None, - move_grads_to_cpu: bool = False, + move_grads_to_cpu: Optional[bool] = None, ): super().__init__() self.process_group = process_group or dist.new_group() @@ -103,13 +103,14 @@ def __init__( self.cpu_offload = cpu_offload self.compute_device = compute_device or torch.device("cuda") self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) - self.move_grads_to_cpu = move_grads_to_cpu + self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") if self.cpu_offload and not self.mixed_precision: raise ValueError("cpu_offload requires mixed_precision=True") + # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index d34aec697..aa1d41ab7 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -154,9 +154,10 @@ def test_transformer(self): ) def test_cpu_offload_and_cpu_grads(self): - config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) - spawn_and_init(test_fn) + for move_grads_choice in (True, None): + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": move_grads_choice} + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + spawn_and_init(test_fn) def test_cpu_offload_and_cuda_grads(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. From a1b3924c0ff66685483839f00b96cc5e299b16f4 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Tue, 2 Feb 2021 10:01:18 -0800 Subject: [PATCH 14/48] formatting change (#46) --- fairscale/nn/data_parallel/shard_params_data_parallel.py | 1 - tests/nn/data_parallel/test_shard_params_data_parallel.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index e61372b0c..50258face 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -110,7 +110,6 @@ def __init__( if self.cpu_offload and not self.mixed_precision: raise ValueError("cpu_offload requires mixed_precision=True") - # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index aa1d41ab7..6c8c14f81 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -156,7 +156,9 @@ def test_transformer(self): def test_cpu_offload_and_cpu_grads(self): for move_grads_choice in (True, None): config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": move_grads_choice} - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False + ) spawn_and_init(test_fn) def test_cpu_offload_and_cuda_grads(self): From e10df73ac7d30f3b11517bbad87a3ffa186ef311 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 3 Feb 2021 10:20:25 -0500 Subject: [PATCH 15/48] Test that backward hooks are registered (#49) * Test backward hooks are registered * expand * fs_test * passing * assert again * add assert not called * naming --- .../test_shard_params_data_parallel.py | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 6c8c14f81..f172705a7 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -11,6 +11,7 @@ import unittest from unittest import mock +from parameterized import parameterized import torch from torch import nn @@ -42,12 +43,19 @@ def _train_for_several_steps(model, num_steps, autocast): input = model.module.get_input(torch.device("cuda")) output = model(*input) loss = model.module.get_loss(input, output).to(model_device) - print(f"loss device: {loss.device}") assert loss.dtype == torch.float32 loss.backward() optim.step() return loss.detach() + @staticmethod + def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> ShardParamsDataParallel: + if cuda_first: + model = ShardParamsDataParallel(TransformerWithSharedParams(**model_kwargs).cuda(), group, **config) + else: + model = ShardParamsDataParallel(TransformerWithSharedParams(**model_kwargs), group, **config).cuda() + return model + class TestMixedPrecision(DistributedTest): def test_all_fp32(self): @@ -225,6 +233,53 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") +class TestHooks(DistributedTest): + # Feel free to modify these tests as the implementation changes. + # They aspire to make sure that backward hooks are registered and used + + + + @parameterized.expand([[True], [False]]) + def test_output_backward_hooks(self, cuda_first): + fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first) + spawn_and_init(fn) + + @classmethod + def _test_output_backward_hooks(self, rank, group, cuda_first=False): + model = self.get_wrapped_model(group, cuda_first=cuda_first) + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + optim.zero_grad() + # Inputs always cuda regardless of move_grads_cpu, or model.device + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + assert len(output._backward_hooks) == 1 # this is pre-bwd hook + loss = model.module.get_loss(input, output).cuda() + for p in model.params: + assert p.grad is None # because of pre_backward_hook + loss.backward() + assert len(output._backward_hooks) == 1 # It doesn't get removed + optim.step() + assert len(output._backward_hooks) == 1 + + @parameterized.expand([[True], [False]]) + def test_register_functions_called(self, cuda_first): + fn = functools.partial(self._test_register_functions_called, cuda_first=cuda_first) + spawn_and_init(fn) + + @classmethod + def _test_register_functions_called(self, rank, group, cuda_first=False): + """Tests that _register_{pre|post}_backward_hooks called during forward.""" + model = self.get_wrapped_model(group, cuda_first=cuda_first) + input = model.module.get_input(torch.device("cuda")) + model._register_post_backward_hooks = mock.MagicMock(return_value=None) + model._register_pre_backward_hooks = mock.MagicMock(return_value=None) + assert not model._register_post_backward_hooks.called + assert not model._register_pre_backward_hooks.called + model(*input) + assert model._register_post_backward_hooks.called + assert model._register_pre_backward_hooks.called + + class TransformerWithSharedParams(nn.Module): def __init__(self, *args, **kwargs): super().__init__() From e139857d8b8a3d39c22b5a41dc571c92e16fb572 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Wed, 3 Feb 2021 14:00:41 -0800 Subject: [PATCH 16/48] [test] add test for apply_to_tensors (#50) * format change * [test]: test apply_to_tensors * formatting * added some skeletons * name for TODO * fixing the use of lru_cache * formatting --- .../shard_params_data_parallel.py | 4 +- fairscale/utils/testing.py | 14 +++- .../test_shard_params_data_parallel.py | 2 - tests/utils/test_containers.py | 68 +++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 tests/utils/test_containers.py diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 50258face..66fc822a0 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -212,7 +212,7 @@ def fixup(p: Parameter, size: torch.Size) -> Parameter: self.params = [fixup(p, size) for p, size in zip(self.params, self.orig_sizes)] del self.orig_sizes - # TODO: figuring out how to do typing for this overloaded function. + # TODO (Min): figuring out how to do typing for this overloaded function. def state_dict(self, *args, **kwargs): # type: ignore """ Returns the whole (unsharded) state of the module. Parameters are not @@ -226,7 +226,7 @@ def state_dict(self, *args, **kwargs): # type: ignore # returned state dict. return self.module.state_dict(*args, **kwargs) - # TODO: figuring out how to do typing for this overloaded function. + # TODO (Min): figuring out how to do typing for this overloaded function. def local_state_dict(self, *args, **kwargs): # type: ignore """ Returns the local (sharded) state of the module. Parameters are sharded, diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 323c0e4e7..2e395a362 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -450,11 +450,23 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: return loss -@functools.lru_cache +@functools.lru_cache() def get_cycles_per_ms() -> float: """Approximate number of cycles per millisecond for torch.cuda._sleep Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + + ..note:: + This doesn't seems to return consistent cycles on desktop GPUs likely + due to frequency scaling. + >>> get_cycles_per_ms() + 227.6441091140009 + # new python process + >>> get_cycles_per_ms() + 564.652154766248 + # new python process + >>> get_cycles_per_ms() + 245.56459442962856 """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index f172705a7..83970a936 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -237,8 +237,6 @@ class TestHooks(DistributedTest): # Feel free to modify these tests as the implementation changes. # They aspire to make sure that backward hooks are registered and used - - @parameterized.expand([[True], [False]]) def test_output_backward_hooks(self, cuda_first): fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first) diff --git a/tests/utils/test_containers.py b/tests/utils/test_containers.py new file mode 100644 index 000000000..b0eebbeb1 --- /dev/null +++ b/tests/utils/test_containers.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test utility classes from containers.py. """ + +import random + +import pytest +import torch + +from fairscale.utils.containers import ( + apply_to_tensors, + pack_kwargs, + split_non_tensors, + unpack_kwargs, + unpack_non_tensors, +) + + +@pytest.mark.parametrize("devices", [["cpu"], ["cuda"], ["cpu", "cuda"]]) +def test_apply_to_tensors(devices): + """Test apply_to_tensors for both cpu & gpu""" + if "cuda" in devices and not torch.cuda.is_available() or torch.cuda.device_count() < 1: + pytest.skip("Skipped due to lack of GPU") + expected = 0 + + def get_a_tensor(): + """Return a random tensor on random device.""" + dev = random.choice(devices) + shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10))) + t = torch.rand(shape).to(dev) + nonlocal expected + expected += t.numel() + return t + + # create a mixed bag of data. + data = [1, "str"] + data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) + data.insert(0, set(["x", get_a_tensor(), get_a_tensor()])) + data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2)))) + + total = 0 + + def fn(t, x=[[total]]): + nonlocal total + total += t.numel() + return t + + apply_to_tensors(fn, data) + assert total == expected, f"{total} vs. {expected}" + + +def test_pack_unpack(): + # tbd + p = pack_kwargs + up = unpack_kwargs + + +def test_split_unpack(): + # tbd + s = split_non_tensors + up = unpack_non_tensors From 6f153b0666aab1c114924e5727f3148744d32efa Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 3 Feb 2021 21:13:39 -0500 Subject: [PATCH 17/48] Test save/load state_dict V2 (#51) --- requirements-test.txt | 2 + .../test_shard_params_data_parallel.py | 114 +++++++++++++++++- 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 77ea87544..47b7173cc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,3 +13,5 @@ pytest-cov == 2.10.0 pytest-mpi == 0.4 pytest-timeout == 1.4.2 mpi4py == 3.0.3 +remote-pdb >= 2.1.0 +parameterized >= 0.8.1 diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 83970a936..215993ac5 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -18,6 +18,8 @@ from fairscale.nn.data_parallel import ShardParamsDataParallel from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal +# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 + class DistributedTest(unittest.TestCase): def setUp(self): @@ -36,6 +38,7 @@ def setUp(self): def _train_for_several_steps(model, num_steps, autocast): model_device = next(model.parameters()).device optim = torch.optim.Adam(model.parameters(), lr=0.0001) + # If you set this higher implem differs from ddp in the 5th decimal place for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -157,9 +160,7 @@ def test_transformer(self): keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"] for config in itertools.product([True, False], repeat=len(keys)): config = dict(zip(keys, config)) - spawn_and_init( - functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2, - ) + spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),) def test_cpu_offload_and_cpu_grads(self): for move_grads_choice in (True, None): @@ -233,6 +234,109 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") +class TestSaveLoadLocalStateDict(DistributedTest): + def test_load_local_state_dict(self): + test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}) + spawn_and_init(test_fn) + + def test_local_state_dict_flatten_params_breaks(self): + test_fn_broken = functools.partial(self._load_local_and_train, {"flatten_parameters": True}) + with self.assertRaises(Exception): + spawn_and_init(test_fn_broken) + # RuntimeError: Traceback [1] + # [1] https://gist.github.com/sshleifer/612d8eb02dbbf357d6133b2700e02f5e + + def test_local_state_dict_odd_vocab_shape_breaks(self): + test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}, d_model=16, d_vocab=37) + with self.assertRaises(Exception): + spawn_and_init(test_fn) + + @classmethod + def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): + """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" + model = ShardParamsDataParallel( + TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config + ).cuda() + state_1 = model.local_state_dict() + state_before_training = {k: v.cpu().clone() for k, v in state_1.items()} + model.load_local_state_dict(state_1) + state_1_weight = state_1["embed_tokens.weight"] + + # This weight will be sharded since we access module.state_dict directly + state_1_module_weight = model.module.state_dict()["embed_tokens.weight"] + torch.testing.assert_allclose(state_1_weight, state_1_module_weight) + torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) + self._train_for_several_steps(model, 4, False) + + state_2 = model.local_state_dict() + state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} + model.load_local_state_dict(state_2) + + assert state_1.keys() == state_2.keys() + + # Assert that parameters were updated since before training + unchanged = [] + for k in state_1: + if (state_before_training[k] == state_after_training[k]).all(): + unchanged.append(k) + if unchanged: + raise AssertionError(f"params {unchanged} not changed after training") + + +class TestSaveLoadStateDict(DistributedTest): + def test_calling_state_dict_twice_breaks(self): + test_fn = functools.partial(self._test_calling_state_dict_twice_breaks, {"flatten_parameters": False}) + spawn_and_init(test_fn) + + @classmethod + def _test_calling_state_dict_twice_breaks(self, config, rank, group): + ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config) + self._train_for_several_steps(ddp_model, 1, False) + ddp_model.state_dict() # Succeeds + try: + ddp_model.state_dict() + assert False, "Second state_dict call succeeded" + except Exception: + pass + + def test_state_dict_after_forward(self): + test_fn = functools.partial(self._test_module_state_dict, {"flatten_parameters": False}) + spawn_and_init(test_fn) + + @classmethod + def _test_module_state_dict(cls, config, rank, group): + ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) + try: + ddp_model.state_dict() + assert False, "Calling state_dict before forward succeeded" + except Exception: + pass + cls._train_for_several_steps(ddp_model, 2, False) + state_1 = ddp_model.state_dict() + # You must make a new ShardParamsDataParallel instance to use module.load_state_dict + unwrapped_model = TransformerWithSharedParams() + unwrapped_model.load_state_dict(state_1) + new_ddp_model = ShardParamsDataParallel(unwrapped_model, group, **config).cuda() + cls._train_for_several_steps(new_ddp_model, 2, False) + try: + ddp_model.load_state_dict(new_ddp_model.state_dict()) + assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" + except Exception: + pass + + +def get_sharded_model(): + sharded_model = ShardParamsDataParallel( + nn.Sequential( + nn.Linear(8, 100), + ShardParamsDataParallel(nn.Linear(100, 100)), + ShardParamsDataParallel(nn.Linear(100, 100)), + nn.Linear(100, 8), + ) + ) + return sharded_model + + class TestHooks(DistributedTest): # Feel free to modify these tests as the implementation changes. # They aspire to make sure that backward hooks are registered and used @@ -279,11 +383,9 @@ def _test_register_functions_called(self, rank, group, cuda_first=False): class TransformerWithSharedParams(nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs): super().__init__() torch.manual_seed(0) # keep everything deterministic - d_model = 16 - d_vocab = 32 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, From 41147722018a63a41ae5ff8c1b256bb8b3ac7cda Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Feb 2021 10:22:47 -0500 Subject: [PATCH 18/48] Replace x.view(-1) with torch.flatten(x) (#59) --- fairscale/nn/data_parallel/shard_params_data_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 66fc822a0..5a2cc4d30 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -178,7 +178,7 @@ def _shard_initial_params(self) -> None: e = (self.rank + 1) * shard_size orig_data = p.data - p.data = p.data.view(-1)[s:e].clone() + p.data = torch.flatten(p.data)[s:e].clone() free_storage_(orig_data) def __getattr__(self, name: str) -> Any: @@ -405,7 +405,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param.grad.data.div_(self.world_size) # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(param.grad.data.view(-1)) + param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) if self.move_grads_to_cpu: param._cpu_grad.copy_(param.grad.data, non_blocking=True) @@ -424,7 +424,7 @@ def _rebuild_full_params(self) -> None: for p in self.params: # All-gather parameters alloc_storage_(p._full_param, size=p._orig_size) - output_list = list(p._full_param.view(-1).chunk(self.world_size)) + output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) dist.all_gather(output_list, p.data, group=self.process_group) p.data = p._full_param p.grad = None From 36b2d39392f5168009a6c2e6bc14bc2db4298a7a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Feb 2021 11:43:12 -0500 Subject: [PATCH 19/48] Add more comments + docstrings (#58) --- .../shard_params_data_parallel.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 5a2cc4d30..4a913935c 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -117,7 +117,7 @@ def __init__( if self.flatten_parameters and len(params) > 0: self.module: nn.Module = FlattenParamsWrapper(module, param_list=params) - del module + del module # free original module in case it helps garbage collection self.params = [self.module.flat_param] else: self.module = module @@ -189,6 +189,11 @@ def __getattr__(self, name: str) -> Any: return getattr(self.module, name) def __getstate__(self) -> Dict[str, str]: + """Serialize the state of the current ShardParamsDataParallel instance. + + Some properties are not serializable (e.g., process groups, streams), so + we remove them and try to reconstruct them in :func:`__setstate__`. + """ state = copy.copy(self.__dict__) state["orig_sizes"] = [p._orig_size for p in self.params] if state["process_group"] is not None: @@ -204,7 +209,6 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def fixup(p: Parameter, size: torch.Size) -> Parameter: assert isinstance(p, Parameter) p.data = p.data.clone() # move tensors out of shared memory - # Ignore mypy error since we add additional fields to a param. p._is_sharded = True p._orig_size = size return p @@ -257,6 +261,9 @@ def _init_param(self, p: Parameter) -> None: assert p._is_sharded assert not hasattr(p, "_full_param") + # _fp32_shard will correspond to a single shard of the parameters in + # full precision (typically FP32, but this is dependent on the dtype of + # the model as it's passed in by the user). p._fp32_shard = p.data if self.mixed_precision: @@ -264,15 +271,25 @@ def _init_param(self, p: Parameter) -> None: if self.cpu_offload: assert p._fp32_shard.device == torch.device("cpu") + # If we plan to keep the FP32 parameters on CPU, then pinning + # memory allows us to later use non-blocking transfers when moving + # the FP32 param shard to self.compute_device. p._fp32_shard = p._fp32_shard.pin_memory() + p.data = p._fp32_shard + # In mixed precision mode, we maintain a reduced precision + # (typically FP16) parameter shard on self.compute_device for + # performing the computation in the forward/backward pass. We resize + # the storage to size 0 at init and rematerialize this (by copying + # from _fp32_shard) as needed. p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) free_storage_(p._fp16_shard) else: p._fp16_shard = None # use _fp32_shard - p.data = p._fp32_shard - + # We also maintain a full-sized parameter of type self.compute_dtype + # (typically FP16 for mixed_precision or FP32 otherwise). We resize the + # storage to size 0 at init and only materialize this when needed. p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) free_storage_(p._full_param) @@ -281,6 +298,9 @@ def _init_param(self, p: Parameter) -> None: grad_dtype = torch.float16 else: grad_dtype = torch.float32 + # We can optionally move the grad shard to CPU during the backward + # pass. In this case, it's important to pre-allocate the CPU grad + # shard in pinned memory so that we can do a non-blocking transfer. p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() @torch.no_grad() @@ -385,6 +405,25 @@ def _register_post_backward_hooks(self) -> None: @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: + """ + At the start of _post_backward_hook, param.grad contains the full + gradient for the local batch. The reduce_scatter op will replace + param.grad with a single shard of the summed gradient across all GPUs. + This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] + param.grad (GPU #1): [10, 12] + + The local GPU's optimizer.step is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by _shard_initial_params, which ensures that the + local optimizer only sees the relevant single parameter shard. + """ if param.grad is None: return if param.grad.requires_grad: From bc5190b8728e020f08bd1c8da3cbb1eaf902eab6 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Feb 2021 13:10:48 -0500 Subject: [PATCH 20/48] Rearrange dtype and device change in post-backward hook (#61) --- .../shard_params_data_parallel.py | 29 ++++++++++++++----- .../test_shard_params_data_parallel.py | 3 ++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 4a913935c..ab29c0418 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import torch +from torch.autograd import Variable import torch.distributed as dist from torch.distributed import ProcessGroup import torch.nn as nn @@ -294,14 +295,10 @@ def _init_param(self, p: Parameter) -> None: free_storage_(p._full_param) if self.move_grads_to_cpu: - if self.mixed_precision and not self.fp32_reduce_scatter: - grad_dtype = torch.float16 - else: - grad_dtype = torch.float32 # We can optionally move the grad shard to CPU during the backward # pass. In this case, it's important to pre-allocate the CPU grad # shard in pinned memory so that we can do a non-blocking transfer. - p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() + p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() @torch.no_grad() def _pre_forward_init(self) -> None: @@ -394,6 +391,7 @@ def _register_post_backward_hooks(self) -> None: # Register backward hooks to reshard params and reduce-scatter grads. if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled + self._post_backward_callback_queued = False for p in self.params: if p.requires_grad: if hasattr(p, "_shard_bwd_hook"): @@ -446,13 +444,28 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Reduce-scatter grad. param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) + if self.move_grads_to_cpu: param._cpu_grad.copy_(param.grad.data, non_blocking=True) param.grad.data = param._cpu_grad - # Cast grad to param's dtype (typically FP32). - if self.mixed_precision: - param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Enqueue a callback at the end of the backward pass to ensure that all + # post-backward work has finished. We only need one callback. + if not self._post_backward_callback_queued: + self._post_backward_callback_queued = True + Variable._execution_engine.queue_callback(self._wait_for_post_backward) + + @torch.no_grad() + def _wait_for_post_backward(self) -> None: + """Wait for all post-backward work to finish.""" + if self.move_grads_to_cpu: + # Wait for the non-blocking GPU -> CPU grad transfers to finish. + torch.cuda.current_stream().synchronize() @torch.no_grad() def _rebuild_full_params(self) -> None: diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 215993ac5..9bd9346ab 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -163,6 +163,9 @@ def test_transformer(self): spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),) def test_cpu_offload_and_cpu_grads(self): + # We only test True and None (which implies True). We don't test the + # False condition because that requires the optimizer to internally do + # the device transfer and PyTorch optimizers don't support this. for move_grads_choice in (True, None): config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": move_grads_choice} test_fn = functools.partial( From 8a5f81c7ba7cca3da8d0d95781b79e7ce6802a6a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Feb 2021 13:13:23 -0500 Subject: [PATCH 21/48] Do reduce-scatter in a separate CUDA stream (#62) * Do reduce-scatter in a separate CUDA stream * Add _post_backward_stream to stubs --- .../shard_params_data_parallel.py | 39 ++++++++++++------- stubs/torch/nn/parameter.pyi | 2 + .../test_shard_params_data_parallel.py | 15 +++++++ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index ab29c0418..abd6f35ea 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -300,6 +300,9 @@ def _init_param(self, p: Parameter) -> None: # shard in pinned memory so that we can do a non-blocking transfer. p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() + # Stream for overlapping the backward pass and gradient reductions. + p._post_backward_stream = torch.cuda.Stream() + @torch.no_grad() def _pre_forward_init(self) -> None: first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] @@ -430,29 +433,35 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Free full params and switch to FP32 shard after backward. self._free_full_params([param]) self._use_fp32_param_shard([param]) - if self.mixed_precision: + # This is a no-op if reshard_after_forward is True, since we already + # free the param shard when rebuilding the full params in the + # pre_backward_hook. self._free_fp16_param_shard([param]) - if self.fp32_reduce_scatter: + # Wait for all work in the current stream to finish, then start the + # reductions in _post_backward_stream. + param._post_backward_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(param._post_backward_stream): + if self.mixed_precision and self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) - # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.world_size) + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.world_size) - # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) + # Reduce-scatter grad. + param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) - # Cast grad to param's dtype (typically FP32). Note: we do this - # before the move_grads_to_cpu step so that this entire hook remains - # non-blocking. The downside is a bit more D2H transfer in that case. - if self.mixed_precision: - param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) - if self.move_grads_to_cpu: - param._cpu_grad.copy_(param.grad.data, non_blocking=True) - param.grad.data = param._cpu_grad + if self.move_grads_to_cpu: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad # Enqueue a callback at the end of the backward pass to ensure that all # post-backward work has finished. We only need one callback. @@ -463,6 +472,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @torch.no_grad() def _wait_for_post_backward(self) -> None: """Wait for all post-backward work to finish.""" + for p in self.params: + torch.cuda.current_stream().wait_stream(p._post_backward_stream) if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 05f24df38..7cd76d9d2 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -2,6 +2,7 @@ from typing import Optional from .. import Size, Tensor +from ..cuda import Stream import builtins class Parameter(Tensor): @@ -13,6 +14,7 @@ class Parameter(Tensor): _full_param: Tensor _fp32_shard: Tensor _fp16_shard: Optional[Tensor] + _post_backward_stream: Stream def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 9bd9346ab..452f2421a 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -203,6 +203,21 @@ def _maybe_wrap(layer): ) return ModuleWithDelay(model, delay_after_loss_ms=250) + def test_delayed_reduce_scatter(self): + # We insert a delay in the torch.distributed.reduce_scatter op, so that + # the post_backward_stream takes much longer than the backward pass. + # This tests that we properly block at the end of the backward pass for + # the reductions to finish. + config = {"mixed_precision": True} + with mock.patch("torch.distributed.reduce_scatter", wraps=self._delayed_reduce_scatter): + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config) + spawn_and_init(test_fn) + + @classmethod + def _delayed_reduce_scatter(cls, *args, **kwargs): + torch.cuda._sleep(int(250 * get_cycles_per_ms())) + return torch.distributed.reduce_scatter(*args, **kwargs) + @classmethod def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: From 72c1f63af828d8fcb94f2f674bebfc5a38c90d81 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 5 Feb 2021 11:48:07 -0500 Subject: [PATCH 22/48] tests use spawn_for_all_world_sizes (#63) --- .../test_shard_params_data_parallel.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 452f2421a..057b874bb 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -6,7 +6,6 @@ import functools import itertools import sys -import tempfile from typing import Dict import unittest from unittest import mock @@ -16,7 +15,13 @@ from torch import nn from fairscale.nn.data_parallel import ShardParamsDataParallel -from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal +from fairscale.utils.testing import ( + DeviceAndTypeCheckModule, + dist_init, + get_cycles_per_ms, + objects_are_equal, + spawn_for_all_world_sizes, +) # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 @@ -125,7 +130,7 @@ def test_fp32_reduce_scatter_autocast(self): def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2): """Call test_dtypes inside of torch.multiprocessing.spawn""" fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype) - spawn_and_init(fn, world_size=world_size) + spawn_and_init(fn, world_sizes=[world_size]) @staticmethod def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): @@ -448,27 +453,16 @@ def get_loss(self, input, output): return loss -def spawn_and_init(fn, world_size=2, args=None): +def spawn_and_init(fn, args=None, **spawn_kwargs): if args is None: args = () - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - torch.multiprocessing.spawn( - fn=functools.partial(init_and_run, fn, args), - args=(world_size, tmp_file.name), - nprocs=world_size, - join=True, - ) - -def distributed_init(rank, world_size, tmp_file): - torch.distributed.init_process_group( - backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank, - ) - torch.cuda.set_device(rank) + run_fn = functools.partial(init_and_run, fn, args) + spawn_for_all_world_sizes(run_fn, **spawn_kwargs) -def init_and_run(fn, args, rank, world_size, tmp_file): - distributed_init(rank, world_size, tmp_file) +def init_and_run(fn, args, rank, world_size, filename, filename_rpc): + dist_init(rank, world_size, filename, filename_rpc) group = torch.distributed.new_group() fn(rank, group, *args) From f481877a90c359eb0dc8faf06ec0e44e68d94292 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 6 Feb 2021 15:16:04 -0500 Subject: [PATCH 23/48] Fix state_dict bugs (#60) --- .../shard_params_data_parallel.py | 101 ++++++++------- .../test_shard_params_data_parallel.py | 119 +++++++++++------- 2 files changed, 131 insertions(+), 89 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index abd6f35ea..2faa1b8d9 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -124,8 +124,8 @@ def __init__( self.module = module self.params = params - # Shard module parameters. - self._shard_initial_params() + # Shard module parameters in place + self._shard_parameters_() if self.mixed_precision: # Cast all module buffers to FP16 (buffers are not sharded). @@ -139,9 +139,10 @@ def __init__( # ShardParamsDataParallel instances. This flag is only set after the # first forward pass. self._is_root: Optional[bool] = None + self._init_full_param() @torch.no_grad() - def _shard_initial_params(self) -> None: + def _shard_parameters_(self) -> None: """ At initialization we wrap a module with full parameters and shard the parameters in-place. Sharding is implemented by viewing each parameter @@ -239,8 +240,9 @@ def local_state_dict(self, *args, **kwargs): # type: ignore wrapped with ShardParamsDataParallel. """ if self.flatten_parameters: - kwargs["unflatten_params"] = False - return self.module.state_dict(*args, **kwargs) + return self.module.flat_state_dict(*args, **kwargs) # type: ignore + else: + return self.module.state_dict(*args, **kwargs) def load_state_dict( self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True @@ -257,6 +259,13 @@ def load_local_state_dict( """Load a local (sharded) state_dict.""" return self.module.load_state_dict(state_dict, strict) + @torch.no_grad() + def _init_full_param(self) -> None: + """Set p._full_param if not already set.""" + first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] + for p in first_time_params: + self._init_param(p) + @torch.no_grad() def _init_param(self, p: Parameter) -> None: assert p._is_sharded @@ -305,33 +314,19 @@ def _init_param(self, p: Parameter) -> None: @torch.no_grad() def _pre_forward_init(self) -> None: - first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] - for p in first_time_params: - self._init_param(p) - - if len(first_time_params) > 0: - if self._is_root is None: - # This implies that no other ShardParamsDataParallel instance - # wraps this instance, otherwise it would have already set this - # flag to False. - self._is_root = True - - # As the root, we now set all children instances to False. - for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardParamsDataParallel): - assert m._is_root is None - m._is_root = False - - if self._is_root: - # Stream for moving FP32 master params (which may be on CPU) to - # FP16 for computation. We share this stream with all children - # instances, which allows them to overlap transfers across the - # forward pass without synchronizing with the default stream. - self._fp32_to_fp16_stream = torch.cuda.Stream() - - for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardParamsDataParallel): - m._fp32_to_fp16_stream = self._fp32_to_fp16_stream + self._set_is_root() + if not self.cpu_offload: + self._move_fp32_shard_to_cuda() + if self._is_root and not hasattr(self, "_fp32_to_fp16_stream"): + # Stream for moving FP32 master params (which may be on CPU) to + # FP16 for computation. We share this stream with all children + # instances, which allows them to overlap transfers across the + # forward pass without synchronizing with the default stream. + self._fp32_to_fp16_stream = torch.cuda.Stream() + + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + m._fp32_to_fp16_stream = self._fp32_to_fp16_stream assert self._is_root is not None if self._is_root: @@ -339,6 +334,18 @@ def _pre_forward_init(self) -> None: # stream, so we don't move the FP32 master weights prematurely. self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) + def _set_is_root(self) -> None: + """If true, implies that no other ShardParamsDataParallel instance. Called by forward.""" + if self._is_root is not None: + return + # No ShardParamsDataParallel instance wraps this, else _is_root would be set to False + self._is_root = True + # As the root, we now set all children instances to False. + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + assert m._is_root is None + m._is_root = False + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_init() @@ -417,12 +424,12 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param.grad (GPU #1): [5, 6, 7, 8] after reduce_scatter: - param.grad (GPU #0): [6, 8] - param.grad (GPU #1): [10, 12] + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 The local GPU's optimizer.step is responsible for updating a single shard of params, also corresponding to the current GPU's rank. This - alignment is created by _shard_initial_params, which ensures that the + alignment is created by _shard_parameters_, which ensures that the local optimizer only sees the relevant single parameter shard. """ if param.grad is None: @@ -480,19 +487,21 @@ def _wait_for_post_backward(self) -> None: @torch.no_grad() def _rebuild_full_params(self) -> None: - """Get all shards of params.""" - if self.mixed_precision: + """Gather all shards of params.""" + if self.mixed_precision and hasattr(self, "_fp32_to_fp16_stream"): self._cast_fp32_param_shards_to_fp16() for p in self.params: - # All-gather parameters - alloc_storage_(p._full_param, size=p._orig_size) - output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) - dist.all_gather(output_list, p.data, group=self.process_group) + if p._full_param.storage().size() != p._orig_size.numel(): + # All-gather parameters + alloc_storage_(p._full_param, size=p._orig_size) + output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) + dist.all_gather(output_list, p.data, group=self.process_group) + p.data = p._full_param p.grad = None - if self.mixed_precision: + if self.mixed_precision and hasattr(self, "_fp32_to_fp16_stream"): self._free_fp16_param_shard([p]) @torch.no_grad() @@ -517,6 +526,12 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: p._full_param.record_stream(current_stream) free_storage_(p._full_param) + @torch.no_grad() + def _move_fp32_shard_to_cuda(self) -> None: + assert not self.cpu_offload + for p in self.params: + p._fp32_shard = p._fp32_shard.to(p.data.device) + @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Use FP32 shard for a list of params.""" @@ -596,5 +611,7 @@ def free_storage_(data: torch.Tensor) -> None: @torch.no_grad() def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: """Allocate storage for a tensor.""" + if data.storage().size() == size.numel(): # no need to reallocate + return assert data.storage().size() == 0 data.storage().resize_(size.numel()) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 057b874bb..59e9c9e98 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -42,8 +42,7 @@ def setUp(self): @staticmethod def _train_for_several_steps(model, num_steps, autocast): model_device = next(model.parameters()).device - optim = torch.optim.Adam(model.parameters(), lr=0.0001) - # If you set this higher implem differs from ddp in the 5th decimal place + optim = torch.optim.Adam(model.parameters(), lr=0.01) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -168,17 +167,13 @@ def test_transformer(self): spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),) def test_cpu_offload_and_cpu_grads(self): - # We only test True and None (which implies True). We don't test the - # False condition because that requires the optimizer to internally do + # We don't test the False condition because that requires the optimizer to internally do # the device transfer and PyTorch optimizers don't support this. - for move_grads_choice in (True, None): - config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": move_grads_choice} - test_fn = functools.partial( - self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False - ) - spawn_and_init(test_fn) + config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + spawn_and_init(test_fn) - def test_cpu_offload_and_cuda_grads(self): + def test_cpu_offload_and_cuda_grads_breaks(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} with self.assertRaises(Exception): # RuntimeError inside spawn @@ -257,18 +252,14 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") -class TestSaveLoadLocalStateDict(DistributedTest): - def test_load_local_state_dict(self): - test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}) +class TestLocalStateDict(DistributedTest): + @parameterized.expand([[True, True], [False, False]]) + def test_load_local_state_dict(self, flatten_params, mixed_precision): + test_fn = functools.partial( + self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision} + ) spawn_and_init(test_fn) - def test_local_state_dict_flatten_params_breaks(self): - test_fn_broken = functools.partial(self._load_local_and_train, {"flatten_parameters": True}) - with self.assertRaises(Exception): - spawn_and_init(test_fn_broken) - # RuntimeError: Traceback [1] - # [1] https://gist.github.com/sshleifer/612d8eb02dbbf357d6133b2700e02f5e - def test_local_state_dict_odd_vocab_shape_breaks(self): test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}, d_model=16, d_vocab=37) with self.assertRaises(Exception): @@ -282,14 +273,18 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): ).cuda() state_1 = model.local_state_dict() state_before_training = {k: v.cpu().clone() for k, v in state_1.items()} + assert len(state_1) > 0 model.load_local_state_dict(state_1) - state_1_weight = state_1["embed_tokens.weight"] + weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight" - # This weight will be sharded since we access module.state_dict directly - state_1_module_weight = model.module.state_dict()["embed_tokens.weight"] - torch.testing.assert_allclose(state_1_weight, state_1_module_weight) - torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) - self._train_for_several_steps(model, 4, False) + state_1_weight = state_1[weight_key] + assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32" + if not model.flatten_parameters: + # This weight will be sharded since we access module.state_dict directly + state_1_module_weight = model.module.state_dict()[weight_key] + torch.testing.assert_allclose(state_1_weight, state_1_module_weight) + torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) + self._train_for_several_steps(model, 4, model.mixed_precision) state_2 = model.local_state_dict() state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} @@ -307,40 +302,57 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): class TestSaveLoadStateDict(DistributedTest): - def test_calling_state_dict_twice_breaks(self): - test_fn = functools.partial(self._test_calling_state_dict_twice_breaks, {"flatten_parameters": False}) + @parameterized.expand([[False], [True]]) + def test_calling_state_dict_twice(self, mixed_precision): + test_fn = functools.partial( + self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) spawn_and_init(test_fn) @classmethod - def _test_calling_state_dict_twice_breaks(self, config, rank, group): + def _test_calling_state_dict_twice(self, config, rank, group): ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config) - self._train_for_several_steps(ddp_model, 1, False) - ddp_model.state_dict() # Succeeds - try: - ddp_model.state_dict() - assert False, "Second state_dict call succeeded" - except Exception: - pass + autocast = ddp_model.mixed_precision + self._train_for_several_steps(ddp_model, 1, autocast) + ddp_model.state_dict() + ddp_model.state_dict() # second call + + @parameterized.expand([[False], [True]]) + def test_state_dict_after_forward(self, mixed_precision): + test_fn = functools.partial( + self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) + spawn_and_init(test_fn) - def test_state_dict_after_forward(self): - test_fn = functools.partial(self._test_module_state_dict, {"flatten_parameters": False}) + @parameterized.expand([[False], [True]]) + def test_state_dict_before_forward(self, mixed_precision): + test_fn = functools.partial( + self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision} + ) spawn_and_init(test_fn) + @classmethod + def _test_state_dict_before_forward(cls, config, rank, group): + ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) + assert not hasattr(ddp_model, "_fp32_to_fp16_stream") + sd = ddp_model.state_dict() + assert not hasattr(ddp_model, "_fp32_to_fp16_stream") + expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32 + wt = sd["embed_tokens.weight"] + assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}" + cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision) + @classmethod def _test_module_state_dict(cls, config, rank, group): ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) - try: - ddp_model.state_dict() - assert False, "Calling state_dict before forward succeeded" - except Exception: - pass - cls._train_for_several_steps(ddp_model, 2, False) + autocast = ddp_model.mixed_precision + cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() # You must make a new ShardParamsDataParallel instance to use module.load_state_dict unwrapped_model = TransformerWithSharedParams() unwrapped_model.load_state_dict(state_1) new_ddp_model = ShardParamsDataParallel(unwrapped_model, group, **config).cuda() - cls._train_for_several_steps(new_ddp_model, 2, False) + cls._train_for_several_steps(new_ddp_model, 2, autocast) try: ddp_model.load_state_dict(new_ddp_model.state_dict()) assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" @@ -369,9 +381,22 @@ def test_output_backward_hooks(self, cuda_first): fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first) spawn_and_init(fn) + def test_backward_hooks_after_save(self): + fn = functools.partial(self._test_backward_hooks_after_save, cuda_first=False) + spawn_and_init(fn) + @classmethod - def _test_output_backward_hooks(self, rank, group, cuda_first=False): + def _test_backward_hooks_after_save(self, rank, group, cuda_first=False): model = self.get_wrapped_model(group, cuda_first=cuda_first) + self._train_for_several_steps(model, 2, model.mixed_precision) + state_1 = model.local_state_dict() + model.load_local_state_dict(state_1) + self._test_output_backward_hooks(rank, group, cuda_first=cuda_first, model=model) + + @classmethod + def _test_output_backward_hooks(self, rank, group, cuda_first=False, model=None): + if model is None: + model = self.get_wrapped_model(group, cuda_first=cuda_first) optim = torch.optim.Adam(model.parameters(), lr=0.0001) optim.zero_grad() # Inputs always cuda regardless of move_grads_cpu, or model.device From 515411bab4caf0490dd892181c0a47a60dcda21a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 6 Feb 2021 21:21:26 -0500 Subject: [PATCH 24/48] update comments to reflect where we are in stack (#69) --- fairscale/nn/data_parallel/shard_params_data_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 2faa1b8d9..1dd6b50f7 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -290,16 +290,16 @@ def _init_param(self, p: Parameter) -> None: # In mixed precision mode, we maintain a reduced precision # (typically FP16) parameter shard on self.compute_device for # performing the computation in the forward/backward pass. We resize - # the storage to size 0 at init and rematerialize this (by copying - # from _fp32_shard) as needed. + # the storage to size 0 at init (here) and re-materialize + # (by copying from _fp32_shard) as needed. p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) free_storage_(p._fp16_shard) else: p._fp16_shard = None # use _fp32_shard # We also maintain a full-sized parameter of type self.compute_dtype - # (typically FP16 for mixed_precision or FP32 otherwise). We resize the - # storage to size 0 at init and only materialize this when needed. + # (FP16 for mixed_precision or FP32 otherwise). We resize the + # storage to size 0 at init (here) and only materialize as needed. p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) free_storage_(p._full_param) From ec4e75ec65e1b9a9a122b750acfef0fe9d5aa20a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 7 Feb 2021 13:58:16 -0500 Subject: [PATCH 25/48] [CI] use parameterized.expand to make each test faster (#68) --- .../test_shard_params_data_parallel.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 59e9c9e98..c13196f11 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -153,18 +153,24 @@ def _reduce_scatter(self, tensor): loss.backward() +keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"] +CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] + + +def rename_test(testcase_func, param_num, param): + return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) + + class TestComparisonToPyTorchDDP(DistributedTest): """ Compare losses and parameter values after several updates when using PyTorch DDP vs. ShardParamsDataParallel. """ - def test_transformer(self): + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_transformer_parameterized(self, config): # Test every combination of these options: - keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"] - for config in itertools.product([True, False], repeat=len(keys)): - config = dict(zip(keys, config)) - spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),) + spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config)) def test_cpu_offload_and_cpu_grads(self): # We don't test the False condition because that requires the optimizer to internally do @@ -253,7 +259,7 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 class TestLocalStateDict(DistributedTest): - @parameterized.expand([[True, True], [False, False]]) + @parameterized.expand([[True, True], [False, False]], name_func=rename_test) def test_load_local_state_dict(self, flatten_params, mixed_precision): test_fn = functools.partial( self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision} @@ -302,8 +308,8 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): class TestSaveLoadStateDict(DistributedTest): - @parameterized.expand([[False], [True]]) - def test_calling_state_dict_twice(self, mixed_precision): + @parameterized.expand([[False], [True]], name_func=rename_test) + def test_calling_state_dict_twice_mixed_precision(self, mixed_precision): test_fn = functools.partial( self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision} ) @@ -317,14 +323,14 @@ def _test_calling_state_dict_twice(self, config, rank, group): ddp_model.state_dict() ddp_model.state_dict() # second call - @parameterized.expand([[False], [True]]) - def test_state_dict_after_forward(self, mixed_precision): + @parameterized.expand([[False], [True]], name_func=rename_test) + def test_state_dict_after_forward_mixed_precision(self, mixed_precision): test_fn = functools.partial( self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision} ) spawn_and_init(test_fn) - @parameterized.expand([[False], [True]]) + @parameterized.expand([[False], [True]], name_func=rename_test) def test_state_dict_before_forward(self, mixed_precision): test_fn = functools.partial( self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision} From b1460d34cbd4f1e50e843340c626e29b900c7090 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 7 Feb 2021 14:39:49 -0500 Subject: [PATCH 26/48] Fix delayed reduce_scatter test (#74) * Fix delayed_reduce_scatter test * Decompose NestedWrappedModule from ModuleWithDelay --- .../test_shard_params_data_parallel.py | 102 +++++++++++------- 1 file changed, 61 insertions(+), 41 deletions(-) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index c13196f11..5ca98c8f1 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -51,7 +51,7 @@ def _train_for_several_steps(model, num_steps, autocast): output = model(*input) loss = model.module.get_loss(input, output).to(model_device) assert loss.dtype == torch.float32 - loss.backward() + model.module.run_backward(loss) optim.step() return loss.detach() @@ -193,37 +193,20 @@ def test_delayed_optim_step(self): # This tests our streams logic, and that we don't start the FP32 -> FP16 # transfer until after the optimization step completes. config = {"mixed_precision": True} - test_fn = functools.partial(self._test_identical_outputs, self._delayed_optim_step_model, config) + model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_after_loss_ms=250) + test_fn = functools.partial(self._test_identical_outputs, model_fn, config) spawn_and_init(test_fn) - @classmethod - def _delayed_optim_step_model(cls, rank, group, config=None): - def _maybe_wrap(layer): - if config is not None: - return ShardParamsDataParallel(layer, group, **config) - return layer - - torch.manual_seed(0) # keep everything deterministic - model = nn.Sequential( - nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), - ) - return ModuleWithDelay(model, delay_after_loss_ms=250) - def test_delayed_reduce_scatter(self): # We insert a delay in the torch.distributed.reduce_scatter op, so that # the post_backward_stream takes much longer than the backward pass. # This tests that we properly block at the end of the backward pass for # the reductions to finish. config = {"mixed_precision": True} - with mock.patch("torch.distributed.reduce_scatter", wraps=self._delayed_reduce_scatter): - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config) + model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_before_reduction_ms=250) + test_fn = functools.partial(self._test_identical_outputs, model_fn, config) spawn_and_init(test_fn) - @classmethod - def _delayed_reduce_scatter(cls, *args, **kwargs): - torch.cuda._sleep(int(250 * get_cycles_per_ms())) - return torch.distributed.reduce_scatter(*args, **kwargs) - @classmethod def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: @@ -237,13 +220,13 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = model_init_fn(rank, group).cuda() + model = model_init_fn(group=group, wrapper_config=None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. - model = ShardParamsDataParallel(model_init_fn(rank, group, config), group, **config) + model = ShardParamsDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) if use_cuda: model = model.cuda() else: @@ -366,18 +349,6 @@ def _test_module_state_dict(cls, config, rank, group): pass -def get_sharded_model(): - sharded_model = ShardParamsDataParallel( - nn.Sequential( - nn.Linear(8, 100), - ShardParamsDataParallel(nn.Linear(100, 100)), - ShardParamsDataParallel(nn.Linear(100, 100)), - nn.Linear(100, 8), - ) - ) - return sharded_model - - class TestHooks(DistributedTest): # Feel free to modify these tests as the implementation changes. # They aspire to make sure that backward hooks are registered and used @@ -464,12 +435,23 @@ def get_loss(self, input, output): _, tgt = input return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") + def run_backward(self, loss): + loss.backward() + -class ModuleWithDelay(nn.Module): - def __init__(self, module, delay_after_loss_ms): +class NestedWrappedModule(nn.Module): + def __init__(self, group, wrapper_config): super().__init__() - self.module = module - self.delay_after_loss_ms = delay_after_loss_ms + + def _maybe_wrap(layer): + if wrapper_config is not None: + return ShardParamsDataParallel(layer, group, **wrapper_config) + return layer + + torch.manual_seed(0) # keep everything deterministic + self.module = nn.Sequential( + nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), + ) def get_input(self, device): torch.manual_seed(1) # keep everything deterministic @@ -480,9 +462,47 @@ def forward(self, x): def get_loss(self, input, output): loss = output.sum() - torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) return loss + def run_backward(self, loss): + loss.backward() + + +class ModuleWithDelay(nn.Module): + def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0): + super().__init__() + self.delay_after_loss_ms = delay_after_loss_ms + self.delay_before_reduction_ms = delay_before_reduction_ms + self.module = module + + def get_input(self, device): + return self.module.get_input(device) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = self.module.get_loss(input, output) + if self.delay_after_loss_ms > 0: + torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) + return loss + + def run_backward(self, loss): + orig_reduce_scatter = torch.distributed.reduce_scatter + + def _delayed_reduce_scatter(*args, **kwargs): + if self.delay_before_reduction_ms > 0: + torch.cuda._sleep(int(self.delay_before_reduction_ms * get_cycles_per_ms())) + return orig_reduce_scatter(*args, **kwargs) + + with mock.patch("torch.distributed.reduce_scatter", _delayed_reduce_scatter): + self.module.run_backward(loss) + + +class NestedWrappedModuleWithDelay(ModuleWithDelay): + def __init__(self, group, wrapper_config, **kwargs): + super().__init__(NestedWrappedModule(group, wrapper_config), **kwargs) + def spawn_and_init(fn, args=None, **spawn_kwargs): if args is None: From 014ad05afaa8da4717f25c368e76321ac4494c35 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Sun, 7 Feb 2021 15:38:02 -0800 Subject: [PATCH 27/48] add unit test pack/unpack kwargs (#65) * add unit test pack/unpack kwargs * added two more corner cases * more doc and more tests * more corner cases * formatting * Update fairscale/utils/containers.py Co-authored-by: Sam Shleifer * with pytest.raises is awesome * addressed comment * add tuple to be tested Co-authored-by: Sam Shleifer --- fairscale/utils/containers.py | 21 +++++++++-- tests/utils/test_containers.py | 69 +++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/fairscale/utils/containers.py b/fairscale/utils/containers.py index 0f5134c2e..dbad26337 100644 --- a/fairscale/utils/containers.py +++ b/fairscale/utils/containers.py @@ -32,11 +32,15 @@ def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]: """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite) + Usage:: kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) - assert args == [1, 2] + assert args == (1, 2) assert kwargs == {"a": 3, "b": 4} """ kwarg_keys: List[str] = [] @@ -49,6 +53,7 @@ def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """See pack_kwargs.""" + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -60,11 +65,19 @@ def split_non_tensors( mixed: Union[torch.Tensor, Tuple[Any, ...]] ) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: """ + Split a tuple into a list of tensors and the rest with information + for later reconstruction. + Usage:: x = torch.Tensor([1]) y = torch.Tensor([2]) tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + assert tensors == (x, y) + assert packed_non_tensors == { + "is_tensor": [True, True, False, False], + "objects": [None, 3], + } recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (x, y, None, 3) """ @@ -88,11 +101,13 @@ def unpack_non_tensors( """See split_non_tensors.""" if packed_non_tensors is None: return tensors - assert isinstance(packed_non_tensors, dict) + assert isinstance(packed_non_tensors, dict), type(packed_non_tensors) mixed: List[Any] = [] is_tensor_list = packed_non_tensors["is_tensor"] objects = packed_non_tensors["objects"] - assert len(tensors) + len(objects) == len(is_tensor_list) + assert len(tensors) + len(objects) == len(is_tensor_list), ( + f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}" + ) obj_i = tnsr_i = 0 for is_tensor in is_tensor_list: if is_tensor: diff --git a/tests/utils/test_containers.py b/tests/utils/test_containers.py index b0eebbeb1..d304478cb 100644 --- a/tests/utils/test_containers.py +++ b/tests/utils/test_containers.py @@ -57,12 +57,69 @@ def fn(t, x=[[total]]): def test_pack_unpack(): - # tbd - p = pack_kwargs - up = unpack_kwargs + """Test pack_kwargs and unpack_kwargs.""" + kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4) + assert kwarg_keys == tuple() + assert flat_args == (1, 2, 3, 4) + + kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5,)) + assert kwarg_keys == ("a", "b", "c", "d", "e") + assert flat_args == (1, {2: "2"}, {3}, [4], (5,)) + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + + args, kwargs = unpack_kwargs([], flat_args) + assert kwargs == {} + assert args == (1, 2, 3, 4) + + args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args) + assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4} + assert args == tuple() + + with pytest.raises(AssertionError): + # too many keys should assert. + args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args) def test_split_unpack(): - # tbd - s = split_non_tensors - up = unpack_non_tensors + """Test split_non_tensors and unpack_non_tensors.""" + x = torch.Tensor([1]) + y = torch.Tensor([2]) + + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + assert tensors == (x, y) + assert packed_non_tensors == { + "is_tensor": [True, True, False, False], + "objects": [None, 3], + } + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + + tensors, packed_non_tensors = split_non_tensors((None, 3, x, y)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (None, 3, x, y) + + tensors, packed_non_tensors = split_non_tensors((None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (None, 3) + + tensors, packed_non_tensors = split_non_tensors((x, y)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y) + + recon = unpack_non_tensors(tensors, None) + assert recon == (x, y) + + with pytest.raises(AssertionError): + # assert the second arg should be a dict. + recon = unpack_non_tensors(tensors, set()) + + with pytest.raises(AssertionError): + # assert the content of the second arg should be sane. + recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []}) From 0ca378b8bba7d4cbd738986b98fc509a9904d48e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 7 Feb 2021 18:52:33 -0500 Subject: [PATCH 28/48] Refactor param init and streams logic (#73) - Add two new tests (TestParamInit and TestSerialization) which would have failed previously. These mostly cover the fairseq usage that was not captured by tests before. - Deprecate the `compute_device` option, since we don't actually use it anywhere (nor do we need it in fairseq) - Remove `_move_fp32_shard_to_cuda` and embrace a stronger invariant: p.data and p._fp32_shard should always be the same at the start and end of each function (namely, state_dict, and also forward/backward). - Slightly unrelated, but refactor streams logic a bit, so we have a single `self._streams` dictionary -- this will make an upcoming PR that adds more streams easier --- .../shard_params_data_parallel.py | 242 ++++++++++-------- fairscale/utils/testing.py | 12 + stubs/torch/nn/parameter.pyi | 1 - .../test_shard_params_data_parallel.py | 79 +++++- 4 files changed, 228 insertions(+), 106 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 1dd6b50f7..10d9fb8f0 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -57,27 +57,26 @@ class ShardParamsDataParallel(nn.Module): Args: module (nn.Module): module to checkpoint process_group (Optional): process group for sharding - reshard_after_forward (bool, Optional): if True, reshard parameters + reshard_after_forward (bool, Optional): if ``True``, reshard parameters after the forward pass. This saves memory but slows training. - mixed_precision (bool, Optional): if True, inputs, activations and + mixed_precision (bool, Optional): if ``True``, inputs, activations and gradients will be kept in FP16; computation and communication will occur in FP16; and a (sharded) master copy of the model weights will be maintained in FP32. - fp32_reduce_scatter (bool, Optional): if True, then reduce-scatter - gradients in FP32. This is only relevant when *mixed_precision* is - ``True``. - flatten_parameters (bool, Optional): if True, flatten parameters into a - single contiguous tensor, which improves training speed. - cpu_offload (bool, Optional): if True, offload FP32 params to CPU. This - is only relevant when *mixed_precision* is ``True``. - compute_device (torch.device, Optional): device to move params to for - computation. This is primarily relevant with *cpu_offload* and - defaults to "cuda". + fp32_reduce_scatter (bool, Optional): if ``True``, then reduce-scatter + gradients in FP32. This is only relevant when *``mixed_precision``* + is ``True``. + flatten_parameters (bool, Optional): if ``True``, flatten parameters + into a single contiguous tensor, which improves training speed. + cpu_offload (bool, Optional): if ``True``, offload FP32 params to CPU. + This is only relevant when *``mixed_precision``* is ``True``. compute_dtype (torch.dtype, Optional): dtype for full parameters for - computation. This defaults to torch.float32 unless *mixed_precision* - is set, in which case it defaults to torch.float16. - move_grads_to_cpu (bool, Optional): move grad shard to CPU after + computation. This defaults to ``torch.float32`` unless + *``mixed_precision``* is set, in which case it defaults to + ``torch.float16``. + move_grads_to_cpu (bool, Optional): move gradient shard to CPU after reduction. This is useful when combined with CPU-based optimizers. + It defaults to the value of *``cpu_offload``*. """ def __init__( @@ -89,7 +88,6 @@ def __init__( fp32_reduce_scatter: bool = False, flatten_parameters: bool = True, cpu_offload: bool = False, - compute_device: Optional[torch.device] = None, compute_dtype: Optional[torch.dtype] = None, move_grads_to_cpu: Optional[bool] = None, ): @@ -102,7 +100,6 @@ def __init__( self.fp32_reduce_scatter = fp32_reduce_scatter self.flatten_parameters = flatten_parameters self.cpu_offload = cpu_offload - self.compute_device = compute_device or torch.device("cuda") self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu @@ -135,11 +132,7 @@ def __init__( for n, p in self.named_parameters(): assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" - # Flag to indicate if this instance is wrapped by any other - # ShardParamsDataParallel instances. This flag is only set after the - # first forward pass. - self._is_root: Optional[bool] = None - self._init_full_param() + self._reset_lazy_init() @torch.no_grad() def _shard_parameters_(self) -> None: @@ -200,8 +193,7 @@ def __getstate__(self) -> Dict[str, str]: state["orig_sizes"] = [p._orig_size for p in self.params] if state["process_group"] is not None: state["process_group"] = "MISSING" # process_group isn't pickleable - if "_fp32_to_fp16_stream" in state: - del state["_fp32_to_fp16_stream"] + self._reset_lazy_init() return state def __setstate__(self, state: Dict[str, Any]) -> None: @@ -217,6 +209,7 @@ def fixup(p: Parameter, size: torch.Size) -> Parameter: self.params = [fixup(p, size) for p, size in zip(self.params, self.orig_sizes)] del self.orig_sizes + self._reset_lazy_init() # TODO (Min): figuring out how to do typing for this overloaded function. def state_dict(self, *args, **kwargs): # type: ignore @@ -226,11 +219,15 @@ def state_dict(self, *args, **kwargs): # type: ignore wrapped Module without any sharding-specific logic. """ torch.cuda.synchronize() + self._lazy_init() self._rebuild_full_params() + state_dict = self.module.state_dict(*args, **kwargs) # We don't free the params after generating the state dict, since # freeing is done in-place (via the Storage) and would corrupt the - # returned state dict. - return self.module.state_dict(*args, **kwargs) + # returned state dict. However, we need to maintain the invariant that + # p.data corresponds to the FP32 param shard, so we do that here. + self._use_fp32_param_shard() + return state_dict # TODO (Min): figuring out how to do typing for this overloaded function. def local_state_dict(self, *args, **kwargs): # type: ignore @@ -248,6 +245,8 @@ def load_state_dict( self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True ) -> NamedTuple: """Load a whole (unsharded) state_dict.""" + torch.cuda.synchronize() + self._lazy_init() self._rebuild_full_params() output = self.module.load_state_dict(state_dict, strict) self._free_full_params() @@ -259,21 +258,55 @@ def load_local_state_dict( """Load a local (sharded) state_dict.""" return self.module.load_state_dict(state_dict, strict) - @torch.no_grad() - def _init_full_param(self) -> None: - """Set p._full_param if not already set.""" - first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] - for p in first_time_params: - self._init_param(p) + def _reset_lazy_init(self) -> None: + """Reset instance so :func:`_lazy_init` will run on the next forward.""" + self._is_root: Optional[bool] = None + self._streams: Dict[str, torch.cuda.Stream] = {} + + def _lazy_init(self) -> None: + """Initialization steps that should happen lazily, typically right + before the first forward pass.""" + # Initialize param attributes lazily, in case the param's dtype or + # device changes after __init__. + for p in self.params: + self._init_param_attributes(p) + + # Initialize _is_root and setup streams. These steps would ideally + # happen in __init__, but _is_root can only be determined after the + # entire model hierarchy is setup, thus we run it lazily. + if self._is_root is None: + self._set_is_root() + self._setup_streams() @torch.no_grad() - def _init_param(self, p: Parameter) -> None: - assert p._is_sharded - assert not hasattr(p, "_full_param") + def _init_param_attributes(self, p: Parameter) -> None: + """ + We manage several attributes on each Parameter instance. The first two + are set by :func:`_shard_parameters_`: + + ``_is_sharded``: ``True`` after the Parameter is initially sharded + ``_orig_size``: the size of the original Parameter (before sharding) + + The remaining attributes are set here: + ``_fp32_shard``: a single shard of the parameters in full precision + (typically FP32, but this is dependent on the dtype of the model + as it's passed in by the user). This can be on CPU or GPU + depending on the value of *``cpu_offload``*. + ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be + a single shard of the parameters in FP16, used for all-gather. + ``_full_param``: the full weight, used for computation in the + forward/backward pass. This will be resized in place and only + materialized (via all-gather) as needed. + """ + assert p._is_sharded and hasattr(p, "_orig_size") + if hasattr(p, "_full_param"): + return + + # Compute device defaults to CUDA when *cpu_offload* is enabled, or the + # param's current device otherwise (could be CPU). + compute_device = torch.device("cuda") if self.cpu_offload else p.device - # _fp32_shard will correspond to a single shard of the parameters in - # full precision (typically FP32, but this is dependent on the dtype of - # the model as it's passed in by the user). + # A single shard of the parameters in full precision. p._fp32_shard = p.data if self.mixed_precision: @@ -283,16 +316,16 @@ def _init_param(self, p: Parameter) -> None: assert p._fp32_shard.device == torch.device("cpu") # If we plan to keep the FP32 parameters on CPU, then pinning # memory allows us to later use non-blocking transfers when moving - # the FP32 param shard to self.compute_device. + # the FP32 param shard to compute_device. p._fp32_shard = p._fp32_shard.pin_memory() p.data = p._fp32_shard # In mixed precision mode, we maintain a reduced precision - # (typically FP16) parameter shard on self.compute_device for - # performing the computation in the forward/backward pass. We resize - # the storage to size 0 at init (here) and re-materialize - # (by copying from _fp32_shard) as needed. - p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) + # (typically FP16) parameter shard on compute_device for performing + # the computation in the forward/backward pass. We resize the + # storage to size 0 at init (here) and re-materialize (by copying + # from _fp32_shard) as needed. + p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype) free_storage_(p._fp16_shard) else: p._fp16_shard = None # use _fp32_shard @@ -300,7 +333,7 @@ def _init_param(self, p: Parameter) -> None: # We also maintain a full-sized parameter of type self.compute_dtype # (FP16 for mixed_precision or FP32 otherwise). We resize the # storage to size 0 at init (here) and only materialize as needed. - p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) + p._full_param = torch.zeros(p._orig_size, device=compute_device, dtype=self.compute_dtype) free_storage_(p._full_param) if self.move_grads_to_cpu: @@ -309,33 +342,9 @@ def _init_param(self, p: Parameter) -> None: # shard in pinned memory so that we can do a non-blocking transfer. p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() - # Stream for overlapping the backward pass and gradient reductions. - p._post_backward_stream = torch.cuda.Stream() - - @torch.no_grad() - def _pre_forward_init(self) -> None: - self._set_is_root() - if not self.cpu_offload: - self._move_fp32_shard_to_cuda() - if self._is_root and not hasattr(self, "_fp32_to_fp16_stream"): - # Stream for moving FP32 master params (which may be on CPU) to - # FP16 for computation. We share this stream with all children - # instances, which allows them to overlap transfers across the - # forward pass without synchronizing with the default stream. - self._fp32_to_fp16_stream = torch.cuda.Stream() - - for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardParamsDataParallel): - m._fp32_to_fp16_stream = self._fp32_to_fp16_stream - - assert self._is_root is not None - if self._is_root: - # The top-most (root) instance needs to synchronize with the default - # stream, so we don't move the FP32 master weights prematurely. - self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) - def _set_is_root(self) -> None: - """If true, implies that no other ShardParamsDataParallel instance. Called by forward.""" + """If ``True``, implies that no other :class:`ShardParamsDataParallel` + instance wraps this one. Called once by :func:`_lazy_init`.""" if self._is_root is not None: return # No ShardParamsDataParallel instance wraps this, else _is_root would be set to False @@ -346,13 +355,42 @@ def _set_is_root(self) -> None: assert m._is_root is None m._is_root = False + def _setup_streams(self) -> None: + """Create streams to overlap data transfer and computation.""" + if len(self._streams) > 0 or not self._is_root: + return + # Stream to move main FP32 params (may be on CPU) to FP16 for forward. + self._streams["fp32_to_fp16"] = torch.cuda.Stream() + # Stream for overlapping grad reduction with the backward pass. + self._streams["post_backward"] = torch.cuda.Stream() + # We share streams with all children instances, which allows them to + # overlap transfers across the forward pass without synchronizing with + # the default stream. + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + m._streams = self._streams + + def _wait_for_previous_optim_step(self) -> None: + """ + The outer-most :class:`ShardParamsDataParallel` instance (i.e., the root + instance) needs to synchronize with the default stream to ensure the + previous optimizer step is done. + """ + self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - self._pre_forward_init() + self._lazy_init() + + # Due to the use of streams, we need to make sure the previous + # ``optim.step()`` is done before we all-gather parameters. + if self._is_root: + self._wait_for_previous_optim_step() if self.mixed_precision: args, kwargs = cast_inputs_to_fp16(*args, **kwargs) - # All-gather full parameters. + # All-gather full parameters. This will also transfer FP32 parameters to + # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). self._rebuild_full_params() # Register backward hooks to reshard params and reduce-scatter grads. @@ -364,8 +402,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self.reshard_after_forward: self._free_full_params() - # Switch to FP32 param shard after forward so that the optimizer will be - # initialized with the correct dtype and size. + # Switch to main FP32 param shard. We maintain this invariant throughout + # the code, i.e., ``p.data == p._fp32_shard`` after each function. This + # also ensures that after the first forward, the optimizer state will be + # initialized with the correct dtype and (sharded) size, since optimizer + # state is typically initialized lazily in ``optim.step()``. self._use_fp32_param_shard() if torch.is_grad_enabled(): @@ -374,8 +415,8 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: return outputs def _register_pre_backward_hooks(self, outputs: Any) -> Any: - - # Register pre-backward hook to run before the wrapped module's backward. + """Register pre-backward hook to run before the wrapped module's + backward. Hooks should be attached to all outputs from the forward.""" pre_backward_hook_has_run = [False] def _pre_backward_hook(*unused: Any) -> None: @@ -398,7 +439,7 @@ def _register_hook(t: torch.Tensor) -> torch.Tensor: return outputs def _register_post_backward_hooks(self) -> None: - # Register backward hooks to reshard params and reduce-scatter grads. + """Register backward hooks to reshard params and reduce-scatter grads.""" if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled self._post_backward_callback_queued = False @@ -414,10 +455,10 @@ def _register_post_backward_hooks(self) -> None: @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: """ - At the start of _post_backward_hook, param.grad contains the full - gradient for the local batch. The reduce_scatter op will replace - param.grad with a single shard of the summed gradient across all GPUs. - This shard will align with the current GPU rank. For example:: + At the start of :func:`_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will replace + ``param.grad`` with a single shard of the summed gradient across all + GPUs. This shard will align with the current GPU rank. For example:: before reduce_scatter: param.grad (GPU #0): [1, 2, 3, 4] @@ -427,10 +468,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param.grad (GPU #0): [6, 8] # 1+5, 2+6 param.grad (GPU #1): [10, 12] # 3+7, 4+8 - The local GPU's optimizer.step is responsible for updating a single + The local GPU's ``optim.step`` is responsible for updating a single shard of params, also corresponding to the current GPU's rank. This - alignment is created by _shard_parameters_, which ensures that the - local optimizer only sees the relevant single parameter shard. + alignment is created by :func:`_shard_parameters_`, which ensures that + the local optimizer only sees the relevant parameter shard. """ if param.grad is None: return @@ -447,9 +488,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self._free_fp16_param_shard([param]) # Wait for all work in the current stream to finish, then start the - # reductions in _post_backward_stream. - param._post_backward_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(param._post_backward_stream): + # reductions in post_backward stream. + self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._streams["post_backward"]): if self.mixed_precision and self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) @@ -479,8 +520,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @torch.no_grad() def _wait_for_post_backward(self) -> None: """Wait for all post-backward work to finish.""" - for p in self.params: - torch.cuda.current_stream().wait_stream(p._post_backward_stream) + if self._is_root: + torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() @@ -488,7 +529,7 @@ def _wait_for_post_backward(self) -> None: @torch.no_grad() def _rebuild_full_params(self) -> None: """Gather all shards of params.""" - if self.mixed_precision and hasattr(self, "_fp32_to_fp16_stream"): + if self.mixed_precision and len(self._streams) > 0: self._cast_fp32_param_shards_to_fp16() for p in self.params: @@ -501,7 +542,7 @@ def _rebuild_full_params(self) -> None: p.data = p._full_param p.grad = None - if self.mixed_precision and hasattr(self, "_fp32_to_fp16_stream"): + if self.mixed_precision and len(self._streams) > 0: self._free_fp16_param_shard([p]) @torch.no_grad() @@ -526,12 +567,6 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: p._full_param.record_stream(current_stream) free_storage_(p._full_param) - @torch.no_grad() - def _move_fp32_shard_to_cuda(self) -> None: - assert not self.cpu_offload - for p in self.params: - p._fp32_shard = p._fp32_shard.to(p.data.device) - @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Use FP32 shard for a list of params.""" @@ -545,13 +580,13 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No """Cast FP32 param shard to FP16 for a list of params.""" if params is None: params = self.params - with torch.cuda.stream(self._fp32_to_fp16_stream): + with torch.cuda.stream(self._streams["fp32_to_fp16"]): for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) p._fp16_shard.copy_(p._fp32_shard, non_blocking=True) p.data = p._fp16_shard - torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream) + torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) @torch.no_grad() def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None: @@ -567,11 +602,12 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No free_storage_(p._fp16_shard) @torch.no_grad() - def _reduce_scatter(self, tensor: torch.Tensor, output: Optional[torch.Tensor] = None) -> torch.Tensor: + def _reduce_scatter(self, tensor: torch.Tensor) -> torch.Tensor: + """Reduce-scatter a Tensor (gradient from the local worker) and return + the result (a single shard of the summed gradient across workers).""" assert tensor.numel() % self.world_size == 0 tensor = tensor.view(self.world_size, -1) - if output is None: - output = torch.zeros_like(tensor[0]) + output = torch.zeros_like(tensor[0]) dist.reduce_scatter(output, list(tensor.unbind(0)), group=self.process_group) return output diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 2e395a362..7c44e33a7 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -476,3 +476,15 @@ def get_cycles_per_ms() -> float: end.synchronize() cycles_per_ms = 1000000 / start.elapsed_time(end) return cycles_per_ms + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 7cd76d9d2..e9515b224 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -14,7 +14,6 @@ class Parameter(Tensor): _full_param: Tensor _fp32_shard: Tensor _fp16_shard: Optional[Tensor] - _post_backward_stream: Stream def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 5ca98c8f1..dc3b2ed0e 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -5,6 +5,7 @@ import functools import itertools +import pickle import sys from typing import Dict import unittest @@ -17,6 +18,7 @@ from fairscale.nn.data_parallel import ShardParamsDataParallel from fairscale.utils.testing import ( DeviceAndTypeCheckModule, + DummyProcessGroup, dist_init, get_cycles_per_ms, objects_are_equal, @@ -241,6 +243,81 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") +class TestParamInit(DistributedTest): + def test_param_change_after_init(self): + test_fn = functools.partial(self._test_param_change_after_init, config={"mixed_precision": True}) + spawn_and_init(test_fn) + + @classmethod + def _test_param_change_after_init(self, rank, group, config): + # Establish reference behavior. + model = self.get_wrapped_model(group, cuda_first=False, config=config) + model.eval() # no dropout for this test + input = model.module.get_input(torch.device("cuda")) + ref_output = model(*input) + + # Change the weights in place. + model = self.get_wrapped_model(group, cuda_first=False, config=config) + model.eval() # no dropout for this test + first_param = next(model.parameters()) + nn.init.normal_(first_param.data) + new_output = model(*input) + + assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" + + +class TestSerialization(DistributedTest): + @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) + def test_pickle(self, mixed_precision, cpu_offload): + """Ensure that wrapped modules can be pickled/unpickled.""" + config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} + test_fn = functools.partial(self._test_pickle, config=config) + spawn_and_init(test_fn, world_sizes=[2]) + + @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) + def test_multiprocessing(self, mixed_precision, cpu_offload): + """Ensure that wrapped modules can be sent via multiprocessing.""" + config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} + test_fn = functools.partial(self._test_multiprocessing, config=config) + spawn_and_init(test_fn, world_sizes=[2]) + + @classmethod + def _test_pickle(self, rank, group, config): + model = self._get_model(group, config) + model = pickle.loads(pickle.dumps(model)) + if not config["cpu_offload"]: + model = model.cuda() + self._one_step(model, group) + + @classmethod + def _test_multiprocessing(self, rank, group, config): + mp = torch.multiprocessing.Pool(1) + dummy_group = DummyProcessGroup(rank=group.rank(), size=group.size()) + model = mp.apply(self._get_model, (dummy_group, config)) + if not config["cpu_offload"]: + model = model.cuda() + self._one_step(model, group) + + @classmethod + def _get_model(self, group, config): + with torch.no_grad(): # required for multiprocessing + model = NestedWrappedModule(group, wrapper_config=config) + return ShardParamsDataParallel(model, group, **config) + + @classmethod + def _one_step(self, model, group): + # reset the process group (required after unpickling) + for m in model.modules(): + if isinstance(m, ShardParamsDataParallel): + m.process_group = group + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + input = model.module.get_input(torch.device("cuda")) + output = model(*input) + loss = model.module.get_loss(input, output) + model.module.run_backward(loss) + optim.step() + + class TestLocalStateDict(DistributedTest): @parameterized.expand([[True, True], [False, False]], name_func=rename_test) def test_load_local_state_dict(self, flatten_params, mixed_precision): @@ -323,9 +400,7 @@ def test_state_dict_before_forward(self, mixed_precision): @classmethod def _test_state_dict_before_forward(cls, config, rank, group): ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) - assert not hasattr(ddp_model, "_fp32_to_fp16_stream") sd = ddp_model.state_dict() - assert not hasattr(ddp_model, "_fp32_to_fp16_stream") expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32 wt = sd["embed_tokens.weight"] assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}" From bebe7fd6b3c3143d31f09171646d0f1d53ac1ac1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 8 Feb 2021 08:01:55 -0800 Subject: [PATCH 29/48] Add test for NoGrad mode --- .../shard_params_data_parallel.py | 8 ++++-- .../test_shard_params_data_parallel.py | 27 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 10d9fb8f0..7f24ee5e4 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -409,14 +409,18 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # state is typically initialized lazily in ``optim.step()``. self._use_fp32_param_shard() - if torch.is_grad_enabled(): - outputs = self._register_pre_backward_hooks(outputs) + # Register pre-backward hooks to all-gather the params for the backward + # pass (if needed). + outputs = self._register_pre_backward_hooks(outputs) return outputs def _register_pre_backward_hooks(self, outputs: Any) -> Any: """Register pre-backward hook to run before the wrapped module's backward. Hooks should be attached to all outputs from the forward.""" + if not torch.is_grad_enabled(): + return outputs # don't register hooks if grad isn't enabled + pre_backward_hook_has_run = [False] def _pre_backward_hook(*unused: Any) -> None: diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index dc3b2ed0e..88d78f0de 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -482,6 +482,33 @@ def _test_register_functions_called(self, rank, group, cuda_first=False): assert model._register_pre_backward_hooks.called +class TestNoGrad(DistributedTest): + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_transformer_parameterized(self, config): + test_fn = functools.partial(self._test_transformer, config=config) + spawn_and_init(test_fn) + + @classmethod + def _test_transformer(self, rank, group, config): + autocast = config["mixed_precision"] + + # Train model for a step + model = self.get_wrapped_model(group, cuda_first=False, config=config) + self._train_for_several_steps(model, 1, autocast) + + model.eval() # no dropout for this test + + # Eval in standard mode (i.e., without no_grad) + input = model.module.get_input(torch.device("cuda")) + ref_output = model(*input) + + # Eval with no_grad and compare + with torch.no_grad(): + no_grad_output = model(*input) + + assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output" + + class TransformerWithSharedParams(nn.Module): def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs): super().__init__() From 74b0223b605c21e45ae9c626997b64f92cee025c Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 8 Feb 2021 16:13:37 -0500 Subject: [PATCH 30/48] Add all_gather stream and disable reshard_after_forward on root instance (#75) --- .../shard_params_data_parallel.py | 73 ++++++++++++------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 7f24ee5e4..9676db895 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -58,7 +58,8 @@ class ShardParamsDataParallel(nn.Module): module (nn.Module): module to checkpoint process_group (Optional): process group for sharding reshard_after_forward (bool, Optional): if ``True``, reshard parameters - after the forward pass. This saves memory but slows training. + after the forward pass. This saves memory but slows training. This + is only relevant when resharding individual layers. mixed_precision (bool, Optional): if ``True``, inputs, activations and gradients will be kept in FP16; computation and communication will occur in FP16; and a (sharded) master copy of the model weights will @@ -278,6 +279,11 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() + # Don't free the full params for the outer-most (root) instance, since + # those params will be needed immediately after for the backward pass. + if self._is_root: + self.reshard_after_forward = False + @torch.no_grad() def _init_param_attributes(self, p: Parameter) -> None: """ @@ -361,6 +367,8 @@ def _setup_streams(self) -> None: return # Stream to move main FP32 params (may be on CPU) to FP16 for forward. self._streams["fp32_to_fp16"] = torch.cuda.Stream() + # Stream for all-gathering parameters. + self._streams["all_gather"] = torch.cuda.Stream() # Stream for overlapping grad reduction with the backward pass. self._streams["post_backward"] = torch.cuda.Stream() # We share streams with all children instances, which allows them to @@ -376,7 +384,10 @@ def _wait_for_previous_optim_step(self) -> None: instance) needs to synchronize with the default stream to ensure the previous optimizer step is done. """ - self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) + if self.mixed_precision: + self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) + else: + self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._lazy_init() @@ -511,13 +522,16 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.mixed_precision: param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Optionally move gradients to CPU, typically used if one is running + # the optimizer on the CPU. if self.move_grads_to_cpu: param._cpu_grad.copy_(param.grad.data, non_blocking=True) param.grad.data = param._cpu_grad # Enqueue a callback at the end of the backward pass to ensure that all - # post-backward work has finished. We only need one callback. - if not self._post_backward_callback_queued: + # post-backward work has finished. We only need one callback and it only + # needs to be called from the outer-most (root) instance. + if self._is_root and not self._post_backward_callback_queued: self._post_backward_callback_queued = True Variable._execution_engine.queue_callback(self._wait_for_post_backward) @@ -533,21 +547,23 @@ def _wait_for_post_backward(self) -> None: @torch.no_grad() def _rebuild_full_params(self) -> None: """Gather all shards of params.""" - if self.mixed_precision and len(self._streams) > 0: - self._cast_fp32_param_shards_to_fp16() + with torch.cuda.stream(self._streams["all_gather"]): + if self.mixed_precision: + self._cast_fp32_param_shards_to_fp16() - for p in self.params: - if p._full_param.storage().size() != p._orig_size.numel(): - # All-gather parameters - alloc_storage_(p._full_param, size=p._orig_size) - output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) - dist.all_gather(output_list, p.data, group=self.process_group) + for p in self.params: + if p._full_param.storage().size() != p._orig_size.numel(): + # All-gather parameters + alloc_storage_(p._full_param, size=p._orig_size) + output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) + dist.all_gather(output_list, p.data, group=self.process_group) - p.data = p._full_param - p.grad = None + p.data = p._full_param + p.grad = None - if self.mixed_precision and len(self._streams) > 0: - self._free_fp16_param_shard([p]) + if self.mixed_precision: + self._free_fp16_param_shard([p]) + torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) @torch.no_grad() def _use_full_params(self) -> None: @@ -561,15 +577,16 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: if params is None: params = self.params current_stream = torch.cuda.current_stream() - for p in params: - # There may be external references to the Tensor Storage that we - # can't modify, such as references that are created by - # ctx.save_for_backward in the forward pass. Thus when we unshard - # parameters, we should reuse the original Tensor Storage object - # and unshard it in-place. For now, just resize the Storage to 0 to - # save memory. - p._full_param.record_stream(current_stream) - free_storage_(p._full_param) + with torch.cuda.stream(self._streams["all_gather"]): + for p in params: + # There may be external references to the Tensor Storage that we + # can't modify, such as references that are created by + # ctx.save_for_backward in the forward pass. Thus when we + # unshard parameters, we should reuse the original Tensor + # Storage object and unshard it in-place. For now, just resize + # the Storage to 0 to save memory. + p._full_param.record_stream(current_stream) + free_storage_(p._full_param) @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: @@ -588,7 +605,11 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - p._fp16_shard.copy_(p._fp32_shard, non_blocking=True) + p._fp16_shard.copy_( + # If cpu_offload is True, this will be non-blocking because + # _fp32_shard is pinned, otherwise it's a no-op. + p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) + ) p.data = p._fp16_shard torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) From 679796434a294f039de5e91603b078d07cf0feb3 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 9 Feb 2021 12:54:38 -0500 Subject: [PATCH 31/48] Leave buffers on self.compute_device (#67) --- .../shard_params_data_parallel.py | 27 ++++++++++++----- .../test_shard_params_data_parallel.py | 30 ++++++++++++------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 9676db895..d73e1399d 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -125,16 +125,18 @@ def __init__( # Shard module parameters in place self._shard_parameters_() - if self.mixed_precision: - # Cast all module buffers to FP16 (buffers are not sharded). - self.apply(cast_buffers_to_fp16) - # Make sure all parameters are sharded. for n, p in self.named_parameters(): assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" self._reset_lazy_init() + @torch.no_grad() + def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + """Move all buffers to the specified device and dtype, recursively.""" + cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) + self.apply(cast_fn) + @torch.no_grad() def _shard_parameters_(self) -> None: """ @@ -217,17 +219,19 @@ def state_dict(self, *args, **kwargs): # type: ignore """ Returns the whole (unsharded) state of the module. Parameters are not sharded, so the resulting state_dict can be loaded directly by the - wrapped Module without any sharding-specific logic. + wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32 """ torch.cuda.synchronize() self._lazy_init() self._rebuild_full_params() + self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters. state_dict = self.module.state_dict(*args, **kwargs) # We don't free the params after generating the state dict, since # freeing is done in-place (via the Storage) and would corrupt the # returned state dict. However, we need to maintain the invariant that # p.data corresponds to the FP32 param shard, so we do that here. self._use_fp32_param_shard() + self._all_buffers_to(dtype=self.compute_dtype) return state_dict # TODO (Min): figuring out how to do typing for this overloaded function. @@ -278,6 +282,10 @@ def _lazy_init(self) -> None: if self._is_root is None: self._set_is_root() self._setup_streams() + if self.cpu_offload: # Buffers stay on GPU, and dont get sharded + self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) + else: + self._all_buffers_to(dtype=self.compute_dtype) # Don't free the full params for the outer-most (root) instance, since # those params will be needed immediately after for the backward pass. @@ -652,11 +660,14 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: return args, kwargs -def cast_buffers_to_fp16(module: nn.Module) -> None: - """Cast buffers of a module to FP16.""" +def cast_buffers_( + module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None +) -> None: + """Cast all of module.named_buffers to device, dtype.""" + # if buffers are already on the right device and/or dtype this is just python loop cost for key, buf in module.named_buffers(recurse=False): if buf is not None: - setattr(module, key, buf.half()) + setattr(module, key, buf.to(dtype=dtype, device=device)) def free_storage_(data: torch.Tensor) -> None: diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 88d78f0de..94956f7f7 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -27,6 +27,8 @@ # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 +_BUFFER_NAME = "vocab_bias" + class DistributedTest(unittest.TestCase): def setUp(self): @@ -42,9 +44,9 @@ def setUp(self): raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") @staticmethod - def _train_for_several_steps(model, num_steps, autocast): + def _train_for_several_steps(model, num_steps, autocast, lr=0.01): model_device = next(model.parameters()).device - optim = torch.optim.Adam(model.parameters(), lr=0.01) + optim = torch.optim.Adam(model.parameters(), lr=lr) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -178,7 +180,11 @@ def test_cpu_offload_and_cpu_grads(self): # We don't test the False condition because that requires the optimizer to internally do # the device transfer and PyTorch optimizers don't support this. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.001 + ) + # We use lower lr to reduce this test's sensitivity to slightly different CPU vs CUDA behavior of pytorch. + # With lr=0.01, it fails on torch 1.6.0. spawn_and_init(test_fn) def test_cpu_offload_and_cuda_grads_breaks(self): @@ -210,7 +216,7 @@ def test_delayed_reduce_scatter(self): spawn_and_init(test_fn) @classmethod - def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): + def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True, lr=0.01): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -224,7 +230,7 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) - ref_loss = cls._train_for_several_steps(model, num_steps, autocast) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. @@ -233,14 +239,14 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") - shard_loss = cls._train_for_several_steps(model, num_steps, autocast) + shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) shard_state_dict = model.state_dict() try: torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: - raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") + raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") class TestParamInit(DistributedTest): @@ -332,7 +338,7 @@ def test_local_state_dict_odd_vocab_shape_breaks(self): spawn_and_init(test_fn) @classmethod - def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): + def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=16): """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" model = ShardParamsDataParallel( TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config @@ -346,11 +352,11 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): state_1_weight = state_1[weight_key] assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32" if not model.flatten_parameters: - # This weight will be sharded since we access module.state_dict directly + # The weight will be sharded since we access module.state_dict directly state_1_module_weight = model.module.state_dict()[weight_key] torch.testing.assert_allclose(state_1_weight, state_1_module_weight) torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) - self._train_for_several_steps(model, 4, model.mixed_precision) + self._train_for_several_steps(model, 1, model.mixed_precision) state_2 = model.local_state_dict() state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} @@ -361,7 +367,7 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): # Assert that parameters were updated since before training unchanged = [] for k in state_1: - if (state_before_training[k] == state_after_training[k]).all(): + if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k): unchanged.append(k) if unchanged: raise AssertionError(f"params {unchanged} not changed after training") @@ -520,6 +526,7 @@ def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs): self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights self.output_proj.weight = self.embed_tokens.weight + self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,))) def get_input(self, device): torch.manual_seed(1) # keep everything deterministic @@ -529,6 +536,7 @@ def get_input(self, device): def forward(self, src_ids, tgt_ids): src = self.embed_tokens(src_ids) + src = src + self.vocab_bias tgt = self.embed_tokens(tgt_ids) x = self.transformer(src, tgt) return self.output_proj(x) From 0937594d0f3d693105e57a524bf95ff4d3a87d45 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 9 Feb 2021 18:46:15 -0500 Subject: [PATCH 32/48] Pad parameters inside: right before gather, scatter (#76) --- .../shard_params_data_parallel.py | 58 ++++++++++++++++--- .../test_shard_params_data_parallel.py | 19 +++--- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index d73e1399d..223636bdc 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -5,6 +5,7 @@ import copy import functools +import math from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import torch @@ -13,6 +14,7 @@ from torch.distributed import ProcessGroup import torch.nn as nn from torch.nn import Parameter +import torch.nn.functional as F from fairscale.nn.misc import FlattenParamsWrapper from fairscale.utils.containers import ( @@ -170,15 +172,48 @@ def _shard_parameters_(self) -> None: p._is_sharded = True p._orig_size = p.data.size() - - shard_size = p.data.numel() // self.world_size + # shard p.data such that all elements are part of a shard and the last shard is <= all other shards + shard_size = math.ceil(p.data.numel() / self.world_size) s = self.rank * shard_size - e = (self.rank + 1) * shard_size + e = min(s + shard_size, p.data.numel()) orig_data = p.data p.data = torch.flatten(p.data)[s:e].clone() free_storage_(orig_data) + @property + def is_last_rank(self) -> bool: + return self.rank == (self.world_size - 1) + + def _calc_num_to_pad(self, numel: int) -> int: + num_extra = numel % self.world_size + num_to_pad = self.world_size - num_extra if num_extra > 0 else 0 + return num_to_pad + + @torch.no_grad() + def _all_gather_full_param(self, p: nn.Parameter) -> None: + """Fill p.full_param with gathered p.data values (using torch.distributed.all_gather). + If the last shard is smaller than the other shards, we pad it with zeroes. + """ + full_param_chunks = list(torch.flatten(p._full_param).chunk(self.world_size)) + # We overwrite the last chunk with zeroes if it is smaller than the other entries. + num_to_pad = self._calc_num_to_pad(p._orig_size.numel()) + pointer_to_last_chunk = full_param_chunks[-1] + assert pointer_to_last_chunk.numel() + num_to_pad == full_param_chunks[0].numel() + + param_shard = p.data # we will gather this from each worker + if num_to_pad > 0: # add padding to param_shard and full_param_chunks[-1] + full_param_chunks[-1] = torch.zeros_like(full_param_chunks[0]) # no longer shares memory with full_param + if self.is_last_rank: + param_shard = F.pad(p.data, [0, num_to_pad]) + dist.all_gather(full_param_chunks, param_shard, group=self.process_group) + # ^ updates p._full_param + + # remove padding from full_param_chunks[-1] and p.data + if num_to_pad > 0: # copy shard associated with the padded chunk to full_param + pointer_to_last_chunk.copy_(full_param_chunks[-1][:-num_to_pad]) + # ^ updates p._full_param + def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" try: @@ -522,7 +557,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param.grad.data.div_(self.world_size) # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) + param.grad.data = self._reduce_scatter(param.grad.data) # Cast grad to param's dtype (typically FP32). Note: we do this # before the move_grads_to_cpu step so that this entire hook remains @@ -561,10 +596,8 @@ def _rebuild_full_params(self) -> None: for p in self.params: if p._full_param.storage().size() != p._orig_size.numel(): - # All-gather parameters alloc_storage_(p._full_param, size=p._orig_size) - output_list = list(torch.flatten(p._full_param).chunk(self.world_size)) - dist.all_gather(output_list, p.data, group=self.process_group) + self._all_gather_full_param(p) # Fill p._full_param with (p.data for each shard in self.world_size) p.data = p._full_param p.grad = None @@ -638,10 +671,17 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No def _reduce_scatter(self, tensor: torch.Tensor) -> torch.Tensor: """Reduce-scatter a Tensor (gradient from the local worker) and return the result (a single shard of the summed gradient across workers).""" + tensor = torch.flatten(tensor) + num_to_pad = self._calc_num_to_pad(tensor.numel()) + if num_to_pad > 0: # pad the gradient to be divisible by world_size + tensor = F.pad(tensor, [0, num_to_pad]) assert tensor.numel() % self.world_size == 0 tensor = tensor.view(self.world_size, -1) - output = torch.zeros_like(tensor[0]) - dist.reduce_scatter(output, list(tensor.unbind(0)), group=self.process_group) + output = torch.zeros_like(tensor[0]) # filled with gradient summed across workers + to_scatter = list(tensor.unbind(0)) # world size tensors of shape (shard_size,) + dist.reduce_scatter(output, to_scatter, group=self.process_group) + if self.is_last_rank and num_to_pad > 0: + output = output[:-num_to_pad] return output diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 94956f7f7..6814f3796 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -26,6 +26,7 @@ ) # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 +# All helper functions called by spawn must be either @classmethod, @staticmethod _BUFFER_NAME = "vocab_bias" @@ -332,17 +333,10 @@ def test_load_local_state_dict(self, flatten_params, mixed_precision): ) spawn_and_init(test_fn) - def test_local_state_dict_odd_vocab_shape_breaks(self): - test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}, d_model=16, d_vocab=37) - with self.assertRaises(Exception): - spawn_and_init(test_fn) - @classmethod - def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=16): + def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23): """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" - model = ShardParamsDataParallel( - TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config - ).cuda() + model = self.get_wrapped_model(group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model) state_1 = model.local_state_dict() state_before_training = {k: v.cpu().clone() for k, v in state_1.items()} assert len(state_1) > 0 @@ -382,8 +376,8 @@ def test_calling_state_dict_twice_mixed_precision(self, mixed_precision): spawn_and_init(test_fn) @classmethod - def _test_calling_state_dict_twice(self, config, rank, group): - ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config) + def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs): + ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs) autocast = ddp_model.mixed_precision self._train_for_several_steps(ddp_model, 1, autocast) ddp_model.state_dict() @@ -516,9 +510,10 @@ def _test_transformer(self, rank, group, config): class TransformerWithSharedParams(nn.Module): - def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs): + def __init__(self, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): super().__init__() torch.manual_seed(0) # keep everything deterministic + assert d_vocab >= 12 # we use torch.arange(12) as input self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, From 93670cbdaf9a7827f57b12fc88c778a18c136e3b Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 10 Feb 2021 10:25:05 -0500 Subject: [PATCH 33/48] Add no_sync() context manager (#77) --- .../shard_params_data_parallel.py | 49 +++++++++++-- .../test_shard_params_data_parallel.py | 69 ++++++++++++++++++- 2 files changed, 110 insertions(+), 8 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 223636bdc..5cc08452b 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import contextlib import copy import functools import math -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import torch from torch.autograd import Variable @@ -133,6 +134,10 @@ def __init__( self._reset_lazy_init() + # Flag to indicate if we require gradient reduction in the backward + # pass. This will be False when inside the no_sync context manager. + self.require_backward_grad_sync: bool = True + @torch.no_grad() def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: """Move all buffers to the specified device and dtype, recursively.""" @@ -298,6 +303,27 @@ def load_local_state_dict( """Load a local (sharded) state_dict.""" return self.module.load_state_dict(state_dict, strict) + @contextlib.contextmanager + def no_sync(self) -> Generator: + """ + A context manager to disable gradient synchronizations across DDP + processes. Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + """ + # This instance may wrap other ShardParamsDataParallel instances and we + # need to set all of them to accumulate gradients. + old_flags = [] + for m in self.modules(): # includes self + if isinstance(m, ShardParamsDataParallel): + old_flags.append((m, m.require_backward_grad_sync)) + m.require_backward_grad_sync = False + try: + yield + finally: + for m, old_flag in old_flags: + m.require_backward_grad_sync = old_flag + def _reset_lazy_init(self) -> None: """Reset instance so :func:`_lazy_init` will run on the next forward.""" self._is_root: Optional[bool] = None @@ -481,11 +507,13 @@ def _pre_backward_hook(*unused: Any) -> None: if pre_backward_hook_has_run[0]: return # only run once pre_backward_hook_has_run[0] = True - + # All-gather full parameters. if self.reshard_after_forward: self._rebuild_full_params() else: self._use_full_params() + # Make sure p.grad has the correct size/device (or set it to None). + self._prep_grads_for_backward() def _register_hook(t: torch.Tensor) -> torch.Tensor: t.register_hook(_pre_backward_hook) @@ -545,6 +573,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # pre_backward_hook. self._free_fp16_param_shard([param]) + if not self.require_backward_grad_sync: + return + # Wait for all work in the current stream to finish, then start the # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) @@ -580,9 +611,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @torch.no_grad() def _wait_for_post_backward(self) -> None: - """Wait for all post-backward work to finish.""" - if self._is_root: - torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) + """Wait for post-backward work to finish. Only called on root instance.""" + assert self._is_root + torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() @@ -600,7 +631,6 @@ def _rebuild_full_params(self) -> None: self._all_gather_full_param(p) # Fill p._full_param with (p.data for each shard in self.world_size) p.data = p._full_param - p.grad = None if self.mixed_precision: self._free_fp16_param_shard([p]) @@ -612,6 +642,13 @@ def _use_full_params(self) -> None: assert p._full_param.storage().size() != 0 p.data = p._full_param + @torch.no_grad() + def _prep_grads_for_backward(self) -> None: + """Make sure p.grad has the correct size/device, otherwise set it to None.""" + for p in self.params: + if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device): + p.grad = None + @torch.no_grad() def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: """Free up storage for full parameters.""" diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 6814f3796..dfd7abc95 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -456,8 +456,6 @@ def _test_output_backward_hooks(self, rank, group, cuda_first=False, model=None) output = model(*input) assert len(output._backward_hooks) == 1 # this is pre-bwd hook loss = model.module.get_loss(input, output).cuda() - for p in model.params: - assert p.grad is None # because of pre_backward_hook loss.backward() assert len(output._backward_hooks) == 1 # It doesn't get removed optim.step() @@ -509,6 +507,73 @@ def _test_transformer(self, rank, group, config): assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output" +class TestNoSync(DistributedTest): + def test_transformer(self): + fn = functools.partial(self._test_transformer, config={}) + spawn_and_init(fn) + + def test_transformer_no_flat_params(self): + config = {"flatten_parameters": False} + fn = functools.partial(self._test_transformer, config=config) + spawn_and_init(fn) + + def test_nested_wrapper(self): + fn = functools.partial(self._test_nested_wrapper, config={}) + spawn_and_init(fn) + + @classmethod + def _test_transformer(self, rank, group, config): + model = self.get_wrapped_model(group, config=config) + model.eval() # turn off dropout for the test + self._test_no_sync(model, batch_dim=1) + + @classmethod + def _test_nested_wrapper(self, rank, group, config): + model = NestedWrappedModule(group, config) + model = ShardParamsDataParallel(model, group, **config).cuda() + self._test_no_sync(model, batch_dim=0) + + @classmethod + def _test_no_sync(self, model, batch_dim): + # Generate two input batches. We'll test that we get the same grads if + # we train on them sequentially while accumulating grads (with no_sync) + # vs. concatenating the batches and training in one go. + batch1 = model.module.get_input(torch.device("cuda")) + assert isinstance(batch1, tuple) + batch2 = tuple( + # This randomly permutes the values in a multi-dim tensor. + x.view(-1)[torch.randperm(x.numel())].view_as(x) + for x in batch1 + ) + for x, y in zip(batch1, batch2): + assert not torch.all(x == y) + + # Concat the batches along batch dimension. + concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2)) + + # Establish reference behavior on the concat batch. + model.zero_grad() + output = model(*concat_batch) + ref_loss = model.module.get_loss(concat_batch, output) + ref_loss.backward() + ref_grads = [p.grad.detach().clone() for p in model.parameters()] + + # Test that we get the same results by accumulating grads. + model.zero_grad() + with model.no_sync(): # accumulate gradients from the first batch + output = model(*batch1) + loss1 = model.module.get_loss(batch1, output) + loss1.backward() + output = model(*batch2) + loss2 = model.module.get_loss(batch2, output) + loss2.backward() + accumulated_loss = loss1 + loss2 + accumulated_grads = [p.grad.detach().clone() for p in model.parameters()] + + torch.testing.assert_allclose(ref_loss, accumulated_loss) + assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True) + + class TransformerWithSharedParams(nn.Module): def __init__(self, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): super().__init__() From 1d0bf732b5e928f6760d8463f596af644af3b516 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 10 Feb 2021 16:35:20 -0500 Subject: [PATCH 34/48] rename --- fairscale/nn/data_parallel/__init__.py | 2 +- ...llel.py => fully_sharded_data_parallel.py} | 30 ++++++++-------- ...py => test_fully_sharded_data_parallel.py} | 34 +++++++++---------- 3 files changed, 33 insertions(+), 33 deletions(-) rename fairscale/nn/data_parallel/{shard_params_data_parallel.py => fully_sharded_data_parallel.py} (97%) rename tests/nn/data_parallel/{test_shard_params_data_parallel.py => test_fully_sharded_data_parallel.py} (95%) diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index c9f70e949..d119dfb00 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -3,5 +3,5 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from .shard_params_data_parallel import ShardParamsDataParallel +from .fully_sharded_data_parallel import FullyShardedDataParallel from .sharded_ddp import ShardedDataParallel diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py similarity index 97% rename from fairscale/nn/data_parallel/shard_params_data_parallel.py rename to fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 5cc08452b..43f390438 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -30,13 +30,13 @@ from collections import OrderedDict # noqa: F401 -class ShardParamsDataParallel(nn.Module): +class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters. Usage:: - sharded_module = ShardParamsDataParallel(my_module) + sharded_module = FullyShardedDataParallel(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor([1])) loss = x.sum() @@ -48,11 +48,11 @@ class ShardParamsDataParallel(nn.Module): reduce memory usage and to improve training speed by distributing the unsharding (all-gather) across the forward pass. For example:: - sharded_model = ShardParamsDataParallel( + sharded_model = FullyShardedDataParallel( nn.Sequential( nn.Linear(5, 100), - ShardParamsDataParallel(nn.Linear(100, 100)), - ShardParamsDataParallel(nn.Linear(100, 100)), + FullyShardedDataParallel(nn.Linear(100, 100)), + FullyShardedDataParallel(nn.Linear(100, 100)), nn.Linear(100, 5), ) ) @@ -227,7 +227,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self.module, name) def __getstate__(self) -> Dict[str, str]: - """Serialize the state of the current ShardParamsDataParallel instance. + """Serialize the state of the current FullyShardedDataParallel instance. Some properties are not serializable (e.g., process groups, streams), so we remove them and try to reconstruct them in :func:`__setstate__`. @@ -279,7 +279,7 @@ def local_state_dict(self, *args, **kwargs): # type: ignore """ Returns the local (sharded) state of the module. Parameters are sharded, so the resulting state_dict can only be loaded after the Module has been - wrapped with ShardParamsDataParallel. + wrapped with FullyShardedDataParallel. """ if self.flatten_parameters: return self.module.flat_state_dict(*args, **kwargs) # type: ignore @@ -311,11 +311,11 @@ def no_sync(self) -> Generator: variables, which will later be synchronized in the first forward-backward pass exiting the context. """ - # This instance may wrap other ShardParamsDataParallel instances and we + # This instance may wrap other FullyShardedDataParallel instances and we # need to set all of them to accumulate gradients. old_flags = [] for m in self.modules(): # includes self - if isinstance(m, ShardParamsDataParallel): + if isinstance(m, FullyShardedDataParallel): old_flags.append((m, m.require_backward_grad_sync)) m.require_backward_grad_sync = False try: @@ -418,15 +418,15 @@ def _init_param_attributes(self, p: Parameter) -> None: p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() def _set_is_root(self) -> None: - """If ``True``, implies that no other :class:`ShardParamsDataParallel` + """If ``True``, implies that no other :class:`FullyShardedDataParallel` instance wraps this one. Called once by :func:`_lazy_init`.""" if self._is_root is not None: return - # No ShardParamsDataParallel instance wraps this, else _is_root would be set to False + # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False self._is_root = True # As the root, we now set all children instances to False. for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardParamsDataParallel): + if n != "" and isinstance(m, FullyShardedDataParallel): assert m._is_root is None m._is_root = False @@ -444,12 +444,12 @@ def _setup_streams(self) -> None: # overlap transfers across the forward pass without synchronizing with # the default stream. for n, m in self.named_modules(): - if n != "" and isinstance(m, ShardParamsDataParallel): + if n != "" and isinstance(m, FullyShardedDataParallel): m._streams = self._streams def _wait_for_previous_optim_step(self) -> None: """ - The outer-most :class:`ShardParamsDataParallel` instance (i.e., the root + The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root instance) needs to synchronize with the default stream to ensure the previous optimizer step is done. """ @@ -562,7 +562,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if param.grad is None: return if param.grad.requires_grad: - raise RuntimeError("ShardParamsDataParallel only works with gradients that don't require grad") + raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") # Free full params and switch to FP32 shard after backward. self._free_full_params([param]) diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_fully_sharded_data_parallel.py similarity index 95% rename from tests/nn/data_parallel/test_shard_params_data_parallel.py rename to tests/nn/data_parallel/test_fully_sharded_data_parallel.py index dfd7abc95..d0557424c 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_fully_sharded_data_parallel.py @@ -15,7 +15,7 @@ import torch from torch import nn -from fairscale.nn.data_parallel import ShardParamsDataParallel +from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.utils.testing import ( DeviceAndTypeCheckModule, DummyProcessGroup, @@ -61,11 +61,11 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01): return loss.detach() @staticmethod - def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> ShardParamsDataParallel: + def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> FullyShardedDataParallel: if cuda_first: - model = ShardParamsDataParallel(TransformerWithSharedParams(**model_kwargs).cuda(), group, **config) + model = FullyShardedDataParallel(TransformerWithSharedParams(**model_kwargs).cuda(), group, **config) else: - model = ShardParamsDataParallel(TransformerWithSharedParams(**model_kwargs), group, **config).cuda() + model = FullyShardedDataParallel(TransformerWithSharedParams(**model_kwargs), group, **config).cuda() return model @@ -139,7 +139,7 @@ def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, @staticmethod def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): # Patch _reduce_scatter op to check the dtype of the reduction - orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter + orig_reduce_scatter = FullyShardedDataParallel._reduce_scatter model = DeviceAndTypeCheckModule( expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, @@ -149,8 +149,8 @@ def _reduce_scatter(self, tensor): model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) return orig_reduce_scatter(self, tensor) - with mock.patch.object(ShardParamsDataParallel, "_reduce_scatter", new=_reduce_scatter): - model = ShardParamsDataParallel(model, group, **cfg).cuda() + with mock.patch.object(FullyShardedDataParallel, "_reduce_scatter", new=_reduce_scatter): + model = FullyShardedDataParallel(model, group, **cfg).cuda() device = next(model.parameters()).device x = torch.rand(2, 5).to(device) with torch.cuda.amp.autocast(enabled=autocast): @@ -169,7 +169,7 @@ def rename_test(testcase_func, param_num, param): class TestComparisonToPyTorchDDP(DistributedTest): """ Compare losses and parameter values after several updates when using - PyTorch DDP vs. ShardParamsDataParallel. + PyTorch DDP vs. FullyShardedDataParallel. """ @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @@ -234,8 +234,8 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) ref_state_dict = model.module.state_dict() - # Confirm we get the same behavior using ShardParamsDataParallel. - model = ShardParamsDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) + # Confirm we get the same behavior using FullyShardedDataParallel. + model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) if use_cuda: model = model.cuda() else: @@ -247,7 +247,7 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: - raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") + raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") class TestParamInit(DistributedTest): @@ -309,13 +309,13 @@ def _test_multiprocessing(self, rank, group, config): def _get_model(self, group, config): with torch.no_grad(): # required for multiprocessing model = NestedWrappedModule(group, wrapper_config=config) - return ShardParamsDataParallel(model, group, **config) + return FullyShardedDataParallel(model, group, **config) @classmethod def _one_step(self, model, group): # reset the process group (required after unpickling) for m in model.modules(): - if isinstance(m, ShardParamsDataParallel): + if isinstance(m, FullyShardedDataParallel): m.process_group = group optim = torch.optim.Adam(model.parameters(), lr=0.0001) input = model.module.get_input(torch.device("cuda")) @@ -412,10 +412,10 @@ def _test_module_state_dict(cls, config, rank, group): autocast = ddp_model.mixed_precision cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() - # You must make a new ShardParamsDataParallel instance to use module.load_state_dict + # You must make a new FullyShardedDataParallel instance to use module.load_state_dict unwrapped_model = TransformerWithSharedParams() unwrapped_model.load_state_dict(state_1) - new_ddp_model = ShardParamsDataParallel(unwrapped_model, group, **config).cuda() + new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda() cls._train_for_several_steps(new_ddp_model, 2, autocast) try: ddp_model.load_state_dict(new_ddp_model.state_dict()) @@ -530,7 +530,7 @@ def _test_transformer(self, rank, group, config): @classmethod def _test_nested_wrapper(self, rank, group, config): model = NestedWrappedModule(group, config) - model = ShardParamsDataParallel(model, group, **config).cuda() + model = FullyShardedDataParallel(model, group, **config).cuda() self._test_no_sync(model, batch_dim=0) @classmethod @@ -615,7 +615,7 @@ def __init__(self, group, wrapper_config): def _maybe_wrap(layer): if wrapper_config is not None: - return ShardParamsDataParallel(layer, group, **wrapper_config) + return FullyShardedDataParallel(layer, group, **wrapper_config) return layer torch.manual_seed(0) # keep everything deterministic From 5fc1f1270c26bf4435e671bbdff0e0f31099120a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 11 Feb 2021 16:53:06 -0500 Subject: [PATCH 35/48] Slightly faster execution when world_size == 1 (#81) --- .../fully_sharded_data_parallel.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 43f390438..94f17d8e0 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -197,7 +197,7 @@ def _calc_num_to_pad(self, numel: int) -> int: @torch.no_grad() def _all_gather_full_param(self, p: nn.Parameter) -> None: - """Fill p.full_param with gathered p.data values (using torch.distributed.all_gather). + """Fill p._full_param with gathered p.data values (using torch.distributed.all_gather). If the last shard is smaller than the other shards, we pad it with zeroes. """ full_param_chunks = list(torch.flatten(p._full_param).chunk(self.world_size)) @@ -584,11 +584,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) - # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.world_size) + if self.world_size > 1: + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.world_size) - # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(param.grad.data) + # Reduce-scatter grad. + param.grad.data = self._reduce_scatter(param.grad.data) + else: + param.grad.data = torch.flatten(param.grad.data) # Cast grad to param's dtype (typically FP32). Note: we do this # before the move_grads_to_cpu step so that this entire hook remains @@ -628,7 +631,11 @@ def _rebuild_full_params(self) -> None: for p in self.params: if p._full_param.storage().size() != p._orig_size.numel(): alloc_storage_(p._full_param, size=p._orig_size) - self._all_gather_full_param(p) # Fill p._full_param with (p.data for each shard in self.world_size) + if self.world_size > 1: + # Fill p._full_param with (p.data for each shard in self.world_size). + self._all_gather_full_param(p) + else: + torch.flatten(p._full_param).copy_(p.data) p.data = p._full_param From 36812422ae3e06150facdf69e53391a624f0c211 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Thu, 11 Feb 2021 13:54:20 -0800 Subject: [PATCH 36/48] merge new base (which is public/master) (#82) * [chore] Fix lint errors that broke master (#348) authored-by: Anjali Sridhar * [fix] ShardedDDP - cpu testfix - remove Gloo/CPU (#350) * no idea about the root issue, but it proved to be fairly narrowed (gloo+cpu+python3.8+no cuda installed) so I guess that's out of scope for fairscale * [feat][OSS] elastic and pytorch compatible checkpoints (#310) * adding a test to prove the inter operability with upstream pytorch * updating the changelog * eager state pruning * pytorch 1.5 compat * [fix] ShardedDDP - properly handle post device change (#353) * adding the .to(device) support + unit testing * doc update * [feat] Add AdaScaleWrapper (#347) * [feat] Add AdaScaleWrapper - This enables a different API for wrapping an optimizer with AdaScale. - This also enables AdaScale to be wrapped by OSS. - However, OSS wrapping AdaScale results in different optimization, which future research will be needed to study its effects. testing: add unit tests. * addressed comment: typo * [refactor] Refactor and enable multiprocess nn.Pipe benchmarks. (#319) * mp cleanup * round of multiprocess refactoring * test golden run * print cuda stats * fix lint errors * enable multiprocess pipe benchmarks * set world size to be available gpus * more changes * use synthetic loaders for intermediate pipeline stages * merged master * fix for the devices property * dataloader fix * modify rank check * print wps stats * enable verification * fix logging * fix flag name * fix flag name * check for rank * fix indent * pass args * pass args * modify golden data * remove unused print messsage * fix lint errors * add comments * fix benchmarks Co-authored-by: Anjali Sridhar * [refactor] pipe: simplify balance and module checks (#346) * [chore] v0.1.5 (#355) * [chore] disheartening switch off of a OSS cpu test (#356) * precise skip, only if agent has only cpu * [feat][minor] OSS Benchmark - regression test + background testing new optims (#352) * restoring the regression test, adding a test of the for_each optims * fix the regression test on circleci * removing unused flags * [refactor] multiprocess_pipe: cleanup __init__ (#357) * [refactor] multiprocess_pipe: remove retain_graph __init__ param (#358) It is not currently being used so we can simplify the interface by removing it. * [refactor] multiprocess_pipe: focus on LazyModule usage (#360) * [feat] ShardedDDP : Adding a proper DDP parity / AMP unit test, overdue (#361) * Adding a proper ddp parity / AMP unit test, overdue * catch non-AMP pytorch * [perf][OSS] Clip grad norm : minor obvious speedup (#363) cache this iterator, easy speed up * [refactor] multiprocess_pipe: remove pipelined_backward (#362) * [perf] ShardedDDP - small memory use reduction - minor speedup (#366) * minor * minor * [fix] repro+fix (#365) fix a broken earlier commit, only worked for the first step * [refactor] OSS only use flat buffers (#371) * flat params all along, way simpler * updating the docstring * [refactor] AsyncPipe: do not sub-class MultiProcessPipe (#370) * [refactor] remove multiprocess dependency on async (#373) * [fix] Workaround need for pip --no-build-isolation (#375) * Add fairscale.nn.misc.checkpoint_activations (#376) * Add fairscale.utils.containers Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> * Add fairscale.nn.misc.checkpoint_activations Co-authored-by: Sam Shleifer Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Co-authored-by: Sam Shleifer * [chore] v0.1.6 (#377) * v0.1.6 Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com> Co-authored-by: Benjamin Lefaudeux Co-authored-by: Anjali Sridhar Co-authored-by: msbaines <35972327+msbaines@users.noreply.github.com> Co-authored-by: Leonard Lausen Co-authored-by: Myle Ott Co-authored-by: Sam Shleifer --- .circleci/config.yml | 17 +- .github/workflows/build_wheels.yml | 19 +- .gitignore | 1 + CHANGELOG.md | 24 +- benchmarks/datasets/wikitext2_data.py | 4 +- benchmarks/experimental_ampnet.py | 1 - benchmarks/golden_configs/lm_wikitext2.py | 21 +- benchmarks/golden_configs/oss_mnist.py | 6 +- benchmarks/oss.py | 22 +- benchmarks/pipe.py | 204 +++++----- fairscale/__init__.py | 2 +- fairscale/nn/data_parallel/sharded_ddp.py | 111 ++++-- fairscale/nn/misc/checkpoint_activations.py | 158 ++++++++ fairscale/nn/pipe/async_pipe.py | 264 ++++++++++++- fairscale/nn/pipe/async_pipeline.py | 44 ++- fairscale/nn/pipe/async_schedule.py | 5 +- fairscale/nn/pipe/multiprocess_pipe.py | 288 ++------------ fairscale/nn/pipe/multiprocess_pipeline.py | 12 +- fairscale/nn/pipe/rpc.py | 11 +- fairscale/optim/__init__.py | 2 +- fairscale/optim/adascale.py | 53 ++- fairscale/optim/oss.py | 381 +++++++++---------- fairscale/utils/testing.py | 10 + pyproject.toml | 2 +- stubs/torch/utils/checkpoint.pyi | 3 +- tests/nn/data_parallel/test_sharded_ddp.py | 160 +++++--- tests/nn/misc/test_checkpoint_activations.py | 80 ++++ tests/nn/model_parallel/test_layers.py | 1 - tests/nn/pipe_process/skip/__init__.py | 18 - tests/nn/pipe_process/skip/test_gpipe.py | 180 --------- tests/nn/pipe_process/skip/test_leak.py | 133 ------- tests/nn/pipe_process/test_pipe.py | 227 +---------- tests/nn/pipe_process/test_transparency.py | 4 +- tests/optim/test_oss.py | 307 +++++++++------ tests/optim/test_oss_adascale.py | 28 +- 35 files changed, 1421 insertions(+), 1382 deletions(-) create mode 100644 fairscale/nn/misc/checkpoint_activations.py create mode 100644 tests/nn/misc/test_checkpoint_activations.py delete mode 100644 tests/nn/pipe_process/skip/__init__.py delete mode 100644 tests/nn/pipe_process/skip/test_gpipe.py delete mode 100644 tests/nn/pipe_process/skip/test_leak.py diff --git a/.circleci/config.yml b/.circleci/config.yml index a9dd7e0df..e947c0d32 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -168,12 +168,18 @@ run_pipe_benchmark: &run_pipe_benchmark command: | python benchmarks/pipe.py +run_mp_pipe_benchmark: &run_mp_pipe_benchmark + - run: + name: Run Multiprocess Pipe Benchmark + command: | + python benchmarks/pipe.py --multiprocess --lazy-construction + run_oss_benchmark: &run_oss_benchmark - run: name: Run OSS Benchmark command: | python benchmarks/oss.py --world_size 4 --epochs 2 - python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp --reference_speed 660 --reference_memory 930 --reference_loss 0.023 + python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp run_oss_gloo: &run_oss_gloo - run: @@ -188,6 +194,12 @@ run_oss_amp: &run_oss_amp command: | python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp +run_oss_for_each: &run_oss_for_each + - run: + name: Run OSS with Torch AMP and ForEach optmizer + command: | + python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp --multi_tensor_optim + run_doc_build: &run_doc_build - run: @@ -444,12 +456,15 @@ jobs: - <<: *run_pipe_benchmark + - <<: *run_mp_pipe_benchmark + - <<: *run_oss_benchmark - <<: *run_oss_gloo - <<: *run_oss_amp + - <<: *run_oss_for_each diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 26921c0e3..99f254bf9 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -13,28 +13,17 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-20.04, macOS-10.15] steps: - uses: actions/checkout@v2 - name: Install Python uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Install cibuildwheel - run: | - python -m pip install cibuildwheel - - name: Build wheels for CPython - run: | - python -m cibuildwheel --output-dir dist - env: - CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64 cp39-*64" - CIBW_MANYLINUX_X86_64_IMAGE: manylinux1 - CIBW_BEFORE_BUILD: pip install . + - name: Build wheel + run: pip install . && pip wheel -w wheelhouse . - uses: actions/upload-artifact@v2 with: - name: wheels - path: ./dist/*.whl + path: ./wheelhouse/*.whl diff --git a/.gitignore b/.gitignore index 28155aa60..11b4ad6b2 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ venv/ ENV/ env.bak/ venv.bak/ +.vscode/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 565cb0071..6acfd06fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,31 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [next rel] - TBD + +## [0.1.6] - 2021-02-10 +### Added +- Checkpointing model wrapper (#376) +- Faster OSS, flatbuffers (#371) +- Small speedup in OSS clipgradnorm (#363) + +### Fixed +- Bug in ShardedDDP with 0.1.5 depending the init (KeyError / OSS) +- Much refactoring in Pipe (#357, #358, #360, #362, #370, #373) +- Better pip integration / resident pytorch (#375) + +## [0.1.5] - 2021-02-03 ### Added +- Pytorch compatibility for OSS checkpoints (#310) +- Elastic checkpoints for OSS, world size can vary in between save and loads (#310) +- Tensor views for OSS bucketing, reduced CPU use (#300) - Bucket calls in ShardedDDP, for faster inter node communications (#327) -- Tensor views for OSS bucketing, reduced CPU use +- FlattenParamWrapper, which flattens module parameters into a single tensor seamlessly (#317) +- AMPnet experimental support (#304) + +### Fixed +- ShardedDDP properly handles device changes via `.to()` (#353) +- Add a new interface for AdaScale, AdaScaleWrapper, which makes it compatible with OSS (#347) + ## [0.1.4] - 2021-01-07 ### Fixed diff --git a/benchmarks/datasets/wikitext2_data.py b/benchmarks/datasets/wikitext2_data.py index ce32d1231..931dc07e2 100644 --- a/benchmarks/datasets/wikitext2_data.py +++ b/benchmarks/datasets/wikitext2_data.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import io +import tempfile import torch from torch.utils.data import DataLoader @@ -28,7 +29,8 @@ def get_real_dataloaders(args, benchmark_config): """Return real dataloaders for training, testing and validation.""" url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root="/tmp")) + tmpdir = tempfile.TemporaryDirectory() + test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name)) tokenizer = get_tokenizer("basic_english") def data_process(raw_text_iter): diff --git a/benchmarks/experimental_ampnet.py b/benchmarks/experimental_ampnet.py index d952cacc9..dffc5a621 100644 --- a/benchmarks/experimental_ampnet.py +++ b/benchmarks/experimental_ampnet.py @@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers): chunks=args.chunks, worker_map=get_worker_map(), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), - pipelined_backward=False, checkpoint=args.checkpoint, ) if torch.cuda.is_available(): diff --git a/benchmarks/golden_configs/lm_wikitext2.py b/benchmarks/golden_configs/lm_wikitext2.py index aa3a8da95..8a80fbcb2 100644 --- a/benchmarks/golden_configs/lm_wikitext2.py +++ b/benchmarks/golden_configs/lm_wikitext2.py @@ -20,17 +20,24 @@ def get_benchmark_config(): "scaler": GradScaler(), "clip_value": 0.05, "batch_size": 8, + "num_decoder_layers": 10, "seq_len": 32, } -def get_golden_real_stats(): - - return { - "avg_wps": 703.778, - "std_dev_wps": 5.732, - "peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496], - } +def get_golden_real_stats(multiprocess=False): + if not multiprocess: + return { + "avg_wps": 703.778, + "std_dev_wps": 5.732, + "peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496], + } + else: + return { + "avg_wps": 647.404, + "std_dev_wps": 14.51, + "peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608], + } def get_golden_synthetic_stats(): diff --git a/benchmarks/golden_configs/oss_mnist.py b/benchmarks/golden_configs/oss_mnist.py index 92778cd55..0bac2b496 100644 --- a/benchmarks/golden_configs/oss_mnist.py +++ b/benchmarks/golden_configs/oss_mnist.py @@ -4,9 +4,9 @@ def get_golden_real_stats(): return { - "reference_speed": 1430, - "reference_memory": 1220, - "reference_loss": 0.006, + "reference_speed": 660, + "reference_memory": 1000, + "reference_loss": 0.026, } diff --git a/benchmarks/oss.py b/benchmarks/oss.py index a5de2fbf4..a4e743ae1 100755 --- a/benchmarks/oss.py +++ b/benchmarks/oss.py @@ -28,7 +28,6 @@ from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler -OPTIM = torch.optim.RMSprop TEMPDIR = tempfile.gettempdir() @@ -78,7 +77,7 @@ class OptimType(str, Enum): everyone = "everyone" -def validate_benchmark(measurements, args, check_regression): +def validate_benchmark(measurements, final_loss, args, check_regression): """Validate the measurments against the golden benchmark config.""" golden_data = oss_mnist.get_golden_real_stats() @@ -118,6 +117,10 @@ def train( ): logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) + use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor") + OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop # type: ignore # attr is checked but mypy misses that + logging.info("Multi tensor optimizer: {}".format(use_multi_tensor)) + # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) @@ -260,7 +263,7 @@ def run_closure(closure, scaler, optimizer): img_per_sec = n_items / (training_stop - training_start) * args.epochs logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint") - validate_benchmark(measurements, args, check_regression) + validate_benchmark(measurements, final_loss, args, check_regression) dist.destroy_process_group() # type: ignore @@ -273,9 +276,6 @@ def run_closure(closure, scaler, optimizer): parser.add_argument("--epochs", action="store", default=10, type=int) parser.add_argument("--batch_size", action="store", default=256, type=int) parser.add_argument("--check_regression", action="store_true", default=False) - parser.add_argument("--reference_speed", action="store", default=1430, type=float) - parser.add_argument("--reference_memory", action="store", default=1220, type=float) - parser.add_argument("--reference_loss", action="store", default=0.006, type=float) parser.add_argument( "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone ) @@ -285,6 +285,9 @@ def run_closure(closure, scaler, optimizer): parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101") parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP") + parser.add_argument( + "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers" + ) args = parser.parse_args() @@ -332,12 +335,7 @@ def run_closure(closure, scaler, optimizer): logging.info("\n*** Benchmark OSS with ShardedDDP") mp.spawn( train, # type: ignore - args=( - args, - BACKEND, - OptimType.oss_sharded_ddp, - False, - ), # FIXME: @lefaudeux - SDP should give the same results + args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,), nprocs=args.world_size, join=True, ) diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index bdfe9a8a5..04cee6fdd 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -7,7 +7,6 @@ import logging import math import operator -import os import pprint import time @@ -17,6 +16,7 @@ from models import transformer_lm import numpy as np import torch +import torch.distributed as dist from torch.distributed import rpc import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP @@ -29,6 +29,9 @@ from fairscale.optim.oss import OSS from fairscale.utils.testing import dist_init, get_worker_map +MPI_PORT = 29500 +RPC_PORT = 29501 + def init_random_seed(seed: int): @@ -64,7 +67,7 @@ def get_lm_model(args, device, config): dropout = config["dropout"] vocab_size = config["vocab_size"] nhid = config["nhid"] - ndecoder = args.num_decoder_layers + ndecoder = config["num_decoder_layers"] if args.lazy_construction: layers = [ @@ -179,13 +182,13 @@ def get_device(model, index): if not torch.cuda.is_available(): return torch.device("cpu") - if model.devices: + if hasattr(model, "devices"): return model.devices[index] else: return torch.cuda.current_device() -def get_fake_dataloader(lm_dataloader_len): +def get_fake_dataloader(lm_dataloader_len, args): fake_input = {"input": torch.zeros(args.batch_size)} class FakeDataset: @@ -224,7 +227,7 @@ def train(model_config, model, benchmark_config, args): # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one. if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): - lm_dataloader = get_fake_dataloader(len(lm_dataloader)) + lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config) total_tokens = 0 total_tokens_per_log_interval = 0 @@ -288,11 +291,12 @@ def get_batch(source): if i % log_interval == 0 and i > 0: cur_loss = total_loss / log_interval elapsed = time.time() - start_time - print( - "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( - i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) + if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: + print( + "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( + i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) + ) ) - ) total_tokens_per_log_interval = 0 total_loss = 0 start_time = time.time() @@ -303,8 +307,10 @@ def get_batch(source): raise RuntimeError( "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark." ) - - return wps, loss.item() + if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: + return wps, loss.item() + else: + return 0.0, 0.0 # TODO(anj-s): Add an option for users to be able to benchmark evaluate. @@ -334,52 +340,64 @@ def get_number_of_words(data): return data.size()[0] * data.size()[1] -def verify_lm_run(wps, golden_config): - """Verify that words per second for a given benchmark run matches the golden data.""" - - # Assert that words per second is within 3 standard deviations of the average - # of five golden runs - print("Throughput(wps) is {:.2f}.".format(wps)) - if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): +def verify_peak_memory(rank, golden_config, std_dev): + print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])) + current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"] + golden_ref = golden_config["peak_mem_usage"][rank] + if not current_device_usage < golden_ref * std_dev: raise RuntimeError( - "Throughput(wps):{:.2f} is below the golden threshold of an " - "average value of {:.2f} and standard dev of {:.2f}.".format( - wps, golden_config["avg_wps"], golden_config["std_dev_wps"] - ) + "Peak memory usage for cuda device {:d} is {:d} which" + "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref) ) - for i in range(4): - print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"])) - # Assert that memory usage on each GPU is within 10% of golden run - # Right-hand-side is golden run bytes * 110% - for i, golden_ref in zip(range(4), golden_config["peak_mem_usage"]): - current_device_usage = torch.cuda.memory_stats(i)["allocated_bytes.all.peak"] - if not current_device_usage < golden_ref * 1.1: +def verify_lm_run(wps, golden_config, args): + """Verify that words per second for a given benchmark run matches the golden data.""" + + # Verify wps only on the last rank in multiprocess pipe + if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: + # Assert that words per second is within 3 standard deviations of the average + # of five golden runs + print("Throughput(wps) is {:.2f}.".format(wps)) + if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): raise RuntimeError( - "Peak memory usage for cuda device {:d} is {:d} which" - "is less than golden reference value of {:d}".format(i, current_device_usage, golden_ref) + "Throughput(wps):{:.2f} is below the golden threshold of an " + "average value of {:.2f} and standard dev of {:.2f}.".format( + wps, golden_config["avg_wps"], golden_config["std_dev_wps"] + ) ) + if args.multiprocess: + verify_peak_memory(dist.get_rank(), golden_config, 1.5) + else: + for i in range(4): + verify_peak_memory(i, golden_config, 1.1) + def benchmark_language_model(model_config, model, benchmark_config, args): - golden_config = get_golden_config(args.model_name) + golden_config = get_golden_config(args.model_name, args) epoch = benchmark_config["epochs"] - print("-" * 110) - print("| start of epoch {:1d}".format(epoch)) - print("-" * 110) start_time = time.time() + if dist.get_rank() == dist.get_world_size() - 1: + print("-" * 110) + print("| start of epoch {:1d}".format(epoch)) + print("-" * 110) wps, loss = train(model_config, model, benchmark_config, args) elapsed_time = time.time() - start_time - print("-" * 110) - print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) - print("-" * 110) + if dist.get_rank() == dist.get_world_size() - 1: + print("-" * 110) + print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) + print("-" * 110) + print("Throughput(wps) is {:.2f}.".format(wps)) + print( + "Peak allocated bytes on cuda:{}: {:1d}".format( + dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"] + ) + ) - print("wps ", wps) if len(model.balance) == 4: - if args.model_name == "lm": - verify_lm_run(wps, golden_config) + verify_lm_run(wps, golden_config, args) else: raise RuntimeError("Unrecognized args.model_name " % args.model_name) @@ -458,11 +476,11 @@ def create_benchmark_config(model_name): raise RuntimeError("Unrecognized args.model_mame " % args.model_name) -def get_golden_config(model_name): +def get_golden_config(model_name, args): """Return a dict with the golden data for throughput and memory usage.""" if model_name == "lm": - return lm_wikitext2.get_golden_real_stats() + return lm_wikitext2.get_golden_real_stats(args.multiprocess) else: raise RuntimeError("Unrecognized args.model_mame " % args.model_name) @@ -470,6 +488,9 @@ def get_golden_config(model_name): def benchmark_single_process(args): """Benchmark a given model using a single process and multiple devices.""" + init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT) + torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup) + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 assert num_devices > 0 init_random_seed(0) @@ -492,17 +513,16 @@ def benchmark_single_process(args): def run_mp_worker(args, available_workers): benchmark_config = create_benchmark_config(args.model_name) - model_config = create_model_config(args, config=benchmark_config) + model_config = create_model_config(args, benchmark_config=benchmark_config) model = model_config["model"] - balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) + balance = generate_balance(get_pipeline_parallel_group().size(), len(model)) pipe_model = MultiProcessPipe( model, balance, chunks=args.chunks, worker_map=get_worker_map(), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), - pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint, # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"], ) @@ -512,7 +532,7 @@ def run_mp_worker(args, available_workers): print(f"running all at once") pipe_model.pipeline.all_at_once = True - if args.use_synthetic_data: + if args.dry_run: train(model_config, pipe_model, benchmark_config, args) else: benchmark_language_model(model_config, pipe_model, benchmark_config, args) @@ -530,63 +550,27 @@ def run_worker(rank, world_size, args): torch.distributed.destroy_process_group() -def bench_multi_process(args, all_at_once=False): - if args.local_world_size != 0: - world_size = args.local_world_size - else: - world_size = min(torch.cuda.device_count(), 2) - mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True) - - -best_device_map = { - 0: "mlx5_0:1", - 1: "mlx5_0:1", - 2: "mlx5_1:1", - 3: "mlx5_1:1", - 4: "mlx5_2:1", - 5: "mlx5_2:1", - 6: "mlx5_3:1", - 7: "mlx5_3:1", -} - - -def bench_mpi(args): - guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) - local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank] - - os.environ["MASTER_ADDR"] = args.host - os.environ["MASTER_PORT"] = "10638" - if args.socket_name: - os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name - os.environ["TP_SOCKET_IFNAME"] = args.socket_name - - torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size) - - os.environ["MASTER_ADDR"] = args.host - os.environ["MASTER_PORT"] = "10639" - init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - torch.cuda.set_device(local_rank % torch.cuda.device_count()) +def benchmark_multiprocess(rank, world_size, args): + + init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT) + # TODO(anj-s): Add regression benchmarks for nccl as well. + torch.distributed.init_process_group( + backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup + ) + torch.cuda.set_device(rank % torch.cuda.device_count()) + # TODO(anj-s): Move to TensorPipeRpcBackendOptions. rpc.init_rpc( f"Test{rank}", rank=rank, world_size=world_size, backend=rpc.BackendType.PROCESS_GROUP, - rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method), + rpc_backend_options=rpc.ProcessGroupRpcBackendOptions( + rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT) + ), ) - - backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"} - - if args.ddp_zero: - initialize_model_parallel(1, 4, **backends) - else: - initialize_model_parallel(1, world_size, **backends) + initialize_model_parallel(1, world_size) init_random_seed(0) - run_mp_worker(args, world_size) rpc.shutdown() @@ -594,17 +578,12 @@ def bench_mpi(args): parser = argparse.ArgumentParser(description="benchmark") -parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size") -parser.add_argument("--world-size", "-w", type=int, default=0, help="world size") -parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0) +parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.") parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") -parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi") parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch") parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once") parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches") -parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp") -parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp") parser.add_argument( "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model" @@ -612,12 +591,6 @@ def bench_mpi(args): parser.add_argument( "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe" ) -parser.add_argument( - "--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass" -) -parser.add_argument( - "--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass" -) parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") parser.add_argument( @@ -626,15 +599,16 @@ def bench_mpi(args): default="lm", help="Language Model(LM) used to benchmark nn.pipe.", ) -parser.set_defaults(pipelined_backward=True) if __name__ == "__main__": args = parser.parse_args() - # TODO(anj-s): Add support for multiprocess benchmarking. - if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ: - print(f"Running benchmark with args: {args}") + + # TODO(anj-s): Remove print statements and introduce logging levels. + + if not args.multiprocess: + print(f"Running single process benchmark with args: {args}") benchmark_single_process(args) else: - if os.environ["OMPI_COMM_WORLD_RANK"] == "0": - print(f"Running benchmark with args: {args}") - bench_mpi(args) + world_size = max(torch.cuda.device_count(), 1) + print(f"Running multiprocess benchmark with args: {args}") + mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True) diff --git a/fairscale/__init__.py b/fairscale/__init__.py index 82cd113eb..91b69ba3d 100644 --- a/fairscale/__init__.py +++ b/fairscale/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -__version__ = "0.1.4" +__version__ = "0.1.6" ################################################################################ # Import most common subpackages diff --git a/fairscale/nn/data_parallel/sharded_ddp.py b/fairscale/nn/data_parallel/sharded_ddp.py index a46dc039b..ad2cdef6e 100644 --- a/fairscale/nn/data_parallel/sharded_ddp.py +++ b/fairscale/nn/data_parallel/sharded_ddp.py @@ -11,7 +11,7 @@ import contextlib from itertools import chain import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch from torch import nn @@ -142,7 +142,7 @@ def __init__( self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers} self._should_bucket_grad: List[bool] = [] - self._bucket_iterator: Optional[Iterable[Bucket]] = None + self._bucket_list: Optional[List[Bucket]] = None self._setup_bucket_strategy() # - setup backward hooks which will be called by Torch's autograd in due time @@ -156,6 +156,8 @@ def __init__( if sync_models_at_startup: self._sync_params_and_buffers() + self._clear_counters() + def forward(self, *inputs: Any, **kwargs: Any) -> Any: """ Module forward pass, handles any DDP-specific work in the background. Primes the @@ -172,6 +174,44 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: # Normal FW on the base model return self.module(*inputs, **kwargs) + def to( # type: ignore + self, + device: Optional[Union[int, torch.device]], + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> "ShardedDataParallel": + """ + Moves and/or casts the parameters and buffers. + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point desired :attr:`dtype` s. In addition, this method will + only cast the floating point parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (:class:`torch.device`): the desired device of the parameters and buffers in this module. + dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers. + non_blocking (bool): make it an asynchronous call. + + Returns: + Module: self. + + """ + + for optimizer in self.buckets.keys(): + for device in self.buckets[optimizer].keys(): + for bucket in self.buckets[optimizer][device]: + bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking) + + self.module.to(device) + def reduce(self) -> None: """.. deprecated:: 0.0.4 @@ -215,18 +255,20 @@ def no_sync(self) -> Generator: def _clear_counters(self) -> None: """Reset all the grad reduce and call counters""" if not self.should_accumulate_grads: + self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] self._reduced_grads = {o: 0 for o in self.sharded_optimizers} - for o in self.buckets.keys(): - for d in self.buckets[o].keys(): - for bucket in self.buckets[o][d]: - assert bucket.sent, ( - "A bucket failed being sent, probably unused parameters." - + "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" - ) + if self.use_buckets: + assert self._bucket_list is not None + + for bucket in self._bucket_list: + assert bucket.sent, ( + "A bucket failed to be sent, probably unused parameters." + + "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" + ) - bucket.reset() + bucket.reset() def _find_rank(self, param: Parameter) -> Tuple[OSS, int]: """ Look up where this parameter belongs to """ @@ -336,19 +378,16 @@ def _setup_backward_hooks(self) -> None: def bucket_flush(*unused: Any) -> None: handle = None - for bucket_optim in self.buckets.values(): - for bucket_rank in bucket_optim.values(): - for bucket in bucket_rank: - if not bucket.sent: - # Reduce the bucket. Some parameters went unused and this bucket was not flushed - bucket.buffer.mul_(self.world_size_scaling) - bucket.sent = True - handle = dist.reduce( - tensor=bucket.buffer, - dst=bucket.destination, - group=self.process_group, - async_op=True, - ) + assert self._bucket_list is not None + + for bucket in self._bucket_list: + if not bucket.sent: + # Reduce the bucket. Some parameters went unused and this bucket was not flushed + bucket.buffer.mul_(self.world_size_scaling) + bucket.sent = True + handle = dist.reduce( + tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True, + ) # Only wait on the last handle if handle: @@ -392,19 +431,19 @@ def _setup_bucket_strategy(self) -> None: if not self.use_buckets: return - # - Allocate one buffer per rank and per device to group the small parameters - for sharded_optimizer in self.sharded_optimizers: - for device, per_device in sharded_optimizer.per_device_params.items(): - self.buckets[sharded_optimizer][device] = [ - Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=per_device[0][0].dtype, device=device)) - for _ in per_device - ] - # Devise the bucketing strategy for sharded_optimizer in self.sharded_optimizers: for device, per_rank_params in sharded_optimizer.per_device_params.items(): + self.buckets[sharded_optimizer][device] = [] + for dst_rank, params in enumerate(per_rank_params): offset = 0 + + self.buckets[sharded_optimizer][device].append( + Bucket( + buffer=torch.zeros(self.buffer_max_size, dtype=per_rank_params[0][0].dtype, device=device) + ) + ) bucket = self.buckets[sharded_optimizer][device][dst_rank] bucket.destination = dst_rank @@ -435,3 +474,13 @@ def _setup_bucket_strategy(self) -> None: bucket.buffer.resize_(offset) if bucket.max_params_checked_in > 0: self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket + + self._bucket_list = list( + chain( + *[ + self.buckets[sharded_optimizer][device] + for sharded_optimizer in self.sharded_optimizers + for device in sharded_optimizer.per_device_params.keys() + ] + ) + ) diff --git a/fairscale/nn/misc/checkpoint_activations.py b/fairscale/nn/misc/checkpoint_activations.py new file mode 100644 index 000000000..44f29c612 --- /dev/null +++ b/fairscale/nn/misc/checkpoint_activations.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Dict, Optional, Tuple + +import torch +from torch import Tensor +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors + + +def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module: + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + - supports offloading activations to CPU + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + + Args: + module (nn.Module): module to wrap + offload_to_cpu (Optional, bool): whether to offload activations to CPU + """ + module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore + return module + + +def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} + output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + +def get_rng_state() -> Dict[str, Any]: + state = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state: Dict[str, Any]) -> None: + torch.set_rng_state(state["torch_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling :func:`unpack_non_tensors`. + """ + + @staticmethod + def forward( # type: ignore + ctx: Any, + run_function: Any, + parent_ctx_dict: Dict[str, Any], + kwarg_keys: Tuple[str, ...], + *args: Any, + **kwargs: Any + ) -> Any: + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + if parent_ctx_dict["offload"]: + ctx.fwd_device = tuple(x.device for x in tensor_inputs) + ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) + tensor_inputs = tuple(x.cpu() for x in tensor_inputs) + + else: + ctx.fwd_device, ctx.grad_requirements = None, None + + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") + + tensor_inputs: Tuple = ctx.saved_tensors + tensor_inputs = checkpoint.detach_variable(tensor_inputs) + if ctx.fwd_device is not None: + tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs)) + for i, need_grad in enumerate(ctx.grad_requirements): + tensor_inputs[i].requires_grad = need_grad + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = get_rng_state() + + # Set the states to what it used to be before the forward pass. + set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. + set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) + return (None, None, None) + grads diff --git a/fairscale/nn/pipe/async_pipe.py b/fairscale/nn/pipe/async_pipe.py index 2e1868650..7cc6ebd52 100644 --- a/fairscale/nn/pipe/async_pipe.py +++ b/fairscale/nn/pipe/async_pipe.py @@ -6,14 +6,20 @@ from collections import OrderedDict from dataclasses import dataclass, field import itertools -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +import threading +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +import warnings import torch from torch import Tensor, nn +from fairscale.nn.model_parallel import get_pipeline_parallel_group + +from . import microbatch from .async_pipeline import AsyncPipeline from .async_schedule import Invocation, Location, ModuleWrapper -from .multiprocess_pipe import MultiProcessPipe, check_balance +from .batchnorm import DeferredBatchNorm +from .skip.layout import SkipLayout from .skip.skippable import Skippable from .types import LazyModule @@ -38,7 +44,169 @@ def __len__(self) -> int: return len(self.modules) -class AsyncPipe(MultiProcessPipe): +def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None: + if len(set(map(id, module))) != len(module): + raise ValueError("module with duplicate children is not supported") + + +def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None: + if len(module) != sum(balance): + raise ValueError( + f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})" + ) + + if any(x <= 0 for x in balance): + raise ValueError(f"all balance numbers must be positive integer (balance: {balance})") + + +MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") + + +class AsyncPipe(Module): + """Wraps an arbitrary :class:`nn.Sequential ` module + to train on Pipe_. If the module requires lots of memory, Pipe will be + very efficient. + + Pipe combines pipeline parallelism with checkpointing to reduce peak + memory required to train while minimizing device under-utilization. + + You should determine the balance when defining a :class:`AsyncPipe` module, as + balancing will not be done automatically. The module will be partitioned + into multiple devices according to the given balance. You may rely on + heuristics to find your own optimal configuration. + + Args: + module (torch.nn.Sequential): + sequential module to be parallelized + balance (ints): + list of number of layers in each partition + + Keyword Args: + group (ProcessGroup): + the process group that all + pipeline stages are a member of. Defaults to + `get_pipeline_parallel_group()` + worker_map (Dict[int, str]): + a map from worker name (the first argument to + `torch.distributed.rpc.init_rpc`) to global rank (i.e. + `torch.distributed.get_rank()`) needed in order for pipeline stages + to communicate with each other + input_device (device): + the device on which tensors should be located before being passed to + the first module in a given pipeline stage + chunks (int): + number of micro-batches (default: ``1``) + checkpoint (str): + when to enable checkpointing, one of ``'always'``, + ``'except_last'``, or ``'never'`` (default: ``'except_last'``) + deferred_batch_norm (bool): + whether to use deferred BatchNorm moving statistics (default: + :data:`False`, see :class:`DeferredBatchNorm` for more + details) + + Raises: + TypeError: + the module is not a :class:`nn.Sequential `. + ValueError: + invalid arguments, or wrong balance + IndexError: + the number of devices is fewer than the number of partitions. + + """ + + #: The number of layers in each partition. + balance: List[int] = [] + # ^^ + # The default value [] required for Sphinx's autoattribute. + + #: The devices mapped to each partition. + #: + #: ``devices[-1]`` refers to the device of the last partition, which means + #: it is the output device. Probably, you need to use it to transfer the + #: target to calculate the loss without a device mismatch + #: :exc:`RuntimeError`. For example:: + #: + #: out_device = pipe.devices[-1] + #: + #: for input, target in loader: + #: target = target.to(out_device, non_blocking=True) + #: output = pipe(input) + #: loss = F.cross_entropy(output, target) + #: + + #: The number of micro-batches. + chunks: int = 1 + + #: The checkpoint mode to determine when to enable checkpointing. It is one + #: of ``'always'``, ``'except_last'``, or ``'never'``. + checkpoint: str = "except_last" + + def __init__( + self, + module: Union[nn.Sequential, List[LazyModule]], + balance: Iterable[int], + *, + group: Optional[torch.distributed.ProcessGroup] = None, + worker_map: Optional[Dict[int, str]] = None, + input_device: Union[None, int, str, torch.device] = None, + chunks: int = chunks, + checkpoint: str = checkpoint, + deferred_batch_norm: bool = False, + ) -> None: + super().__init__() + + if chunks <= 0: + raise ValueError("number of chunks must be positive integer") + if checkpoint not in ["always", "except_last", "never"]: + raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") + + self.balance = list(balance) + verify_module(module) + check_balance(module, self.balance) + + self.chunks = chunks + self.checkpoint = checkpoint + self.pipeline: Optional[AsyncPipeline] + self.lock = threading.Lock() + + self.worker_map = worker_map + self.input_device = input_device + + self.group: torch.distributed.ProcessGroup + if group is None: + self.group = get_pipeline_parallel_group() + else: + self.group = group + + if self.group.size() < len(self.balance): + raise IndexError( + f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" + f" {len(self.balance)})" + ) + + self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) + + rank = self.group.rank() + self.final_stage = rank == len(self.balance) - 1 + if rank >= len(self.balance): + warnings.warn("More ranks than partitions, some ranks unused") + self.partitions: List[ModuleWrapper] = [] + self.pipeline = None + # TODO(msb) remove this hack + self.partition = None + else: + self.partitions = self.instantiate_partition(module, self.balance, self.group) + if deferred_batch_norm: + for part in self.partitions: + part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) + for name, part in enumerate(self.partitions): + self.add_module(str(name), part.module) + self.create_pipeline() + # TODO(msb) remove this hack + self.partition = self.partitions[0].module + + del module + def create_pipeline(self) -> None: # The micro-batch index where the checkpointing stops. checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] @@ -54,14 +222,8 @@ def create_pipeline(self) -> None: ) def instantiate_partition( - self, - module: Union[nn.Sequential, List[LazyModule]], - balance: Iterable[int], - group: torch.distributed.ProcessGroup, + self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, ) -> List[ModuleWrapper]: - balance = list(balance) - check_balance(module, balance, True) - layers: NamedModules = OrderedDict() def maybe_realize(layer: Any) -> nn.Module: @@ -152,3 +314,85 @@ def append_module(mod: "OrderedDict[str, nn.Module]") -> None: result.append(wrapper) return result + + def __len__(self) -> int: + """Counts the length of the underlying sequential module.""" + return sum(len(p) for p in self.partitions) + + def __getitem__(self, index: int) -> nn.Module: + """Gets a layer in the underlying sequential module.""" + partitions: List[Any] + partitions = self.partitions + + if index < 0: + partitions = partitions[::-1] + + for partition in partitions: + try: + if isinstance(partition, ModuleWrapper): + return partition.module[index] + else: + return partition[index] + except IndexError: + pass + + shift = len(partition) + + if index < 0: + index += shift + else: + index -= shift + + raise IndexError + + def __iter__(self) -> Iterable[nn.Module]: + """Iterates over children of the underlying sequential module.""" + for partition in self.partitions: + yield from partition.module + + def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore + """:class:`AsyncPipe` is a fairly transparent module wrapper. It doesn't + modify the input and output signature of the underlying module. But + there's type restriction. Input and output have to be a + :class:`~torch.Tensor` or a tuple of tensors. This restriction is + applied at partition boundaries too. + + Args: + input (torch.Tensor or tensors): input mini-batch + + Returns: + tensor or tensors: output mini-batch + + Raises: + TypeError: input is not a tensor or tensors. + + """ + microbatch.check(input) + + if not self.pipeline: + # No pipeline is not illegal, more ranks than partitions + return input + + # Divide a mini-batch into micro-batches. + batches = microbatch.scatter(input, self.chunks) + + # Run pipeline parallelism. + with self.lock: + self.pipeline.run(self.training, batches, event) + + if self.final_stage: + output = microbatch.gather(batches) + else: + # Don't merge micro-batches to avoid unnecessary edges in autograd + # graph + # FIXME(tom) should figure out a proper type here + output = batches # type: ignore + + return output + + def back_helper(self, output: List[microbatch.Batch]) -> None: + if self.final_stage: + raise ValueError("back_helper should only be called on non-final stages") + + if self.pipeline: + self.pipeline.back_helper(output) diff --git a/fairscale/nn/pipe/async_pipeline.py b/fairscale/nn/pipe/async_pipeline.py index fdb64dd95..0e6939f08 100644 --- a/fairscale/nn/pipe/async_pipeline.py +++ b/fairscale/nn/pipe/async_pipeline.py @@ -4,18 +4,54 @@ # LICENSE file in the root directory of this source tree. import logging +import os from threading import Event -from typing import List, Optional +from typing import Dict, List, Optional, Union import torch -from .async_schedule import AsyncEventLoop +from .async_schedule import AsyncEventLoop, ModuleWrapper +from .messages import MakeTransport from .microbatch import Batch -from .multiprocess_pipeline import MultiProcessPipeline +from .skip.layout import SkipLayout from .skip.tracker import SkipTrackerThroughPotals -class AsyncPipeline(MultiProcessPipeline): +class AsyncPipeline: + """The async pipeline parallelism for Pipe.""" + + def __init__( + self, + partitions: List[ModuleWrapper], + skip_layout: SkipLayout, + checkpoint_stop: int, + group: torch.distributed.ProcessGroup, + *, + worker_map: Optional[Dict[int, str]] = None, + input_device: Union[None, int, str, torch.device] = None, + final_stage: bool = False, + ) -> None: + self.partitions = partitions + self.skip_layout = skip_layout + self.__checkpoint_stop = checkpoint_stop + self.group = group + self.training: bool + self.transport = MakeTransport( + use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), + worker_map=worker_map, + input_device=input_device, + ) + self.input_device = input_device + self.final_stage = final_stage + + @property + def checkpoint_stop(self) -> int: + # Disable checkpointing if in eval mode. + training = self.partitions[0].module.training + if not training: + return 0 + return self.__checkpoint_stop + def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: """Runs pipeline parallelism. diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index 1bfd1fef0..166d7bbcc 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -17,6 +17,7 @@ from .messages import Transport from .microbatch import Batch +from .multiprocess_pipeline import create_task from .skip.tracker import SkipTrackerThroughPotals from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors @@ -191,10 +192,6 @@ def run_invocation( """Actually run the forward pass for a given module, and send the result to the next stage in the pipeline if needed.""" - # We import here to avoid a cyclic dependency. - # TODO(msb) Break the cyclic dependency. - from .multiprocess_pipeline import create_task - task = create_task( self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, ) diff --git a/fairscale/nn/pipe/multiprocess_pipe.py b/fairscale/nn/pipe/multiprocess_pipe.py index 289a0c5ad..4f1149ce1 100644 --- a/fairscale/nn/pipe/multiprocess_pipe.py +++ b/fairscale/nn/pipe/multiprocess_pipe.py @@ -20,7 +20,7 @@ """The MultiProcessPipe interface.""" from collections import OrderedDict import threading -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union import warnings import torch @@ -31,12 +31,10 @@ from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from . import microbatch -from .async_schedule import Location, ModuleWrapper from .batchnorm import DeferredBatchNorm from .multiprocess_pipeline import MultiProcessPipeline from .phony import get_phony -from .skip.layout import SkipLayout, inspect_skip_layout -from .skip.skippable import Skippable, verify_skippables +from .skip.layout import SkipLayout from .types import LazyModule __all__ = ["MultiProcessPipe", "LazyModule"] @@ -53,123 +51,19 @@ NamedModules = OrderedDict -def recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for -naive automatic balancing: - - from fairscale.nn import Pipe - from fairscale.nn.pipe.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = MultiProcessPipe(model, balance, ...) -""" - - -# FIXME(tom) make this a valid way to call -def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: - for layer in module: - if isinstance(layer, nn.Module): - pass - elif isinstance(layer, LazyModule): - pass - else: - raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned") - - def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None: - if isinstance(module, Iterable) and not isinstance(module, nn.Sequential): - verify_list_of_callable(module) - else: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") - - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - - -def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None: + if len(set(map(id, module))) != len(module): + raise ValueError("module with duplicate children is not supported") - if filter_unique: - module_len = len(set(map(id, module))) - else: - module_len = len(module) - if module_len != sum(balance): - raise BalanceError( +def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None: + if len(module) != sum(balance): + raise ValueError( f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})" ) if any(x <= 0 for x in balance): - raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") - - -def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]: - """Splits a module into multiple partitions. - - Returns: - partitions - - Partitions are represented as a :class:`~torch.nn.ModuleList` whose - item is a partition. All layers in a partition are placed in the - same device. - - Raises: - BalanceError: - wrong balance - IndexError: - the number of devices is fewer than the number of partitions. - - """ - balance = list(balance) - - check_balance(module, balance) - - j = 0 - partitions = [] - layers: NamedModules = OrderedDict() - - for name, layer in module.named_children(): - layers[name] = layer - - if len(layers) == balance[j]: - # Group buffered layers as a partition. - partition = nn.Sequential(layers) - - partitions.append(partition) - - # Prepare for the next partition. - layers.clear() - j += 1 - - return partitions + raise ValueError(f"all balance numbers must be positive integer (balance: {balance})") MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") @@ -223,16 +117,6 @@ class MultiProcessPipe(Module): whether to use deferred BatchNorm moving statistics (default: :data:`False`, see :class:`DeferredBatchNorm` for more details) - pipelined_backward (bool, optional): - if True, call torch.autograd.backward once per microbatch on the - backward pass (instead of once for the whole batch). This works - around a potential deadlock in pytorch when using tensor parallelism - at the same time. Defaults to `True` if - `get_model_parallel_world_size() > 1` - (default: `None`) - retain_graph (bool): - The value passed to `torch.autograd.backwards(..., retain_graph=) - (default: = `True`) Raises: TypeError: @@ -274,7 +158,7 @@ class MultiProcessPipe(Module): def __init__( self, module: Union[nn.Sequential, List[LazyModule]], - balance: Optional[Iterable[int]] = None, + balance: Iterable[int], *, group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, @@ -282,32 +166,25 @@ def __init__( chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, - pipelined_backward: bool = None, - retain_graph: bool = False, ) -> None: super().__init__() - chunks = int(chunks) - checkpoint = str(checkpoint) - - if balance is None: - raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") - verify_module(module) + if get_model_parallel_world_size() > 1: + self.pipelined_backward = True + else: + self.pipelined_backward = False - # Verify if the underlying skippable modules satisfy integrity. The - # integrity can be verified before forward() because it is static. - if isinstance(module, nn.Sequential): - verify_skippables(module) + self.balance = list(balance) + verify_module(module) + check_balance(module, self.balance) self.chunks = chunks self.checkpoint = checkpoint - self.pipelined_backward = pipelined_backward - self.retain_graph = retain_graph self.pipeline: Optional[MultiProcessPipeline] self.lock = threading.Lock() @@ -320,55 +197,35 @@ def __init__( else: self.group = group - self.balance = list(balance) - if self.group.size() < len(self.balance): raise IndexError( f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" f" {len(self.balance)})" ) - try: - rank = self.group.rank() - if rank >= len(self.balance): - warnings.warn("More ranks than partitions, some ranks unused") - self.partitions: List[ModuleWrapper] = [] - else: - self.partitions = self.instantiate_partition(module, balance, self.group) - if deferred_batch_norm: - for part in self.partitions: - part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) - for name, part in enumerate(self.partitions): - self.add_module(str(name), part.module) - if isinstance(module, nn.Sequential): - local_partitions = split_module(module, balance) - self._skip_layout = inspect_skip_layout(local_partitions) - else: - self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) - except BalanceError as exc: - raise ValueError(recommend_auto_balance(str(exc))) + self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) rank = self.group.rank() + self.final_stage = rank == len(self.balance) - 1 if rank >= len(self.balance): + warnings.warn("More ranks than partitions, some ranks unused") + self.partition = nn.Sequential() self.pipeline = None - self.final_stage = False else: - self.final_stage = rank == len(self.balance) - 1 - + self.partition = self.instantiate_partition(module, self.balance, self.group) + if deferred_batch_norm: + self.partitition = DeferredBatchNorm.convert_deferred_batch_norm(self.partition, chunks) + self.add_module(str(0), self.partition) self.create_pipeline() - del module - if self.pipelined_backward is None: - if get_model_parallel_world_size() > 1: - self.pipelined_backward = True - else: - self.pipelined_backward = False + + del module def create_pipeline(self) -> None: # The micro-batch index where the checkpointing stops. checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] self.pipeline = MultiProcessPipeline( - self.partitions, + self.partition, self._skip_layout, checkpoint_stop, group=self.group, @@ -378,88 +235,26 @@ def create_pipeline(self) -> None: ) def instantiate_partition( - self, - module: Union[nn.Sequential, List[LazyModule]], - balance: Iterable[int], - group: torch.distributed.ProcessGroup, - ) -> List[ModuleWrapper]: - balance = list(balance) - check_balance(module, balance, True) - - layers: NamedModules = OrderedDict() - - def maybe_realize(layer: Any) -> nn.Module: - if isinstance(layer, nn.Module): - return layer - elif callable(layer): - return layer() - else: - raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}") - - def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]: - if isinstance(module, nn.Sequential): - yield from module.named_children() - else: - yield from ((str(k), v) for k, v in enumerate(module)) - - j = 0 - - for name, layer in iterate_module(module): - layers[name] = layer - - if len(layers) == balance[j]: - if j == group.rank(): - for key in layers: - layers[key] = maybe_realize(layers[key]) - if not isinstance(module, nn.Sequential): - for layer in layers.values(): - if isinstance(layer, Skippable): - raise ValueError( - "Can't use Skippable layers with multi-process pipe and lazy construction" - ) - - return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))] - - # Prepare for the next partition. - layers.clear() - j += 1 - - raise ValueError("Souldn't get here, more ranks than partitions") + self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, + ) -> nn.Sequential: + rank = group.rank() + first_layer = sum(balance[:rank]) + num_layers = balance[rank] + layers = module[first_layer : first_layer + num_layers] + instantiated_layers = [l if isinstance(l, nn.Module) else l() for l in layers] + return nn.Sequential(*instantiated_layers) def __len__(self) -> int: """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) + return self.partition.__len__() def __getitem__(self, index: int) -> nn.Module: """Gets a layer in the underlying sequential module.""" - partitions: List[Any] - partitions = self.partitions - - if index < 0: - partitions = partitions[::-1] - - for partition in partitions: - try: - if isinstance(partition, ModuleWrapper): - return partition.module[index] - else: - return partition[index] - except IndexError: - pass - - shift = len(partition) - - if index < 0: - index += shift - else: - index -= shift - - raise IndexError + return self.partition.__getitem__(index) def __iter__(self) -> Iterable[nn.Module]: """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition.module + return self.partition.__iter__() def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore """:class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't @@ -501,7 +296,7 @@ def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"), requires_grad=True, ) - output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) + output = PipelinedBackwardPass.apply(output, batches, phony) else: output = microbatch.gather(batches) else: @@ -523,9 +318,8 @@ def back_helper(self, output: List[microbatch.Batch]) -> None: class PipelinedBackwardPass(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors: + def forward(ctx, input: TensorOrTensors, batches, phony) -> TensorOrTensors: ctx.batches = batches - ctx.retain_graph = retain_graph return input @staticmethod @@ -536,7 +330,7 @@ def backward(ctx, *grads) -> Tuple: for grad, batch in reversed(list(zip(grad_batches, ctx.batches))): for t in batch: t.retain_grad() - torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph) + torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=True) with torch.no_grad(): if ctx.batches[0].atomic: diff --git a/fairscale/nn/pipe/multiprocess_pipeline.py b/fairscale/nn/pipe/multiprocess_pipeline.py index 1b9e091fc..173027c11 100644 --- a/fairscale/nn/pipe/multiprocess_pipeline.py +++ b/fairscale/nn/pipe/multiprocess_pipeline.py @@ -30,7 +30,6 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks -from .async_schedule import ModuleWrapper from .checkpoint import Checkpointing from .messages import MakeTransport, Transport from .microbatch import Batch @@ -162,7 +161,7 @@ class MultiProcessPipeline: def __init__( self, - partitions: List[ModuleWrapper], + partition: nn.Sequential, skip_layout: SkipLayout, checkpoint_stop: int, group: torch.distributed.ProcessGroup, @@ -171,7 +170,7 @@ def __init__( input_device: Union[None, int, str, torch.device] = None, final_stage: bool = False, ) -> None: - self.partitions = partitions + self.partition = partition self.skip_layout = skip_layout self.__checkpoint_stop = checkpoint_stop self.group = group @@ -187,7 +186,7 @@ def __init__( @property def checkpoint_stop(self) -> int: # Disable checkpointing if in eval mode. - training = self.partitions[0].module.training + training = self.partition.training if not training: return 0 return self.__checkpoint_stop @@ -208,15 +207,12 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N schedule = [(i, self.group.rank()) for i in range(m)] for i, j in schedule: - assert len(self.partitions) == 1 - partition = self.partitions[0] - if self.group.rank() != 0: batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) else: batch = batches[i] - task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers) + task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers) batches[i] = self.execute_task(task, i, skip_trackers) diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py index a9a803fab..f7d7f37b6 100644 --- a/fairscale/nn/pipe/rpc.py +++ b/fairscale/nn/pipe/rpc.py @@ -14,13 +14,12 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from .async_pipe import AsyncPipe -from .multiprocess_pipe import MultiProcessPipe from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 -PipeModel: MultiProcessPipe +PipeModel: AsyncPipe PipeResult: TensorOrTensors @@ -72,7 +71,7 @@ def backward(ctx, *grad): return (None, None, None, None, None, None) -def callback_with_model(callback: Callable[[Any, MultiProcessPipe], None], ctx: Any) -> None: +def callback_with_model(callback: Callable[[Any, AsyncPipe], None], ctx: Any) -> None: try: group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group set_device_based_on_group(group) @@ -121,7 +120,7 @@ def _foreach_worker(self, callback: Callable, args: Any = None) -> None: futures = [f.wait() for f in futures] def foreach_worker( - self, callback: Callable[[Any, MultiProcessPipe], None], ctx: Any = None, *, include_self: bool = False + self, callback: Callable[[Any, AsyncPipe], None], ctx: Any = None, *, include_self: bool = False ) -> None: """Call `callback` on each worker with the `ctx` and model local to that worker. e.g. @@ -197,7 +196,7 @@ def final_stage(self) -> bool: @staticmethod def _recv_result( - model: MultiProcessPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage + model: AsyncPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage ) -> TensorOrTensors: group = get_pipeline_parallel_group() set_device_based_on_group(group) @@ -245,7 +244,7 @@ def _register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: set_device_based_on_group(group) kwargs["group"] = group kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) - model = MultiProcessPipe(*args, **kwargs) + model = AsyncPipe(*args, **kwargs) model.cuda() global PipeModel PipeModel = model diff --git a/fairscale/optim/__init__.py b/fairscale/optim/__init__.py index 9fe42fb6f..d2a5bc7af 100644 --- a/fairscale/optim/__init__.py +++ b/fairscale/optim/__init__.py @@ -8,7 +8,7 @@ """ import logging -from .adascale import AdaScale +from .adascale import AdaScale, AdaScaleWrapper from .oss import OSS try: diff --git a/fairscale/optim/adascale.py b/fairscale/optim/adascale.py index 3b2b81350..8df8c4184 100644 --- a/fairscale/optim/adascale.py +++ b/fairscale/optim/adascale.py @@ -32,13 +32,18 @@ # POSSIBILITY OF SUCH DAMAGE. import functools -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type import numpy as np import torch from torch.autograd import Variable import torch.distributed as dist -from torch.optim import Optimizer +from torch.optim import SGD, Optimizer + +if TYPE_CHECKING: # pragma: no cover + from torch.optim.optimizer import _params_t +else: + _params_t = Any class AdaScale(Optimizer): @@ -582,3 +587,47 @@ def set_num_gradients_to_accumulate(self, num_gradients_to_accumulate: int, upda # When effective world size is large enough, smoothing is probably # not needed, so the smoothing factor is 0. self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0) + + +class AdaScaleWrapper(AdaScale): + """ + A thin wrapper for AdaScale so that the constructor resembles a + standard optimizer. This allows it to work with other Optimizer + Wrappers, like `OSS`. + + .. warn:: + OSS(AdaScaleWrapper) (i.e. OSS wrapping AdaScale) resulting in each + rank's AdaScale operates on different set of parameters. They + will get different gain values and it is unclear how to adjust + effective step size in that case. We have not validated effectiveness + or benefit in this case. + + OTOH, AdaScale(OSS) (i.e. AdaScale wrapping OSS) is recommended + and is numerically identical to AdaScale without OSS. Since + AdaScale doesn't incur per-parameter state, the memory benefit + of OSS is still the same. + + Args: + params (list of tensors): + parameters to be optimized + optim (class subtyping torch.optim.Optimizer): + a optimizer class to be wrapped. + additional_optim_args (argument dict): + keyward arguments to the `optim` class above. + + The rest params are in-sync with the `AdaScale` class above. + """ + + def __init__( + self, + params: _params_t, + world_size: Optional[int] = None, + scale: Optional[float] = None, + smoothing: float = None, + num_gradients_to_accumulate: int = 1, + debias_ewma: bool = True, + optim_cls: Type[Optimizer] = SGD, + **additional_optim_args: Any, + ): + optim_obj = optim_cls(params, **additional_optim_args) + super().__init__(optim_obj, world_size, scale, smoothing, num_gradients_to_accumulate, debias_ewma) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index a51f18ebf..cdf9f6425 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -5,11 +5,10 @@ from collections import OrderedDict, deque import copy -import itertools from itertools import chain import logging from math import inf -from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union import torch import torch.distributed as dist @@ -52,9 +51,7 @@ class OSS(Optimizer): group (group): torch.distributed group (default: group.WORLD) broadcast_buffer_size (int): - the max size of the buffer used to batch the small parameter tensors, in number of elements (default 16M). - this will not impact the long term memory consumption, but the peak memory can be impacted by the moment - when the buffers are allocated and the bucketed params have not yet been relocated to them. + (deprecated) used to cap the size of the broadcast buffers, not being used anymore. """ #: The optimizer used for a given shard @@ -67,7 +64,7 @@ def __init__( params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, - broadcast_buffer_size: int = 2 ** 24, + broadcast_buffer_size: int = -1, **default: Any, ): @@ -80,6 +77,9 @@ def __init__( self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self._param_rank: Dict[torch.Tensor, int] = {} self._partition_parameters: List[List[dict]] = [] + self._index_to_param: Dict[int, torch.Tensor] = {} + self._param_to_index: Dict[int, int] = {} + self._local_params: Optional[List[torch.Tensor]] = None # Build the wrapped optimizer, responsible for a shard of the params self.group = group if group is not None else dist.group.WORLD @@ -99,12 +99,9 @@ def __init__( # Current default device is set by the parameters allocated to this rank self._device = list(self.per_device_params.keys())[0] - self.buckets: Dict[torch.device, List[torch.Tensor]] = {} - self.buffer_max_size = broadcast_buffer_size - - self.should_bucket_param: List[bool] = [] self.work_handles: Deque[Workhandle] = deque() - self._setup_bucket_strategy() + self.buckets: Dict[torch.device, List[torch.Tensor]] = {} + self._setup_flat_buffers() # Partition helpers def partition_parameters(self) -> List[List[dict]]: @@ -142,6 +139,41 @@ def partition_parameters(self) -> List[List[dict]]: return self._partition_parameters + @property + def local_params(self) -> List[torch.Tensor]: + """ Iterable which goes through the parameters that this rank owns + """ + if self._local_params is None: + self._local_params = list( + chain( + *[ + list(filter(lambda x: x.grad is not None, device_params[self.rank])) + for device_params in self.per_device_params.values() + ] + ) + ) + + # Make sure that the iterator is not consumed, only expose a copy + return self._local_params + + @property + def index_to_param(self) -> Dict[int, torch.Tensor]: + """ Hash table in between parameter indices in the global optimizer scheme, and the actual params + """ + if len(self._index_to_param) == 0: + self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))} + + return self._index_to_param + + @property + def param_to_index(self) -> Dict[int, int]: + """ Hash table in between parameter indices in the global optimizer scheme, and the actual params + """ + if len(self._param_to_index) == 0: + self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))} + + return self._param_to_index + @property def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]: """Sorted list of all the params, first per device then per rank. @@ -191,7 +223,7 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> .. note: Any extra parameter is passed to the base optimizer as-is""" # Sync oss param_groups attributes in case they've been updated by a scheduler. - self._sync_param_groups() + OSS._sync_param_groups(self.param_groups, self.optim.param_groups) # Run the optimizer step on this shard only: if closure is not None: @@ -203,7 +235,7 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> self._broadcast_params() # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups - self._sync_param_groups(local_to_global=True) + OSS._sync_param_groups(self.optim.param_groups, self.param_groups) return loss @@ -236,25 +268,16 @@ def clip_grad_norm( max_norm = float(max_norm) norm_type = float(norm_type) - # Filter out the grad-less params, concatenate params from all devices - local_params = itertools.chain( - *[ - list(filter(lambda x: x.grad is not None, device_params[self.rank])) - for device_params in self.per_device_params.values() - ] - ) - # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism. # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # 'model_parallel' flag is set in Megatron-LM: # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 - if filter_params_fn is not None: - local_params = filter_params_fn(local_params) + local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params # Compute the norm on this grad set, # then sync all the norms from all ranks if norm_type == inf: - total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore + total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # all reduce over data parallel and model parallel workers dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) else: @@ -280,26 +303,13 @@ def clip_grad_norm( return total_norm # State dict interfaces - def local_state_dict(self) -> dict: - """Gets this rank's state_dict. - - Returns: - The state of the optimizer as a :class:`dict`. - It contains two entries: - - * state - a dict holding current optimization state. Its content - differs between optimizer classes. - * param_groups - a dict containing all parameter groups - """ - return self.optim.state_dict() - def consolidate_state_dict(self, recipient_rank: int = 0) -> None: """Update the consolidated state_dict list, one per rank. .. warning: This needs to be called on all replicas""" # Sync lr and other attributes in case its been updated - self._sync_param_groups() + OSS._sync_param_groups(self.param_groups, self.optim.param_groups) if self.rank == recipient_rank: # Pull the sharded state from all the other replicas @@ -310,12 +320,104 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: # Acknowledge broadcasts, and send this rank's shard when needed self._broadcast_state_dict() + def local_state_dict(self) -> dict: + """ .. deprecated:: 0.1.5 + + Returns this rank's state_dict as a :class:`dict` which contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + + * param_groups - a dict containing all parameter groups + + .. warning: This does not represent the optimizer state dict, only a shard. + """ + return self.optim.state_dict() + + def state_dict(self) -> Dict[str, Any]: + """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the + sharded properties are not exposed. It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + + * param_groups - a dict containing all parameter groups + + .. warning: + If the state has not been consolidated, this returns a shard's worth, not the global state. + + .. warning: + Returning the global state is limited to the replica which was responsible for the consolidation. + The state may also not be up to date, depending on when `consolidate_state_dict` was last called. + """ + + if len(self._all_states) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. \ + Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state" + ) + + # Unify the shard states and the state that pytorch would expect, given the model. + # Indexation needs several redirections, since each shard only knows a limited scope of the model + # - get the pytorch compliant parameter indexing + state_dict = super().state_dict() + + # - go through the per-shard states, which are all indexed locally + for rank, s in enumerate(self._all_states): + # -- match the local indexing and the global partition, update the corresponding saved state globally + for local_pg, global_pg in zip(s["param_groups"], self.partition_parameters()[rank]): + local_index_to_param_id = { + i_param: id(global_pg["params"][i]) for i, i_param in enumerate(local_pg["params"]) + } + + for local_param_index in local_pg["params"]: + # Update the state, if any + if local_param_index in s["state"].keys(): + global_id = self.param_to_index[local_index_to_param_id[local_param_index]] + state_dict["state"][global_id] = s["state"][local_param_index] + + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Restore the global parameter groups as well as the shard. + + Arguments: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict` + """ + + # NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time + # we work around that here by using the fact that the params are ordered as in the param_groups + + for i_param, (key, value) in enumerate(state_dict["state"].items()): + param = self.index_to_param[i_param] + + # Populate the sharded optimizer state on the fly + if self.param_to_rank[param] != self.rank: + state_dict["state"][key] = None + + if key in self.index_to_param: + param = self.index_to_param[i_param] + + # Only add this state to the sharded optimizer if it owns this param + for pg in self.optim.param_groups: + if id(param) in [id(p) for p in pg["params"]]: + self.optim.state[param] = recursive_copy_to_device( + value, non_blocking=True, device=param.device + ) + + super().load_state_dict(state_dict) + + # Sync with the optimizer param groups + OSS._sync_param_groups(state_dict["param_groups"], self.param_groups) + OSS._sync_param_groups(self.param_groups, self.optim.param_groups) + def _broadcast_state_dict(self) -> None: """Broadcast this rank's state shard, discard others""" # Default to CPU space to gain some memory headroom local_cpu_state = recursive_copy_to_device( - self.local_state_dict(), non_blocking=True, device=torch.device("cpu") + self.optim.state_dict(), non_blocking=True, device=torch.device("cpu") ) # Tensor cannot be really empty, even if its size is meaningless @@ -350,7 +452,7 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]: if rank == self.rank: logging.debug("Saving self state") all_states.append( - recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu")) + recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")) ) # Sync with other replicas @@ -378,103 +480,6 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]: return all_states - def state_dict(self) -> Dict[str, Any]: - """Return the last known global optimizer state, which consist of a list of the shards. - - .. warning: - If the state has not been consolidated, this returns a shard's worth, not the global state. - - .. warning: - Returning the global state is limited to the replica which was responsible for the consolidation. - The state may also not be up to date, depending on when `consolidate_state_dict` was last called. - """ - - if len(self._all_states) == 0: - logging.warning("Optimizer state has not been consolidated. Returning the local state") - logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state") - state_dict = self.local_state_dict() - state_dict["local_state_dict"] = True - return state_dict - - # Flatten the param_groups, save the partition which logs the rank <> shard correspondence - partition: List[Tuple[int, int]] = [] - param_groups: List[Dict[Any, Any]] = [] - - start = 0 - for i, s in enumerate(self._all_states): - param_groups.extend(s["param_groups"]) - end = start + len(s["param_groups"]) - partition.append((start, end)) - start = end - - return { - "state": [s["state"] for s in self._all_states], - "param_groups": param_groups, - "partition": partition, - "local_state_dict": False, - } - - @staticmethod - def rank_local_state_dict(rank: int, state_dict: dict) -> dict: - """Returns the local_state_dict for a given rank. - - Arguments: - rank (int): rank to get local_state_dict for - state_dict (dict): global state_dict - """ - param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]] - return {"state": state_dict["state"][rank], "param_groups": param_groups} - - def load_local_state_dict(self, state_dict: dict) -> None: - """Loads this rank's state_dict. - - .. warning: This is not meant to load the global state dict. - """ - - self.optim.load_state_dict(state_dict) - - # Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706) - # Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268 - groups = self.optim.param_groups - saved_groups = state_dict["param_groups"] - id_map = { - old_id: p - for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups))) - } - for k, v in state_dict["state"].items(): - if k in id_map: - param = id_map[k] - self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device) - - # Restore the global param_groups (the params themselves are already correct) - for global_group, local_group in zip(self.param_groups, groups): - for k, v in local_group.items(): - if k != "params": - global_group[k] = v - - # Force a re-partitioning, in case the model changed with the new state - self._partition_parameters.clear() - self._per_device_params.clear() - self._param_rank.clear() - - # Update the bucketing strategy accordingly - self._setup_bucket_strategy() - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """Restore the global parameter groups as well as the shard. - - Arguments: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict` - """ - - # Check whether we got a local or global dict - if "local_state_dict" in state_dict and state_dict["local_state_dict"]: - self.load_local_state_dict(state_dict) - else: - # Dispatch this rank's state dictionary to the wrapped shard optimizer - self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict)) - def add_param_group(self, param_group: dict) -> None: """Add a param group to the :class:`Optimizer` s `param_groups`. @@ -491,16 +496,23 @@ def add_param_group(self, param_group: dict) -> None: super().add_param_group(param_group) if not self.in_super_constructor: # Force a re-partitioning - self._partition_parameters.clear() - self._per_device_params.clear() - self._param_rank.clear() + self._clear_cache() + # Update the partition param_groups = self.partition_parameters()[self.rank] if len(param_groups) == len(self.optim.param_groups) + 1: self.optim.add_param_group(param_groups[-1]) # Update the bucketing strategy accordingly - self._setup_bucket_strategy() + self._setup_flat_buffers() + + def _clear_cache(self) -> None: + self._partition_parameters.clear() + self._per_device_params.clear() + self._param_rank.clear() + self._index_to_param.clear() + self._param_to_index.clear() + self._local_params = None @staticmethod def get_global_rank(group: Any, rank: int) -> int: @@ -510,44 +522,24 @@ def get_global_rank(group: Any, rank: int) -> int: global_rank = dist.distributed_c10d._get_global_rank(group, rank) return global_rank - @torch.no_grad() - def _sync_param_groups(self, local_to_global: bool = False) -> None: - """Sync learning rate and other optimizer attributes (needed to support schedulers). - If the global param groups have been altered, and we want to make sure that the - wrapped optimizer uses the up to date version. - Conversely if the wrapped optimizer has new keys, we expose them through the global param groups""" + @staticmethod + def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None: + """Sync learning rate and other optimizer attributes (needed to support schedulers).""" - for global_group, local_group in zip(self.param_groups, self.optim.param_groups): + for source_group, destination_group in zip(source, destination): # Sync everything but the parameters - for k in filter(lambda x: x != "params", local_group.keys()): - if local_to_global: - global_group[k] = local_group[k] - elif k in global_group.keys(): - local_group[k] = global_group[k] + for k in filter(lambda x: x != "params", source_group.keys()): + destination_group[k] = source_group[k] @torch.no_grad() def _broadcast_params(self) -> None: """Helper function to broadcast all the parameters from a given device""" - i_param = 0 last_work_handle = None # Work handles are consumed within this scope, no callback - for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks) - buckets = self.buckets[device] - # Bucket and issue all the async calls - for (src_rank, params), bucket in zip(enumerate(device_params), buckets): + for device in self.buckets.keys(): + for src_rank, bucket in enumerate(self.buckets[device]): global_src_rank = self.get_global_rank(self.group, src_rank) - - # Direct broadcasts only - for param in params: - if not self.should_bucket_param[i_param]: - last_work_handle = dist.broadcast( - tensor=param.data, src=global_src_rank, group=self.group, async_op=True - ) - - i_param += 1 - - # Bucket broadcasts last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True) # Only check on the last handle, they're all inlined on the same CUDA stream @@ -558,7 +550,6 @@ def _consume_work_handles(self) -> None: """Consume all the futures which are tied to this optimizer's buckets. We start from the first/older ones, since they are the most likely to be ready and non-blocking """ - while len(self.work_handles) > 0: work_handle = self.work_handles.popleft() work_handle.handle.wait() @@ -572,52 +563,26 @@ def _try_consume_work_handle(self) -> None: if work_handle.callback is not None: work_handle.callback() - def _setup_bucket_strategy(self) -> None: - """Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered - (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent - over the wire. - - Generating the partition once and for all allows us to save some time at runtime, and to know when all the - network requests have been issued. + def _setup_flat_buffers(self) -> None: + """Make all params which are on the same device and tied to the same rank views of a single buffer. + This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and + `refresh_trainability` is called. """ - # (re) allocate the buckets - # - Get the correct size for the buckets, cannot be bigger than the model - model_size = sum([p.numel() for p in self.param_to_rank.keys()]) - self.bucket_size = min(self.buffer_max_size, model_size) - logging.info( - "Bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format( - self.bucket_size / 2 ** 20, model_size / 2 ** 20 - ) - ) - - # - Allocate one buffer per rank and per device to group the small parameters - for device, per_device in self.per_device_params.items(): - self.buckets[device] = [ - torch.zeros(self.bucket_size, dtype=per_device[0][0].dtype, device=device) - for _ in range(len(per_device)) - ] - - # Devise the bucketing strategy for device, per_rank_params in self.per_device_params.items(): - for dst_rank, params in enumerate(per_rank_params): - offset = 0 + self.buckets[device] = [] - for param in params: - # Criteria to decide whether this parameter is to be bucketed or not: - # - enough room in the bucket - if param.requires_grad and (offset + param.numel()) < self.bucket_size: - self.should_bucket_param.append(True) + for dst_rank, params in enumerate(per_rank_params): + if len(params) > 0: + trainable_params = list(filter(lambda x: x.requires_grad, params)) + buffer_size = sum(map(lambda x: x.numel(), trainable_params)) + self.buckets[device].append(torch.empty(buffer_size, dtype=params[0].dtype, device=device)) + offset = 0 + for param in trainable_params: # This parameter becomes a view of the bucket offset_next = offset + param.numel() self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten()) param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data) - offset = offset_next - else: - self.should_bucket_param.append(False) - - # Resize the bucket to remove lost space in the end - self.buckets[device][dst_rank].resize_(offset) diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 7c44e33a7..85530ec12 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -31,6 +31,7 @@ import multiprocessing import os import random +import sys import tempfile from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -59,6 +60,15 @@ not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required" ) +skip_if_py38 = pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor == 8, reason="Python3.8 is skipped" +) + +skip_if_py39_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9, + reason="Python3.9 wo CUDA is skipped", +) + _, filename_mpi = tempfile.mkstemp() diff --git a/pyproject.toml b/pyproject.toml index 363c814e5..5395cca7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "torch >= 1.4.0", "wheel >= 0.30.0" ] -build-backend = "setuptools.build_meta" +build-backend = "setuptools.build_meta:__legacy__" [tool.black] line-length = 120 diff --git a/stubs/torch/utils/checkpoint.pyi b/stubs/torch/utils/checkpoint.pyi index f37a23ddd..003be48ae 100644 --- a/stubs/torch/utils/checkpoint.pyi +++ b/stubs/torch/utils/checkpoint.pyi @@ -1,8 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Tuple +from typing import Any, Iterable, Tuple from .. import Tensor from torch.nn.modules.module import Module def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... def checkpoint(function: Module, *args, **kwargs): ... +def check_backward_validity(inputs: Iterable[Any]): ... diff --git a/tests/nn/data_parallel/test_sharded_ddp.py b/tests/nn/data_parallel/test_sharded_ddp.py index 2277067a7..aa26b2323 100644 --- a/tests/nn/data_parallel/test_sharded_ddp.py +++ b/tests/nn/data_parallel/test_sharded_ddp.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. """ -Testing OssDdp class. +Testing ShardedDDP """ from contextlib import suppress @@ -14,6 +14,7 @@ import numpy as np import torch +from torch.cuda.amp import GradScaler as TorchGradScaler import torch.distributed as dist import torch.multiprocessing as mp from torch.nn import Linear, Sequential @@ -21,7 +22,8 @@ from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim import OSS -from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_single_gpu +from fairscale.optim.grad_scaler import ShardedGradScaler +from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu def run_one_step(rank, world_size, backend, device, temp_file_name): @@ -112,16 +114,17 @@ def run_test(backend, device, world_size=2): mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) -def test_step_on_cpu(): - run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4) - - @skip_if_no_cuda @skip_if_single_gpu -def test_step_on_gpu(): +def test_step_gpu(): run_test(backend=dist.Backend.NCCL, device=torch.device("cuda")) +@skip_if_py38 +def test_step_cpu(): + run_test(backend=dist.Backend.GLOO, device=torch.device("cpu")) + + def run_ddp_parity(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) @@ -131,52 +134,90 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): torch.manual_seed(rank) np.random.seed(rank) - # Any model works. Add one different buffer per rank - model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) - model.register_buffer("test_buffer", torch.ones((1)) * rank) - model.to(device) - - sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) - sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) - - ddp_model_single = copy.deepcopy(model) - ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99) - ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) + def check_parity(amp: bool): + # Any model works. Add one different buffer per rank + model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) + model.register_buffer("test_buffer", torch.ones((1)) * rank) + model.to(device) - def check_same_model_params(): - for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): - for p, ddp_p in zip(pg["params"], ddp_pg["params"]): - assert torch.allclose( - p, ddp_p, atol=1e-3 - ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" + sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) + sharded_ddp_model = ShardedDataParallel( + module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True + ) - for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): - assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP" + ddp_model_single = copy.deepcopy(model) + ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99) + ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) - # The model should be synchronized in between the ranks at construction time, check that - check_same_model_params() + ddp_scaler = TorchGradScaler() if amp else None + sharded_ddp_scaler = ShardedGradScaler() if amp else None - # The models should stay the same in between the ranks - for i in range(20): - input_tensor = torch.rand((64, 2)).to(device) + def check_same_model_params(): + for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): + for p, ddp_p in zip(pg["params"], ddp_pg["params"]): + assert torch.allclose( + p, ddp_p, atol=1e-3 + ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" - def closure_ddp(input_tensor=input_tensor): - ddp_optimizer.zero_grad() - ddp_loss = ddp_model(input_tensor).abs().sum() - ddp_loss.backward() - return ddp_loss - - def closure_sharded(input_tensor=input_tensor): - sharded_optimizer.zero_grad() - sharded_loss = sharded_ddp_model(input_tensor).abs().sum() - sharded_loss.backward() - return sharded_loss - - _ = ddp_optimizer.step(closure=closure_ddp) - _ = sharded_optimizer.step(closure=closure_sharded) + for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): + assert torch.allclose( + b, ddp_b, atol=1e-3 + ), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}" + # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() + # The models should stay the same in between the ranks + for i in range(10): + input_tensor = torch.rand((64, 2)).to(device) + + def closure_ddp(input_tensor=input_tensor): + ddp_optimizer.zero_grad() + + if ddp_scaler is not None: + with torch.cuda.amp.autocast(): + ddp_loss = ddp_model(input_tensor).abs().sum() + ddp_scaler.scale(ddp_loss).backward() + else: + ddp_loss = ddp_model(input_tensor).abs().sum() + ddp_loss.backward() + return ddp_loss + + def closure_sharded(input_tensor=input_tensor): + sharded_optimizer.zero_grad() + + if sharded_ddp_scaler is not None: + with torch.cuda.amp.autocast(): + sharded_loss = sharded_ddp_model(input_tensor).abs().sum() + sharded_ddp_scaler.scale(sharded_loss).backward() + else: + sharded_loss = sharded_ddp_model(input_tensor).abs().sum() + sharded_loss.backward() + return sharded_loss + + # Step/scale both + if ddp_scaler is not None: + _ = closure_ddp(input_tensor) + ddp_scaler.step(ddp_optimizer) + ddp_scaler.update() + else: + ddp_optimizer.step(closure=closure_ddp) + + if sharded_ddp_scaler is not None: + _ = closure_sharded(input_tensor) + sharded_ddp_scaler.step(sharded_optimizer) + sharded_ddp_scaler.update() + else: + sharded_optimizer.step(closure=closure_sharded) + + check_same_model_params() + + check_parity(amp=False) + + # Catch a version of pytorch which would not support AMP + if hasattr(torch.cuda.amp, "autocast"): + check_parity(amp=True) + dist.destroy_process_group() @@ -341,6 +382,35 @@ def test_random_attributes(): dist.destroy_process_group() +def run_test_device_change(rank, world_size, backend, device, temp_file_name): + # Check that the wrapped module can change devices + + url = "file://" + temp_file_name + dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) + + model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() + optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) + ddp_model = ShardedDataParallel(model, optimizer) + ddp_model.to(device) + + inputs = torch.rand((10, 2), device=device) + outputs = ddp_model(inputs) # assert if the module has not been changed properly + loss = outputs.norm().backward() + + dist.destroy_process_group() + + +@skip_if_no_cuda +@skip_if_single_gpu +def test_device_change(): + # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs + world_size = 2 + backend = "gloo" + temp_file_name = tempfile.mkstemp()[1] + device = "cuda" + mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) + + def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) diff --git a/tests/nn/misc/test_checkpoint_activations.py b/tests/nn/misc/test_checkpoint_activations.py new file mode 100644 index 000000000..84fdc13d5 --- /dev/null +++ b/tests/nn/misc/test_checkpoint_activations.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test fairscale.nn.misc.checkpoint_activations +""" + +import unittest + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper + + +class Model(nn.Module): + def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs): + super().__init__() + torch.manual_seed(0) + self.use_pytorch_checkpoint = use_pytorch_checkpoint + self.ffn = nn.Sequential( + nn.Linear(32, 128), + # add a Dropout layer to test RNG save/restore + nn.Dropout(p=0.5), + nn.Linear(128, 32), + ) + if use_fairseq_checkpoint: + self.ffn = checkpoint_wrapper(self.ffn, **kwargs) + self.out = nn.Linear(32, 1) + + def forward(self, x): + if self.use_pytorch_checkpoint: + x = checkpoint(self.ffn, x) + else: + x = self.ffn(x) + return self.out(x) + + +class TestComparisonToPyTorch(unittest.TestCase): + def _test_checkpoint_wrapper(self, device, log_memory_usage=False): + def get_loss_and_gnorm(model): + torch.manual_seed(1) + input = torch.rand(2, 16, 32).requires_grad_(True).to(device) + model.zero_grad() + loss = model(input).sum() + loss.backward() + gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])) + return {"loss": loss, "gnorm": gnorm} + + model = Model().to(device) + no_cpt = get_loss_and_gnorm(model) + + model = Model(use_pytorch_checkpoint=True).to(device) + pyt_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True).to(device) + fairseq_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) + fairseq_cpt_offload = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) + + def test_checkpoint_wrapper_cpu(self): + self._test_checkpoint_wrapper(device=torch.device("cpu")) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_checkpoint_wrapper_cuda(self): + self._test_checkpoint_wrapper(device=torch.device("cuda")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/nn/model_parallel/test_layers.py b/tests/nn/model_parallel/test_layers.py index 0f28f8c63..6d915a7a6 100644 --- a/tests/nn/model_parallel/test_layers.py +++ b/tests/nn/model_parallel/test_layers.py @@ -443,7 +443,6 @@ def forward_model(model_, target, step=False): worker_map=worker_map, input_device=torch.cuda.current_device(), chunks=chunk_size, - pipelined_backward=True, ).cuda() torch.distributed.barrier() pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group()) diff --git a/tests/nn/pipe_process/skip/__init__.py b/tests/nn/pipe_process/skip/__init__.py deleted file mode 100644 index 5a5df5f80..000000000 --- a/tests/nn/pipe_process/skip/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -# Copyright 2019 Kakao Brain -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/nn/pipe_process/skip/test_gpipe.py b/tests/nn/pipe_process/skip/test_gpipe.py deleted file mode 100644 index 2f5a1cdc8..000000000 --- a/tests/nn/pipe_process/skip/test_gpipe.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -# Copyright 2019 Kakao Brain -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest -import torch -from torch import nn - -from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe -from fairscale.nn.pipe.skip import pop, skippable, stash -from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange -from fairscale.utils.testing import get_worker_map, torch_spawn - - -@torch_spawn([3]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) -@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") -def x1to3(balance, checkpoint, pipe_class): - torch.manual_seed(0) - - if pipe_class == AsyncPipe and len(balance) > 1: - print(f"skipping yarg") - pytest.skip("Skip tensors NYI for AsyncPipe") - - @skippable(stash=["1to3"]) - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - yield stash("1to3", input) - output = self.conv(input) - return output - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - output = self.conv(input) - return output - - @skippable(pop=["1to3"]) - class Layer3(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - skip_1to3 = yield pop("1to3") - output = self.conv(input) + skip_1to3 - return output - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = pipe_class( - model, - balance, - chunks=3, - checkpoint=checkpoint, - input_device=torch.cuda.current_device(), - worker_map=get_worker_map(), - pipelined_backward=False, - ).cuda() - - input = torch.rand(30, 3, 224, 224, requires_grad=True).cuda() - input.retain_grad() - output = model(input) - if model.group.rank() == len(balance) - 1: - loss = output.mean() - loss.backward() - elif model.group.rank() < len(balance) - 1: - model.back_helper(output) - if model.group.rank() == len(balance) - 1: - # TODO(tom) the single-process test uses 2e-1 but for some reason - # mutli-process is more noisy, need to investigate why - assert torch.allclose(output.norm(), torch.tensor(1039.0).cuda(), atol=4e-1) - if model.group.rank() == 0: - assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053).cuda()) - - torch.distributed.barrier() - - -@torch_spawn([2]) -@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) -@pytest.mark.skip(reason="flaky test") -def none_skip(pipe_class): - if pipe_class == AsyncPipe: - pytest.skip("Skip tensors NYI for AsyncPipe") - - @skippable(stash=["none"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("none", None) - return input - - @skippable(pop=["none"]) - class Pop(nn.Module): - def forward(self, input): - none = yield pop("none") - assert none is None - return input - - model = nn.Sequential(Stash(), Pop()) - model = pipe_class( - model, [1, 1], worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5, - ).cuda() - - input = torch.rand(10, requires_grad=True).cuda() - input.retain_grad() - output = model(input) - - def assert_grad_fn_is_not_portal(grad_fn, visited=set()): - if grad_fn in visited or grad_fn is None: - return - - assert not isinstance(grad_fn, PortalBlue._backward_cls) - assert not isinstance(grad_fn, PortalCopy._backward_cls) - assert not isinstance(grad_fn, PortalOrange._backward_cls) - - visited.add(grad_fn) - for next_grad_fn, _ in grad_fn.next_functions: - assert_grad_fn_is_not_portal(next_grad_fn, visited) - - if model.group.rank() == 1: - assert_grad_fn_is_not_portal(output.grad_fn) - - output.sum().backward() - else: - model.back_helper(output) - assert input.grad.mean().item() == 1 - - -@torch_spawn([2]) -@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) -def lazy_skippable_error(pipe_class): - """Using skippable layers in combination with lazy construction is currently - not supported, check that it raises an Exception""" - - @skippable(stash=["1to3"]) - class Layer1(nn.Linear): - pass - - @skippable(pop=["1to3"]) - class Layer3(nn.Linear): - pass - - model = [ - LazyModule(lambda: Layer1(10, 10)), - LazyModule(lambda: nn.Linear(10, 10)), - LazyModule(lambda: Layer3(10, 10)), - ] - - with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): - pipe_class( - model, [2, 1], worker_map=get_worker_map(), - ) diff --git a/tests/nn/pipe_process/skip/test_leak.py b/tests/nn/pipe_process/skip/test_leak.py deleted file mode 100644 index e8370317a..000000000 --- a/tests/nn/pipe_process/skip/test_leak.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -# Copyright 2019 Kakao Brain -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest -import torch -from torch import nn - -from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe, is_checkpointing, is_recomputing -from fairscale.nn.pipe.skip import pop, skippable, stash -from fairscale.nn.pipe.skip.tracker import current_skip_tracker -from fairscale.utils.testing import get_worker_map, torch_spawn - - -@skippable(stash=["skip"]) -class Stash(nn.Module): - def forward(self, input): - yield stash("skip", input) - return input - - -@skippable(pop=["skip"]) -class Pop(nn.Module): - def forward(self, input): - skip = yield pop("skip") - return input + skip - - -@torch_spawn([2]) -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) -@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def delete_portal_tensor(train, checkpoint, pipe_class): - # Without checkpointing: - # +- Stash --+ +--- Pop ----+ - - - layers - # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function - # +----------+ +------------+ - # - # With checkpointing: - # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ - # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | - # +----------+ +------------+ +------------+ +----------+ - - if pipe_class == AsyncPipe: - pytest.skip("Skip tensors NYI for AsyncPipe") - - def portal_tensor_life_is(tensor_life, skip_tracker=None): - if skip_tracker is None: - skip_tracker = current_skip_tracker() - - # Get the current portal. - portal = list(skip_tracker.portals.values())[0] - - if tensor_life == 0: - return portal.tensor_life == 0 and portal.tensor is None - else: - return portal.tensor_life == tensor_life and portal.tensor is not None - - # Check the portal tensor after 'Stash'. - stash_ = Stash() - - @stash_.register_forward_hook - def check_portal_tensor_after_stash(*_): - if is_checkpointing(): - assert portal_tensor_life_is(2) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(1) - - pop_ = Pop() - - @pop_.register_forward_hook - def check_portal_tensor_after_pop(*_): - if is_checkpointing(): - assert portal_tensor_life_is(1) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(0) - - class NoPortalTensorAtBackward(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.skip_tracker = current_skip_tracker() - return input.detach() - - @staticmethod - def backward(ctx, grad): - assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) - return grad - - def forward(self, input): - return self.F.apply(input) - - model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = pipe_class(model, balance=[2, 1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input) - if model.group.rank() == 1: - output.norm().backward() - else: - model.back_helper(output) - else: - model.eval() - with torch.no_grad(): - model(input) - - torch.distributed.barrier() diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index 13b8a636f..b14a1befb 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -32,7 +32,7 @@ initialize_model_parallel, ) from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe -from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version +from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version @torch_spawn([2]) @@ -109,16 +109,9 @@ def mpi(): @torch_spawn([1]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def public_attrs(pipe_class): - class MyString: - def __init__(self, value): - self.value = value - - def __str__(self): - return self.value - model = nn.Sequential(nn.Linear(1, 1)) - pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"),) + pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42, checkpoint="always",) assert pipe.balance == [1] assert pipe.chunks == 42 @@ -266,15 +259,9 @@ def count_grad_fn(grad_fn, name, visited=set()): model = nn.Sequential(nn.Linear(1, 1)) input = torch.rand(2, 1) - always = pipe_class( - model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always", pipelined_backward=False, - ) - except_last = pipe_class( - model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last", pipelined_backward=False, - ) - never = pipe_class( - model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never", pipelined_backward=False, - ) + always = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always",) + except_last = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last",) + never = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never",) always_output = always(input) except_last_output = except_last(input) @@ -313,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class): @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def checkpoint_eval(pipe_class): model = nn.Sequential(nn.Linear(1, 1)) - model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) + model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): @@ -350,9 +337,7 @@ def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = pipe_class( - model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False, - ) + model = pipe_class(model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always",) input = torch.rand(1, requires_grad=True) output = model(input) @@ -381,8 +366,8 @@ def hook(module, input, output): nonlocal latent latent = output - partition = model.partitions[0] - partition.module.register_forward_hook(hook) + partition = model.partition + partition.register_forward_hook(hook) with torch.no_grad(): model(input) @@ -463,7 +448,7 @@ def forward(self, a_and_b): return (self.fc_a(a), self.fc_b(b)) model = nn.Sequential(Two()) - model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) + model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,) a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) @@ -489,7 +474,7 @@ def forward(self, only_a): return (self.fc(a),) model = nn.Sequential(One()) - model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) + model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,) a = torch.rand(10, 1, requires_grad=True) @@ -631,14 +616,12 @@ def partitions(pipe_class): model = nn.Sequential(a, b) model = pipe_class(model, [1, 1], worker_map=get_worker_map()) - assert isinstance(model.partitions, list) - assert len(model) == 1 - assert isinstance(model.partitions[0].module, nn.Sequential) + assert isinstance(model.partition, nn.Sequential) if model.group.rank() == 0: - assert "0.0.weight" in model.state_dict() + assert model[0].weight == a.weight else: - assert "0.1.weight" in model.state_dict() + assert model[0].weight == b.weight @torch_spawn([2]) @@ -684,6 +667,7 @@ def empty_module(pipe_class): @torch_spawn([2]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) +@pytest.mark.skip(reason="TODO(msb) handle named_children") def named_children(pipe_class): a = nn.Linear(1, 1) b = nn.Linear(1, 1) @@ -706,15 +690,11 @@ def named_children(pipe_class): @torch_spawn([1]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def recommend_auto_balance(pipe_class): - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): - # balance is required - pipe_class(nn.Sequential()) - - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): + with pytest.raises(ValueError): # module and sum of balance have differen length (module: 0, sum of balance: 1) pipe_class(nn.Sequential(), [1]) - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): + with pytest.raises(ValueError): # module and sum of balance have different length (module: 2, sum of balance: 1) pipe_class(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) @@ -776,7 +756,7 @@ def __init__(self, module): @torch_spawn([4]) -@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) +@pytest.mark.parametrize("pipe_class", [MultiProcessPipe]) def pipelined_backward(pipe_class): model = nn.Sequential(nn.ReLU(), nn.ReLU()) @@ -805,174 +785,3 @@ def async_event_loop(): if pipe.final_stage: loss = output.mean() loss.backward() - - -@torch_spawn([4]) -def reuse_lazy(): - if False: # speed - reused = LazyModule(lambda: nn.Linear(10, 10)) - model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] - pipe = AsyncPipe(model, [3, 1, 1], worker_map=get_worker_map()) - pipe.eval() - output = pipe(torch.rand(10)) - - print(f"output on {pipe.group.rank()}, {output}") - torch.distributed.barrier() - - set_random_seed(1234) - # test both foward - reused = nn.Linear(10, 10) - layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - model = nn.Sequential(*layers) - model.eval() - - set_random_seed(1234) - # ensure identical weights but no sharing between model and pipe - reused = nn.Linear(10, 10) - layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - pipe = AsyncPipe(layers, [3, 1, 1], worker_map=get_worker_map()) - pipe.eval() - model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None - inputs = torch.rand(10) - if False: # speed - model_out = model(inputs) - pipe_out = pipe(inputs) - - torch.distributed.barrier() - - if pipe.final_stage: - assert torch.equal(model_out, pipe_out) - - model.train() - pipe.train() - model_out = model(inputs) - pipe_out = pipe(inputs) - if pipe.final_stage: - pipe_loss = pipe_out.mean() - pipe_loss.backward() - - model_loss = model_out.mean() - model_loss.backward() - - model_optimizer.step() - if pipe_optimizer: - pipe_optimizer.step() - - model.eval() - pipe.eval() - model_out = model(inputs) - pipe_out = pipe(inputs) - - print(f"before barrier on {torch.distributed.get_rank()}") - torch.distributed.barrier() - print(f"after barrier on {torch.distributed.get_rank()}") - - if pipe.final_stage: - assert torch.equal(model_out, pipe_out) - - -@torch_spawn([1]) -def instantiate_partition(): - from fairscale.nn.pipe.async_schedule import Location - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = AsyncPipe(model, balance=[1], worker_map=get_worker_map(), chunks=1) - - class FakeGroup: - def __init__(self, rank, size): - self._rank = rank - self._size = size - - def rank(self): - return self._rank - - def size(self): - return self._size - - def check_partitions(model, balance, expected_order, expected_ranks): - """Check the instantiated model matches expectation of order and rank - model: a list of modules or an nn.Sequential - balance: the balance argument to MultiProcessPipe - expected_order: the index of modules in `model` in the order they will - be executed, grouped by nn.Sequential - expected_rank: the rank that each module will be executed on - """ - - invocations = [] - invocation_wrapper = dict() - - # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from - # instantiated model - for rank in range(len(balance)): - instantiated = pipe.instantiate_partition(model, balance, FakeGroup(rank, len(balance))) - for part in instantiated: - assert isinstance(part.module, nn.Sequential) - for inv in part.invocations: - invocations.append(inv) - invocation_wrapper[inv] = part - - modules = [] - prev = None - current = Location(0, 0) - ranks = [] - - for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)): - # Check integrity of Location chain - assert inv.order == order - assert inv.source == prev - assert inv.this == current - prev = inv.this - current = inv.dest - modules.append(list(invocation_wrapper[inv].module.children())) - ranks.append(inv.this.stage) - - # assert len(modules) == len(expected_order) - for left, right in zip(modules, expected_order): - assert len(left) == len(right), f"{right}" - assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}" - - assert ranks == expected_ranks - - reused = nn.Linear(20, 20) - model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - balance = [3, 1, 1] - - check_partitions( - model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2] - ) - - reused2 = nn.Linear(5, 5) - model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()] - balance = [4, 1, 1] - - check_partitions( - model, - balance, - expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]], - expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2], - ) - - reused2 = nn.Linear(5, 5) - model = [ - nn.Linear(10, 10), - reused, - nn.Linear(10, 10), - nn.ReLU(), - reused, - reused2, - nn.ReLU(), - reused, - reused2, - nn.ReLU(), - ] - # 0 1 2 3 1 5 6 1 5 9 - balance = [4, 2, 1] - - check_partitions( - model, - balance, - expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]], - expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2], - ) diff --git a/tests/nn/pipe_process/test_transparency.py b/tests/nn/pipe_process/test_transparency.py index 262dfc60e..9de481584 100644 --- a/tests/nn/pipe_process/test_transparency.py +++ b/tests/nn/pipe_process/test_transparency.py @@ -60,13 +60,13 @@ def zero_grad(parameters): if model.group.rank() == 1: loss = outputs.mean() loss.backward() - grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) + grad_with_pipe = sum_grad(model.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) - grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) + grad_with_pipe = sum_grad(model.partition.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index 1c7b0ad45..7f7cb058e 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -11,7 +11,7 @@ import copy from math import inf import tempfile -from typing import Type, cast +from typing import Any, Type, cast import unittest import numpy as np @@ -22,10 +22,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP import fairscale.optim as optim -from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu +from fairscale.utils.testing import skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") +RECIPIENT_RANK = 1 try: from torch.distributed import broadcast_object_list # noqa @@ -42,6 +43,19 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND): dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) +def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any: + if _torch_broadcast_object: + package = [something_to_sync] + dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD) + package_sync = package[0] + else: + package_sync = optim.utils.broadcast_object( + something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device + ) + + return package_sync + + class TestSingleRank(unittest.TestCase): """ All the following tests do not check for inter-process communication @@ -158,31 +172,11 @@ def step(self): o.step() assert x == torch.tensor([0.9], device=DEVICE) - def test_local_state_dict(self): - x = torch.tensor([1.0], device=DEVICE, requires_grad=True) - o = optim.OSS([x], lr=0.1) - local_state_dict = o.local_state_dict() - o = optim.OSS([x], lr=0.01) - o.load_local_state_dict(local_state_dict) - # We should now be using a lr of 0.1. - assert o.optim.param_groups[0]["lr"] == 0.1 - assert o.param_groups[0]["lr"] == 0.1 - x.backward() - o.step() - assert x == torch.tensor([0.9], device=DEVICE) - def test_implicit_local_state_dict(self): x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = optim.OSS([x], lr=0.1) - local_state_dict = o.state_dict() - o = optim.OSS([x], lr=0.01) - o.load_state_dict(local_state_dict) - # We should now be using a lr of 0.1. - assert o.optim.param_groups[0]["lr"] == 0.1 - assert o.param_groups[0]["lr"] == 0.1 - x.backward() - o.step() - assert x == torch.tensor([0.9], device=DEVICE) + with pytest.raises(RuntimeError): + _ = o.state_dict() def run_test_add_param_group(rank, world_size, tempfile_name): @@ -348,7 +342,10 @@ def test_step_with_closure(): def run_test_sharding(rank, world_size, tempfile_name): dist_init(rank, world_size, tempfile_name) params = [] - for size in [5, 4, 2, 6, 4, 3]: + sizes = [9, 7, 5, 3] + sizes_world = sizes * world_size + + for size in sizes_world: params.append(torch.rand(size, 1)) # Make sure that the params are trainable, enforces size-based partitioning @@ -356,17 +353,17 @@ def run_test_sharding(rank, world_size, tempfile_name): p.requires_grad = True o = optim.OSS(params, lr=0.1) - assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8 + assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == sum(sizes) dist.destroy_process_group() def test_sharding(): - world_size = 3 - if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: - pytest.skip("Not enough GPUs for NCCL-based test") - temp_file_name = tempfile.mkstemp()[1] + world_size = 4 + if torch.cuda.is_available(): + world_size = min(world_size, torch.cuda.device_count()) + _, temp_file_name = tempfile.mkstemp() mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True) @@ -405,18 +402,12 @@ def closure(): # - load it again if rank == reference_rank: optimizer_state_dict = optimizer.state_dict() - assert len(optimizer_state_dict["state"]) == world_size + assert len(optimizer_state_dict["state"]) == len(list(model.parameters())) else: optimizer_state_dict = {} - optim_state = [optimizer_state_dict] - if _torch_broadcast_object: - dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD) - optimizer_state_dict = optim_state[0] - else: - optimizer_state_dict = optim.utils.broadcast_object( - optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device - ) + # distribute to the other ranks + optimizer_state_dict = sync_object_ranks(optimizer_state_dict, reference_rank, device) # Load the optimizer state dict optimizer.load_state_dict(optimizer_state_dict) @@ -436,6 +427,72 @@ def test_collect_shards(): ) +def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): + dist_init(rank, world_size, tempfile_name) + device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE + + # Run a dummy step so that the optimizer state dict exists + batch, input_width, hidden, target_width = 3, 3, 3, 5 + target = torch.rand((batch, target_width), device=device) + inputs = torch.rand((batch, input_width), device=device) + + model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) + model.to(device) + + loss_fn = torch.nn.L1Loss() + loss_fn.to(device) + + optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1) + + def closure(): + optimizer.zero_grad() + output = model(inputs) + loss = loss_fn(output, target) + loss.backward() + return loss + + _ = optimizer.step(closure=closure) + + # Update the optimizer state on the reference rank + optimizer.consolidate_state_dict(recipient_rank=reference_rank) + + # Fetch the state on the reference rank, broadcast to the other ones + if rank == reference_rank: + optimizer_state_dict = optimizer.state_dict() + else: + optimizer_state_dict = {} + + # Run two steps, log the loss + _ = optimizer.step(closure=closure) + reference_loss = optimizer.step(closure=closure) + + # Load the optimizer state dict, rewind the state two steps back + optimizer.load_state_dict(optimizer_state_dict) + + # Run two new steps, log the loss again and check that we get the same + _ = optimizer.step(closure=closure) + test_loss = optimizer.step(closure=closure) + + assert torch.allclose(reference_loss, test_loss) + + dist.destroy_process_group() + + +def test_reproducibility(): + world_size = 2 + temp_file_name = tempfile.mkstemp()[1] + + if torch.cuda.is_available() and torch.cuda.device_count() < world_size: + # Bail out if not enough devices + return + + reference_rank = 0 + + mp.spawn( + run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True, + ) + + def run_test_multiple_groups(rank, world_size, tempfile_name): # Only work with the even ranks, to check that the global_rank indexing is properly used dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo") @@ -507,6 +564,7 @@ def closure(): dist.destroy_process_group(process_group) +@skip_if_py39_no_cuda def test_multiple_groups(): world_size = 6 temp_file_name = tempfile.mkstemp()[1] @@ -574,6 +632,9 @@ def check(norm): print(f"Checking norm {norm}") check(norm) + # Check twice, catch an hypothetic iterator dumb mistake + check(norm) + dist.destroy_process_group() @@ -593,15 +654,17 @@ def test_gradient_clipping(): def run_state_dict_distributed(rank, world_size, tempfile_name): dist_init(rank, world_size, tempfile_name, backend="gloo") + device = torch.device(rank) torch.manual_seed(rank) # make sure that the different rank get different data - # Run a dummy step so that the optimizer state dict exists + # Setup two problems in parallel, we'll make sure that the second track (with save/load) follows the first one(untouched) + # We split the model in two to test the multiple param groups support batch, input_width, hidden, target_width = 3, 20, 10, 5 target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) - model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden),).to(device) + model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden)).to(device) head_oss1 = torch.nn.Linear(hidden, target_width).to(device) model_oss2 = copy.deepcopy(model_oss1) @@ -619,48 +682,36 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) - def run_grad_step(device, model, head, optimizer): - loss_fn = torch.nn.L1Loss() - loss_fn.to(device) + loss_fn = torch.nn.L1Loss().to(device) + def run_grad_step(model, head, optimizer): model.zero_grad() - outputs = head(model(inputs)) - loss = loss_fn(outputs, target) - loss.backward() - - optimizer.step() - optimizer.zero_grad() + def check_equal_models(message: str): + for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): + assert torch.allclose(param1, param2), message - # save and reload without taking any steps - sharded_optimizer2.consolidate_state_dict() - state_dict2 = sharded_optimizer2.state_dict() + # pull the current state, broadcast it to all ranks + sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks + state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {} + state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) - sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) + # re-create a new optimizer from scratch with absurd values, load the previous state + sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.load_state_dict(state_dict2) + check_equal_models("parameters of the two identical models have diverged (before any steps)") # now take a step and check that parameters are equal - # take a step - run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) - run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) - - # check that model parameters are equal - for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): - assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)" - - # take a step - run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) - run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) + run_grad_step(model_oss1, head_oss1, sharded_optimizer1) + run_grad_step(model_oss2, head_oss2, sharded_optimizer2) + check_equal_models("parameters of the two identical models have diverged (after stepping)") - # check that model parameters are equal - for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): - assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before saving)" - - # save the state dict for one model only - sharded_optimizer2.consolidate_state_dict() - state_dict2 = sharded_optimizer2.state_dict() + # save the state dict for one model only, then distribute to the other ranks + sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks + state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {} + state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) # Check that the pulled state and the .param_groups attribute are in sync for replica in range(len(state_dict2["param_groups"])): @@ -669,18 +720,14 @@ def run_grad_step(device, model, head, optimizer): assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k] # take a step - run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) - run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) - - # check that saving did not cause a change in the parameters - for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): - assert torch.allclose( - param1, param2 - ), "parameters of the two identical models have diverged (after consolidating)" + run_grad_step(model_oss1, head_oss1, sharded_optimizer1) + run_grad_step(model_oss2, head_oss2, sharded_optimizer2) + check_equal_models("parameters of the two identical models have diverged (after consolidating)") - # save again - sharded_optimizer2.consolidate_state_dict() - state_dict2 = sharded_optimizer2.state_dict() + # save again for one rank, then distribute to the others + sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks + state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {} + state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) # reload the state_dict sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) @@ -688,23 +735,20 @@ def run_grad_step(device, model, head, optimizer): sharded_optimizer2.load_state_dict(state_dict2) # take a step - run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) - run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) - - # check that reloading a saved state dict does not change the parameters - for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): - assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (after reloading)" + run_grad_step(model_oss1, head_oss1, sharded_optimizer1) + run_grad_step(model_oss2, head_oss2, sharded_optimizer2) + check_equal_models("parameters of the two identical models have diverged (after reloading)") dist.destroy_process_group() @skip_if_no_cuda def test_state_dict_distributed(): - world_size = 8 + world_size = 2 temp_file_name = tempfile.mkstemp()[1] if torch.cuda.is_available(): - world_size = min(world_size, torch.cuda.device_count()) + world_size = max(world_size, torch.cuda.device_count()) mp.spawn( run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True, @@ -719,19 +763,51 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) + hidden = 5 + in_channels = 3 + out_channels = 3 + batch = 64 def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank - model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3),) - model.register_buffer("test_buffer", torch.ones((1)) * rank) - model.to(device) + trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden)) + trunk.register_buffer("test_buffer", torch.ones((1)) * rank) + trunk.to(device) + + head = torch.nn.Linear(hidden, out_channels).to(device) + + # Define a model to be trained by OSS + oss_model = torch.nn.Sequential(trunk, head) + oss_trainable_params = [ + {"params": trunk.parameters(), "lr": 1e-5}, + {"params": head.parameters(), "lr": 1e-4}, + ] + + optimizer_settings = {} + if isinstance(optim, torch.optim.SGD): + optimizer_settings["momentum"] = 0.9 + + sharded_optimizer = optim.OSS( + params=oss_trainable_params, + optim=optimizer, + group=None, + broadcast_buffer_size=2 ** 10, + **optimizer_settings, + ) + + oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True) - sharded_optimizer = optim.OSS(params=model.parameters(), optim=optimizer, lr=1e-3) - sharded_ddp_model = DDP(module=model, device_ids=[rank], broadcast_buffers=True) + # Define a model to be trained by normal pytorch + DDP + ddp_trunk = copy.deepcopy(trunk) + ddp_head = copy.deepcopy(head) + ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head) - ddp_model_single = copy.deepcopy(model) - ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) - ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) + ddp_trainable_params = [ + {"params": ddp_trunk.parameters(), "lr": 1e-5}, + {"params": ddp_head.parameters(), "lr": 1e-4}, + ] + ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore + ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): @@ -740,17 +816,13 @@ def check_same_model_params(): p, ddp_p, atol=1e-3 ), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}" - for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): + for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b ), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}" - # The model should be synchronized in between the ranks at construction time, check that - check_same_model_params() - - # The models should stay the same in between the ranks - for i in range(20): - input_tensor = torch.rand((64, 2)).to(device) + def check_step(): + input_tensor = torch.rand((batch, in_channels)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() @@ -760,7 +832,7 @@ def closure_ddp(input_tensor=input_tensor): def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() - sharded_loss = sharded_ddp_model(input_tensor).abs().sum() + sharded_loss = oss_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss @@ -771,8 +843,29 @@ def closure_sharded(input_tensor=input_tensor): loss_ddp, loss_sharded_optim ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" + # The model should be synchronized in between the ranks at construction time, check that + check_same_model_params() + + # The models should stay the same in between ddp and sharded optimizer + for i in range(5): + check_step() check_same_model_params() + # Check that the checkpoints are compatible + # - get states + ddp_state_dict = ddp_optimizer.state_dict() + sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) + sharded_optim_state_dict = sharded_optimizer.state_dict() if rank == RECIPIENT_RANK else {} + sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device) + + # - cross load the states + ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose ! + sharded_optimizer.load_state_dict(ddp_state_dict) + + # - run one step and check that the models are still the same + check_step() + check_same_model_params() + for opt in [torch.optim.SGD, torch.optim.Adam]: check_optimizer_equivalence(opt) diff --git a/tests/optim/test_oss_adascale.py b/tests/optim/test_oss_adascale.py index 7161e0b21..0febcc9af 100644 --- a/tests/optim/test_oss_adascale.py +++ b/tests/optim/test_oss_adascale.py @@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD -from fairscale.optim import OSS, AdaScale +from fairscale.optim import OSS, AdaScale, AdaScaleWrapper from fairscale.utils.golden_testing_data import adascale_test_data from fairscale.utils.testing import skip_if_single_gpu @@ -40,11 +40,16 @@ def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None model = Linear(2, 2, bias=False) model.to("cuda") model = DDP(model, device_ids=[rank]) - if oss: - # For now, we can only wrap AdaScale over OSS. If we do it the other way around, - # AdaScale needs to take different parameter types, i.e. the parameter list, etc. + + assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"] + if oss == "ada-oss": optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1)) + elif oss == "wrapper-oss": + optim = AdaScaleWrapper(model.parameters(), optim_cls=OSS, optim=SGD, lr=0.1) + elif oss == "oss-wrapper": + optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1) else: + assert oss == "none" optim = AdaScale(SGD(model.parameters(), lr=0.1)) if "input" in test_case: @@ -59,7 +64,8 @@ def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None optim.step() optim.zero_grad() - assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain() + if "expected_gain" in test_case: + assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain() if "expected_mean_weight" in test_case: mean_weight = mean([model.module[i].weight.data.mean().item() for i in range(4)]) @@ -75,11 +81,11 @@ def test_basic(test_case): world_size = 2 temp_file_name = tempfile.mkstemp()[1] - mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, True), nprocs=world_size, join=True) + mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, "ada-oss"), nprocs=world_size, join=True) @skip_if_single_gpu -@pytest.mark.parametrize("oss", [True, False]) +@pytest.mark.parametrize("oss", ["none", "ada-oss", "wrapper-oss", "oss-wrapper"]) def test_sequential(oss): """Test adascale with DDP + OSS with a sequential model""" world_size = 2 @@ -98,6 +104,14 @@ def test_sequential(oss): "expected_mean_weight": 52.92657661437988, } + if oss == "oss-wrapper": + # When OSS wraps AdaScale, the training is numerically different + # and it exists only to enable future research. So we don't check + # the gain (OSS doesn't have a gain() function, different rank's + # gains are different). We just ensure the mean_weight is expected. + del test_case["expected_gain"] + test_case["expected_mean_weight"] = 94.93386840820312 + # The model. model = Sequential( Linear(2, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 6, bias=False) From 7bd82d1daf0968b0219e75c44c76765326465c0d Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Fri, 12 Feb 2021 07:13:07 -0800 Subject: [PATCH 37/48] two small changes (#83) * two small changes - link to deepspeed in the doc - added a quick test in the init to catch a common user error * addressed comments * remove an overly strong assert * addressed comments --- .../fully_sharded_data_parallel.py | 10 ++++++- fairscale/utils/parallel.py | 30 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 fairscale/utils/parallel.py diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 94f17d8e0..8a7e9c048 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -25,6 +25,7 @@ unpack_kwargs, unpack_non_tensors, ) +from fairscale.utils.parallel import validate_process_group if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -32,7 +33,11 @@ class FullyShardedDataParallel(nn.Module): """ - A wrapper for sharding Module parameters. + A wrapper for sharding Module parameters. This is inspired by `Xu et al.`_ + as well as the ZeRO Stage 3 from the DeepSpeed_ work. + + .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 + .. _DeepSpeed: https://www.deepspeed.ai/ Usage:: @@ -112,6 +117,9 @@ def __init__( if self.cpu_offload and not self.mixed_precision: raise ValueError("cpu_offload requires mixed_precision=True") + compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device + validate_process_group(compute_device, self.process_group) + # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. diff --git a/fairscale/utils/parallel.py b/fairscale/utils/parallel.py new file mode 100644 index 000000000..2a76b9b8d --- /dev/null +++ b/fairscale/utils/parallel.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +"""Useful functions for parallel training.""" + + +def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: + """Do a quick test in case user called FSDP without calling torch.cuda.set_device() + correctly. This can easily happen in cpu_offload case where the model resides on + the CPU. + """ + if not hasattr(process_group, "allgather"): + # Likely a dummy pg for unit test, skip checking. + return + + world_size = process_group.size() + if "cuda" in str(device): + input_tensor = torch.ones(1).to(device) + output = list(torch.zeros(world_size).to(device).chunk(world_size)) + dist.all_gather(output, input_tensor, group=process_group) + assert torch.cat(output).sum() == float(world_size), ( + f"found {torch.cat(output).sum()} devices in process group but " + f"world_size={world_size}. Check torch.cuda.set_device is called properly" + ) From d8f33496b3943cdf7d50446e3838d1c7e3a17e44 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Sun, 14 Feb 2021 13:57:22 -0800 Subject: [PATCH 38/48] fixing uneven shard support and add uneven shard unit test (#80) * added uneven shard unit test * updates * fixed * change the inputs * skip torch 1.5 * merged after renaming * add assert output * overlly small param test * fixed the overly small param corner case * another issue with _calc_num_to_pad * one more fix * a lot of more test cases and simplified the sharding logic, but still not fully right * make sure we cover 8 GPUs too, run more than one iter for the smaller param test * byte -> element * fixing an assert in the test * add a todo comment * moved compute_shard_size into a separate module and added unit test * fixed the issue of free_storage * clean up * two more bugs * one more bug * one more bug * not more F.pad() in reduce_scatter --- .../fully_sharded_data_parallel.py | 129 +++++++++++------- fairscale/utils/parallel.py | 25 +++- stubs/torch/nn/parameter.pyi | 1 + tests/nn/data_parallel/test_fsdp_uneven.py | 113 +++++++++++++++ .../test_fully_sharded_data_parallel.py | 4 +- tests/utils/test_parallel.py | 26 ++++ 6 files changed, 248 insertions(+), 50 deletions(-) create mode 100644 tests/nn/data_parallel/test_fsdp_uneven.py create mode 100644 tests/utils/test_parallel.py diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 8a7e9c048..9fe77fa78 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -6,7 +6,6 @@ import contextlib import copy import functools -import math from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import torch @@ -25,7 +24,7 @@ unpack_kwargs, unpack_non_tensors, ) -from fairscale.utils.parallel import validate_process_group +from fairscale.utils.parallel import compute_shard_size, validate_process_group if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -185,48 +184,45 @@ def _shard_parameters_(self) -> None: p._is_sharded = True p._orig_size = p.data.size() - # shard p.data such that all elements are part of a shard and the last shard is <= all other shards - shard_size = math.ceil(p.data.numel() / self.world_size) - s = self.rank * shard_size + + # shard p.data such that all elements are part of a shard and the + # last shard is <= all other shards or, all shards are 1-element + # in size in case total size is smaller than the world_size. + # + # This way, we don't have holes when shards are reconstructed and + # only extra padding elements need to added/removed during sharding. + shard_size = compute_shard_size(p.data.numel(), self.world_size) + s = min(self.rank * shard_size, p.data.numel()) e = min(s + shard_size, p.data.numel()) + assert ( + 0 <= s <= e <= p.data.numel() + ), f"_shard_parameters_: {p.data.numel()} {self.world_size} {shard_size} {s} {e}" orig_data = p.data p.data = torch.flatten(p.data)[s:e].clone() + if p.data.numel() < shard_size: + p.data = F.pad(p.data, [0, shard_size - p.data.numel()]) # pad zeros to the right size. + assert p.data.numel() == shard_size, f"{p.data.numel()} {shard_size}" free_storage_(orig_data) - @property - def is_last_rank(self) -> bool: - return self.rank == (self.world_size - 1) - - def _calc_num_to_pad(self, numel: int) -> int: - num_extra = numel % self.world_size - num_to_pad = self.world_size - num_extra if num_extra > 0 else 0 - return num_to_pad - @torch.no_grad() def _all_gather_full_param(self, p: nn.Parameter) -> None: """Fill p._full_param with gathered p.data values (using torch.distributed.all_gather). - If the last shard is smaller than the other shards, we pad it with zeroes. - """ - full_param_chunks = list(torch.flatten(p._full_param).chunk(self.world_size)) - # We overwrite the last chunk with zeroes if it is smaller than the other entries. - num_to_pad = self._calc_num_to_pad(p._orig_size.numel()) - pointer_to_last_chunk = full_param_chunks[-1] - assert pointer_to_last_chunk.numel() + num_to_pad == full_param_chunks[0].numel() + The p._full_param is already allocated and have the size equal + to shard_size * world_size. + + It is up to the caller to do necessary resize/reshape to the + unpadded _full_param. + """ + full_param_chunks = list(p._full_param.chunk(self.world_size)) + assert len(full_param_chunks) == self.world_size + assert full_param_chunks[-1].numel() == p.data.numel(), f"{full_param_chunks[-1].numel()} {p.data.numel()}" param_shard = p.data # we will gather this from each worker - if num_to_pad > 0: # add padding to param_shard and full_param_chunks[-1] - full_param_chunks[-1] = torch.zeros_like(full_param_chunks[0]) # no longer shares memory with full_param - if self.is_last_rank: - param_shard = F.pad(p.data, [0, num_to_pad]) + dist.all_gather(full_param_chunks, param_shard, group=self.process_group) # ^ updates p._full_param - # remove padding from full_param_chunks[-1] and p.data - if num_to_pad > 0: # copy shard associated with the padded chunk to full_param - pointer_to_last_chunk.copy_(full_param_chunks[-1][:-num_to_pad]) - # ^ updates p._full_param - def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" try: @@ -597,7 +593,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param.grad.data.div_(self.world_size) # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(param.grad.data) + param.grad.data = self._reduce_scatter(param.grad.data, param.data.numel()) else: param.grad.data = torch.flatten(param.grad.data) @@ -637,13 +633,24 @@ def _rebuild_full_params(self) -> None: self._cast_fp32_param_shards_to_fp16() for p in self.params: - if p._full_param.storage().size() != p._orig_size.numel(): - alloc_storage_(p._full_param, size=p._orig_size) + p_size = p.data.numel() * self.world_size + if p._full_param.storage().size() != p_size: + # Allocate based on full size from all shards. + p._full_param.resize_(p_size) + alloc_storage_(p._full_param, torch.Size((p_size,))) if self.world_size > 1: - # Fill p._full_param with (p.data for each shard in self.world_size). + # Fill p._full_param with (p.data for each shard in self.world_size) self._all_gather_full_param(p) + if p._orig_size.numel() < p._full_param.numel(): + # We need a smaller view into _full_param and save + # _full_param_padded. + p._full_param_padded = p._full_param + # Note, full size can be >> orig_size when world_size is + # large and param size is tiny. + p._full_param = p._full_param.split(p._orig_size.numel())[0] else: torch.flatten(p._full_param).copy_(p.data) + p._full_param = p._full_param.reshape(p._orig_size) p.data = p._full_param @@ -679,7 +686,11 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: # Storage object and unshard it in-place. For now, just resize # the Storage to 0 to save memory. p._full_param.record_stream(current_stream) - free_storage_(p._full_param) + if hasattr(p, "_full_param_padded"): + free_storage_(p._full_param_padded) + delattr(p, "_full_param_padded") + else: + free_storage_(p._full_param) @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: @@ -720,20 +731,44 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No free_storage_(p._fp16_shard) @torch.no_grad() - def _reduce_scatter(self, tensor: torch.Tensor) -> torch.Tensor: + def _reduce_scatter(self, tensor: torch.Tensor, shard_size: int) -> torch.Tensor: """Reduce-scatter a Tensor (gradient from the local worker) and return - the result (a single shard of the summed gradient across workers).""" + the result (a single "flattened" shard of the summed gradient across workers). + + Shard_size is passed in to compute the padding, but we don't use F.pad since + it reallocates the tensor, which can be a big chunk of memory consumed. Instead, + we allocate only for missing and incomplete shards and copy only the needed + data to the first allocated shard. (The remining allocated shards are just + padding for reduce_scatter.) + """ tensor = torch.flatten(tensor) - num_to_pad = self._calc_num_to_pad(tensor.numel()) - if num_to_pad > 0: # pad the gradient to be divisible by world_size - tensor = F.pad(tensor, [0, num_to_pad]) - assert tensor.numel() % self.world_size == 0 - tensor = tensor.view(self.world_size, -1) - output = torch.zeros_like(tensor[0]) # filled with gradient summed across workers - to_scatter = list(tensor.unbind(0)) # world size tensors of shape (shard_size,) - dist.reduce_scatter(output, to_scatter, group=self.process_group) - if self.is_last_rank and num_to_pad > 0: - output = output[:-num_to_pad] + full_shards, rem = divmod(tensor.numel(), shard_size) + assert full_shards <= self.world_size, ( + f"incorrect shard_size {shard_size} " f"full_shards {full_shards} " f"world_size {self.world_size}" + ) + full_shards_view = tensor + if rem > 0: + # Get two views in to the tensor. + full_shards_view, rem_view = tensor.split(full_shards * shard_size) + + # This is first part of to_scatter list. + to_scatter = list(full_shards_view.view(-1, shard_size).unbind(0)) + + tail = [] + if full_shards < self.world_size: + # This is the second part of the to_scatter list. + tail = [torch.zeros_like(to_scatter[0]) for i in range(full_shards, self.world_size)] + + if rem > 0: + # Copy the right data in to the first partial shard. + tail[0][:rem].copy_(rem_view) + + assert len(to_scatter) + len(tail) == self.world_size, ( + f"incorrect length {len(to_scatter)} + {len(tail)} vs. " f"{self.world_size}" + ) + + output = torch.zeros_like(to_scatter[0]) # will be filled with gradient summed across workers + dist.reduce_scatter(output, to_scatter + tail, group=self.process_group) return output diff --git a/fairscale/utils/parallel.py b/fairscale/utils/parallel.py index 2a76b9b8d..eba13546f 100644 --- a/fairscale/utils/parallel.py +++ b/fairscale/utils/parallel.py @@ -3,11 +3,34 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +"""Useful functions for parallel training.""" + import torch import torch.distributed as dist from torch.distributed import ProcessGroup -"""Useful functions for parallel training.""" + +def compute_shard_size(numel: int, world_size: int) -> int: + """Compute shard size like the behavior of torch.chunk().""" + assert numel > 0 and world_size > 0, "invalid inputs" + if numel % world_size == 0: + # easy case, including world_size == 1. + shard_size = numel // world_size + else: + if world_size == 2: + # two shards, shard size is the size of the bigger one. + shard_size = numel // world_size + 1 + else: + # find the equal chunks until reminder is smaller than shard_size + for div in range(world_size - 1, 1, -1): + shard_size, rem = divmod(numel, div) + if shard_size >= rem: + break + # corner case: bunch of 1 elements and rest are 0s. + if shard_size == 0: + shard_size = 1 + assert shard_size > 0, f"bug: {shard_size}" + return shard_size def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index e9515b224..82658bbc3 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -11,6 +11,7 @@ class Parameter(Tensor): _is_sharded: bool _orig_size: Size _cpu_grad: Tensor + _full_param_padded: Tensor _full_param: Tensor _fp32_shard: Tensor _fp16_shard: Optional[Tensor] diff --git a/tests/nn/data_parallel/test_fsdp_uneven.py b/tests/nn/data_parallel/test_fsdp_uneven.py new file mode 100644 index 000000000..9dadd5eb9 --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_uneven.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test FSDP with uneven parameter shards. """ + +import tempfile + +import pytest +import torch +from torch import Tensor +import torch.multiprocessing as mp +from torch.nn import Linear, Sequential +from torch.optim import SGD + +from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP +from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version + + +def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): + result = dist_init(rank, world_size, tempfile_name, unused) + assert result, "Dist init failed" + + if test_case["assert_ref_out"]: + with torch.no_grad(): + weight = model.weight.T.clone().cuda() + v = torch.Tensor(test_case["inputs"][0][rank]).cuda() + ref_out = torch.matmul(v, weight) + model.to("cuda") + assert isinstance(fsdp_config, dict), str(fsdp_config) + model = FSDP(model, **fsdp_config) + optim = SGD(model.parameters(), lr=0.1) + inputs = test_case["inputs"] + assert len(inputs) == 1 or not test_case["assert_ref_out"] + assert len(inputs[0]) >= world_size + for in_data in inputs: + in_data = Tensor(in_data[rank]).cuda() + out = model(in_data) + out.sum().backward() + optim.step() + optim.zero_grad() + + if test_case["assert_ref_out"]: + torch.testing.assert_allclose(ref_out, out) + + teardown() + + +@skip_if_single_gpu +@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}]) +@pytest.mark.parametrize( + "fsdp_config", [{}, {"flatten_parameters": False}], +) +@pytest.mark.parametrize("world_size", list(range(2, 9))) +def test_one_iteration(world_size, test_case, fsdp_config): + """Test FSDP with uneven divide of parameter shards.""" + if torch_version() < (1, 6, 0): + pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") + + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs.") + + temp_file_name = tempfile.mkstemp()[1] + unused = tempfile.mkstemp()[1] + + # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers + # more cases in FSDP. Also, assert_ref_out can be extended to multiple + # iterations. This could be a good bootcamp task. I should file a github + # issue once we merge. + model = Linear(3, 3, bias=False) + mp.spawn( + _test_func, + args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), + nprocs=world_size, + join=True, + ) + + +@skip_if_single_gpu +@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3), torch.rand(8, 3)], "assert_ref_out": False}]) +@pytest.mark.parametrize("fsdp_config", [{}, {"flatten_parameters": False}]) +@pytest.mark.parametrize("world_size", list(range(2, 9))) +def test_smaller_than_world_size(world_size, test_case, fsdp_config): + """Test FSDP with uneven divide of parameter shards.""" + if torch_version() < (1, 6, 0): + pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") + + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs.") + + temp_file_name = tempfile.mkstemp()[1] + unused = tempfile.mkstemp()[1] + + model = Sequential( + Linear(3, 3, bias=False), + Linear(3, 4, bias=False), + Linear(4, 5, bias=False), + Linear(5, 4, bias=False), + Linear(4, 3, bias=False), + Linear(3, 1, bias=False), + Linear(1, 1, bias=False), # param here is smaller than world_size if unflattened. + ) + mp.spawn( + _test_func, + args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), + nprocs=world_size, + join=True, + ) diff --git a/tests/nn/data_parallel/test_fully_sharded_data_parallel.py b/tests/nn/data_parallel/test_fully_sharded_data_parallel.py index d0557424c..368292031 100644 --- a/tests/nn/data_parallel/test_fully_sharded_data_parallel.py +++ b/tests/nn/data_parallel/test_fully_sharded_data_parallel.py @@ -145,9 +145,9 @@ def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtyp expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, ) - def _reduce_scatter(self, tensor): + def _reduce_scatter(self, tensor, shard_size): model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) - return orig_reduce_scatter(self, tensor) + return orig_reduce_scatter(self, tensor, shard_size) with mock.patch.object(FullyShardedDataParallel, "_reduce_scatter", new=_reduce_scatter): model = FullyShardedDataParallel(model, group, **cfg).cuda() diff --git a/tests/utils/test_parallel.py b/tests/utils/test_parallel.py new file mode 100644 index 000000000..00a62d521 --- /dev/null +++ b/tests/utils/test_parallel.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test utility classes from containers.py. """ + +import pytest +import torch + +from fairscale.utils.parallel import compute_shard_size + + +@pytest.mark.parametrize( + "test_case", [(1, 2), (2, 2), (3, 2), (4, 2), (4, 4), (3, 4), (9, 4), (9, 6), (10, 5), (14, 5)] +) +def test_compute_shard_size(test_case): + """Test compute_shard_size, verify using torch.chunk()""" + numel, world_size = test_case + result = compute_shard_size(numel, world_size) + expected = torch.zeros(numel).chunk(world_size)[0].numel() + assert result == expected, f"{result} == {expected}" From 366c38e3015052390a0e12633f86233ae8967f6c Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Mon, 15 Feb 2021 07:12:39 -0800 Subject: [PATCH 39/48] rename test file (#86) - this way, it is slightly easier to run all tests by python -m pytest tests/nn/data_parallel/test_fsdp* --- .../{test_fully_sharded_data_parallel.py => test_fsdp.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/nn/data_parallel/{test_fully_sharded_data_parallel.py => test_fsdp.py} (100%) diff --git a/tests/nn/data_parallel/test_fully_sharded_data_parallel.py b/tests/nn/data_parallel/test_fsdp.py similarity index 100% rename from tests/nn/data_parallel/test_fully_sharded_data_parallel.py rename to tests/nn/data_parallel/test_fsdp.py From 87d28c9f588a1295b1fc00f2363254d3696b23ae Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Thu, 18 Feb 2021 16:14:07 -0800 Subject: [PATCH 40/48] adding TrainingState enum for asserting purpose (#87) * adding TrainingState enum for asserting purpose * addressed Sam's suggestions --- .../fully_sharded_data_parallel.py | 42 +++++++++++++++++++ tests/nn/data_parallel/test_fsdp.py | 2 + tests/nn/data_parallel/test_fsdp_uneven.py | 1 + 3 files changed, 45 insertions(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 9fe77fa78..378dbdd51 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -5,6 +5,7 @@ import contextlib import copy +from enum import Enum, auto import functools from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union @@ -30,6 +31,25 @@ from collections import OrderedDict # noqa: F401 +class TrainingState(Enum): + """ + Simple enum to indicate what state FSDP is in. Used for asserting + to make sure APIs are called in the correct state. + + TODO (Min): It would be nice to capture the stepping state as well. + Maybe we can use the model.zero_grad() call, but not sure if it + is called if optim.zero_grad() is used instead. + It would be nice to have clear state transition be explicit like: + + zero_grad -> fwd -> bwd -> optionally accum grad by repeating + fwd/bwd -> stepping -> loop back to zero_grad + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD = auto() + + class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters. This is inspired by `Xu et al.`_ @@ -145,6 +165,8 @@ def __init__( # pass. This will be False when inside the no_sync context manager. self.require_backward_grad_sync: bool = True + self.training_state = TrainingState.IDLE + @torch.no_grad() def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: """Move all buffers to the specified device and dtype, recursively.""" @@ -315,6 +337,8 @@ def no_sync(self) -> Generator: variables, which will later be synchronized in the first forward-backward pass exiting the context. """ + assert self._is_root, "no_sync on inner FSDP is not tested." + self.assert_idle() # This instance may wrap other FullyShardedDataParallel instances and we # need to set all of them to accumulate gradients. old_flags = [] @@ -465,6 +489,9 @@ def _wait_for_previous_optim_step(self) -> None: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._lazy_init() + # Start of a forward pass. + self.training_state = TrainingState.FORWARD + # Due to the use of streams, we need to make sure the previous # ``optim.step()`` is done before we all-gather parameters. if self._is_root: @@ -497,6 +524,9 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # pass (if needed). outputs = self._register_pre_backward_hooks(outputs) + # Done with a forward pass. + self.training_state = TrainingState.IDLE + return outputs def _register_pre_backward_hooks(self, outputs: Any) -> Any: @@ -511,6 +541,10 @@ def _pre_backward_hook(*unused: Any) -> None: if pre_backward_hook_has_run[0]: return # only run once pre_backward_hook_has_run[0] = True + + # Start of a backward pass. + self.training_state = TrainingState.BACKWARD + # All-gather full parameters. if self.reshard_after_forward: self._rebuild_full_params() @@ -624,6 +658,8 @@ def _wait_for_post_backward(self) -> None: if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() + # A backward pass is done. + self.training_state = TrainingState.IDLE @torch.no_grad() def _rebuild_full_params(self) -> None: @@ -771,6 +807,12 @@ def _reduce_scatter(self, tensor: torch.Tensor, shard_size: int) -> torch.Tensor dist.reduce_scatter(output, to_scatter + tail, group=self.process_group) return output + def assert_idle(self) -> None: + """Assert we are in the idle state.""" + assert ( + self.training_state == TrainingState.IDLE + ), f"wrong state to call no_sync. current state is {self.training_state}" + @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 368292031..abddc8eb6 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -58,6 +58,8 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01): assert loss.dtype == torch.float32 model.module.run_backward(loss) optim.step() + if hasattr(model, "assert_idle"): + model.assert_idle() return loss.detach() @staticmethod diff --git a/tests/nn/data_parallel/test_fsdp_uneven.py b/tests/nn/data_parallel/test_fsdp_uneven.py index 9dadd5eb9..0a5b5f761 100644 --- a/tests/nn/data_parallel/test_fsdp_uneven.py +++ b/tests/nn/data_parallel/test_fsdp_uneven.py @@ -48,6 +48,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_out, out) + model.assert_idle() teardown() From 72b59671103ba221f023af467a9f3d2a7e3b5671 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 19 Feb 2021 18:22:31 -0500 Subject: [PATCH 41/48] Misc (#88) * Refactor sharding logic * Make world_size 1 a bit faster * Add MoE test, make other tests more sensitive - replace Adam with SGD w/ momentum, since Adam is scale invariant (less sensitive for tests) - change data across ranks (previously all ranks had the same batch) * CR comments; reduce LR in tests to see if it fixes CI * lint * default num_steps 3 -> 2 --- .../fully_sharded_data_parallel.py | 222 +++++++++--------- fairscale/utils/testing.py | 2 +- stubs/torch/nn/parameter.pyi | 1 - tests/nn/data_parallel/test_fsdp.py | 115 +++++++-- 4 files changed, 202 insertions(+), 138 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 378dbdd51..20f794e17 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -25,7 +25,7 @@ unpack_kwargs, unpack_non_tensors, ) -from fairscale.utils.parallel import compute_shard_size, validate_process_group +from fairscale.utils.parallel import validate_process_group if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -142,7 +142,7 @@ def __init__( # Only handle params which are not already sharded. This enables # sharding individual layers of a Module, with an outer wrapper to # shard any leftover parameters. - params = list(p for p in module.parameters() if not getattr(p, "_is_sharded", False)) + params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded")) if self.flatten_parameters and len(params) > 0: self.module: nn.Module = FlattenParamsWrapper(module, param_list=params) @@ -157,7 +157,7 @@ def __init__( # Make sure all parameters are sharded. for n, p in self.named_parameters(): - assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" + assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}" self._reset_lazy_init() @@ -204,46 +204,41 @@ def _shard_parameters_(self) -> None: if self.mixed_precision: assert p.dtype == torch.float32 - p._is_sharded = True + # If world_size is 1, then we all-reduce grads instead of sharding. + p._is_sharded = self.world_size > 1 p._orig_size = p.data.size() - # shard p.data such that all elements are part of a shard and the - # last shard is <= all other shards or, all shards are 1-element - # in size in case total size is smaller than the world_size. - # - # This way, we don't have holes when shards are reconstructed and - # only extra padding elements need to added/removed during sharding. - shard_size = compute_shard_size(p.data.numel(), self.world_size) - s = min(self.rank * shard_size, p.data.numel()) - e = min(s + shard_size, p.data.numel()) - assert ( - 0 <= s <= e <= p.data.numel() - ), f"_shard_parameters_: {p.data.numel()} {self.world_size} {shard_size} {s} {e}" - - orig_data = p.data - p.data = torch.flatten(p.data)[s:e].clone() - if p.data.numel() < shard_size: - p.data = F.pad(p.data, [0, shard_size - p.data.numel()]) # pad zeros to the right size. - assert p.data.numel() == shard_size, f"{p.data.numel()} {shard_size}" - free_storage_(orig_data) + if not p._is_sharded: + continue + p._is_sharded = True - @torch.no_grad() - def _all_gather_full_param(self, p: nn.Parameter) -> None: - """Fill p._full_param with gathered p.data values (using torch.distributed.all_gather). + # Shard using torch.chunk to match all-gather/reduce-scatter. + chunks = list(torch.flatten(p.data).chunk(self.world_size)) + while len(chunks) < self.world_size: + chunks.append(chunks[0].new_empty(0)) - The p._full_param is already allocated and have the size equal - to shard_size * world_size. + # Determine number of padding elements. + num_to_pad = chunks[0].numel() - chunks[self.rank].numel() + assert num_to_pad >= 0, num_to_pad - It is up to the caller to do necessary resize/reshape to the - unpadded _full_param. - """ - full_param_chunks = list(p._full_param.chunk(self.world_size)) - assert len(full_param_chunks) == self.world_size - assert full_param_chunks[-1].numel() == p.data.numel(), f"{full_param_chunks[-1].numel()} {p.data.numel()}" - param_shard = p.data # we will gather this from each worker + # Replace p.data with the relevant shard. + orig_data = p.data + p.data = chunks[self.rank].clone() # clone since we free storage below + if num_to_pad > 0: + p.data = F.pad(p.data, [0, num_to_pad]) + free_storage_(orig_data) - dist.all_gather(full_param_chunks, param_shard, group=self.process_group) - # ^ updates p._full_param + def extra_repr(self) -> str: + return ( + f"rank={self.rank}, world_size={self.world_size}, " + f"reshard_after_forward={self.reshard_after_forward}, " + f"mixed_precision={self.mixed_precision}, " + f"fp32_reduce_scatter={self.fp32_reduce_scatter}, " + f"flatten_parameters={self.flatten_parameters}, " + f"cpu_offload={self.cpu_offload}, " + f"compute_dtype={self.compute_dtype}, " + f"move_grads_to_cpu={self.move_grads_to_cpu}" + ) def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" @@ -259,6 +254,7 @@ def __getstate__(self) -> Dict[str, str]: we remove them and try to reconstruct them in :func:`__setstate__`. """ state = copy.copy(self.__dict__) + state["is_sharded"] = [p._is_sharded for p in self.params] state["orig_sizes"] = [p._orig_size for p in self.params] if state["process_group"] is not None: state["process_group"] = "MISSING" # process_group isn't pickleable @@ -269,14 +265,17 @@ def __setstate__(self, state: Dict[str, Any]) -> None: """Intercept state setting and perform needed changes on params.""" super().__setstate__(state) - def fixup(p: Parameter, size: torch.Size) -> Parameter: + def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter: assert isinstance(p, Parameter) p.data = p.data.clone() # move tensors out of shared memory - p._is_sharded = True + p._is_sharded = is_sharded p._orig_size = size return p - self.params = [fixup(p, size) for p, size in zip(self.params, self.orig_sizes)] + self.params = [ + fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes) + ] + del self.is_sharded del self.orig_sizes self._reset_lazy_init() @@ -371,7 +370,7 @@ def _lazy_init(self) -> None: if self._is_root is None: self._set_is_root() self._setup_streams() - if self.cpu_offload: # Buffers stay on GPU, and dont get sharded + if self.cpu_offload: # Buffers stay on GPU, and don't get sharded self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) else: self._all_buffers_to(dtype=self.compute_dtype) @@ -387,7 +386,9 @@ def _init_param_attributes(self, p: Parameter) -> None: We manage several attributes on each Parameter instance. The first two are set by :func:`_shard_parameters_`: - ``_is_sharded``: ``True`` after the Parameter is initially sharded + ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` + if the Parameter is intentionally not sharded (in which case we + will all-reduce grads for this param). ``_orig_size``: the size of the original Parameter (before sharding) The remaining attributes are set here: @@ -397,12 +398,13 @@ def _init_param_attributes(self, p: Parameter) -> None: depending on the value of *``cpu_offload``*. ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be a single shard of the parameters in FP16, used for all-gather. - ``_full_param``: the full weight, used for computation in the - forward/backward pass. This will be resized in place and only - materialized (via all-gather) as needed. + ``_full_param_padded``: the full weight (padded to be evenly + divisible by ``world_size``), used for computation in the + forward and backward pass. This will be resized in place and + only materialized (via all-gather) as needed. """ - assert p._is_sharded and hasattr(p, "_orig_size") - if hasattr(p, "_full_param"): + assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size") + if hasattr(p, "_fp32_shard"): return # Compute device defaults to CUDA when *cpu_offload* is enabled, or the @@ -435,9 +437,15 @@ def _init_param_attributes(self, p: Parameter) -> None: # We also maintain a full-sized parameter of type self.compute_dtype # (FP16 for mixed_precision or FP32 otherwise). We resize the - # storage to size 0 at init (here) and only materialize as needed. - p._full_param = torch.zeros(p._orig_size, device=compute_device, dtype=self.compute_dtype) - free_storage_(p._full_param) + # storage to size 0 at init (here) and only materialize as needed. The + # storage may contain padding elements so that it is evenly divisible by + # world_size, although these padding elements will be removed before the + # relevant computation. + if p._is_sharded: + p._full_param_padded = torch.zeros( + p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype + ) + free_storage_(p._full_param_padded) if self.move_grads_to_cpu: # We can optionally move the grad shard to CPU during the backward @@ -626,10 +634,12 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Average grad by world_size for consistency with PyTorch DDP. param.grad.data.div_(self.world_size) + if param._is_sharded: # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(param.grad.data, param.data.numel()) - else: - param.grad.data = torch.flatten(param.grad.data) + param.grad.data = self._reduce_scatter_grad(param) + elif self.world_size > 1: + # All-reduce non-sharded grad. + dist.all_reduce(param.grad.data, group=self.process_group) # Cast grad to param's dtype (typically FP32). Note: we do this # before the move_grads_to_cpu step so that this entire hook remains @@ -669,26 +679,24 @@ def _rebuild_full_params(self) -> None: self._cast_fp32_param_shards_to_fp16() for p in self.params: - p_size = p.data.numel() * self.world_size - if p._full_param.storage().size() != p_size: + if not p._is_sharded: + if self.mixed_precision: + p.data = p._fp16_shard + continue + + p_size = p._full_param_padded.size() + if p._full_param_padded.storage().size() != p_size.numel(): # Allocate based on full size from all shards. - p._full_param.resize_(p_size) - alloc_storage_(p._full_param, torch.Size((p_size,))) - if self.world_size > 1: - # Fill p._full_param with (p.data for each shard in self.world_size) - self._all_gather_full_param(p) - if p._orig_size.numel() < p._full_param.numel(): - # We need a smaller view into _full_param and save - # _full_param_padded. - p._full_param_padded = p._full_param - # Note, full size can be >> orig_size when world_size is - # large and param size is tiny. - p._full_param = p._full_param.split(p._orig_size.numel())[0] + alloc_storage_(p._full_param_padded, size=p_size) + assert p_size.numel() % self.world_size == 0 + if p._is_sharded: + # Fill p._full_param_padded with (p.data for each shard in self.world_size) + chunks = list(p._full_param_padded.chunk(self.world_size)) + dist.all_gather(chunks, p.data, group=self.process_group) else: - torch.flatten(p._full_param).copy_(p.data) - p._full_param = p._full_param.reshape(p._orig_size) + p._full_param_padded.copy_(torch.flatten(p.data), non_blocking=True) - p.data = p._full_param + p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) if self.mixed_precision: self._free_fp16_param_shard([p]) @@ -697,8 +705,13 @@ def _rebuild_full_params(self) -> None: @torch.no_grad() def _use_full_params(self) -> None: for p in self.params: - assert p._full_param.storage().size() != 0 - p.data = p._full_param + if not p._is_sharded: + if self.mixed_precision: + assert p._fp16_shard.storage().size() != 0 + p.data = p._fp16_shard + else: + assert p._full_param_padded.storage().size() != 0 + p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) @torch.no_grad() def _prep_grads_for_backward(self) -> None: @@ -715,18 +728,18 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: current_stream = torch.cuda.current_stream() with torch.cuda.stream(self._streams["all_gather"]): for p in params: + if not p._is_sharded: + if self.mixed_precision: + self._free_fp16_param_shard([p]) + continue # There may be external references to the Tensor Storage that we # can't modify, such as references that are created by # ctx.save_for_backward in the forward pass. Thus when we # unshard parameters, we should reuse the original Tensor # Storage object and unshard it in-place. For now, just resize # the Storage to 0 to save memory. - p._full_param.record_stream(current_stream) - if hasattr(p, "_full_param_padded"): - free_storage_(p._full_param_padded) - delattr(p, "_full_param_padded") - else: - free_storage_(p._full_param) + p._full_param_padded.record_stream(current_stream) + free_storage_(p._full_param_padded) @torch.no_grad() def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: @@ -767,44 +780,21 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No free_storage_(p._fp16_shard) @torch.no_grad() - def _reduce_scatter(self, tensor: torch.Tensor, shard_size: int) -> torch.Tensor: - """Reduce-scatter a Tensor (gradient from the local worker) and return - the result (a single "flattened" shard of the summed gradient across workers). - - Shard_size is passed in to compute the padding, but we don't use F.pad since - it reallocates the tensor, which can be a big chunk of memory consumed. Instead, - we allocate only for missing and incomplete shards and copy only the needed - data to the first allocated shard. (The remining allocated shards are just - padding for reduce_scatter.) - """ - tensor = torch.flatten(tensor) - full_shards, rem = divmod(tensor.numel(), shard_size) - assert full_shards <= self.world_size, ( - f"incorrect shard_size {shard_size} " f"full_shards {full_shards} " f"world_size {self.world_size}" - ) - full_shards_view = tensor - if rem > 0: - # Get two views in to the tensor. - full_shards_view, rem_view = tensor.split(full_shards * shard_size) - - # This is first part of to_scatter list. - to_scatter = list(full_shards_view.view(-1, shard_size).unbind(0)) - - tail = [] - if full_shards < self.world_size: - # This is the second part of the to_scatter list. - tail = [torch.zeros_like(to_scatter[0]) for i in range(full_shards, self.world_size)] - - if rem > 0: - # Copy the right data in to the first partial shard. - tail[0][:rem].copy_(rem_view) - - assert len(to_scatter) + len(tail) == self.world_size, ( - f"incorrect length {len(to_scatter)} + {len(tail)} vs. " f"{self.world_size}" - ) - - output = torch.zeros_like(to_scatter[0]) # will be filled with gradient summed across workers - dist.reduce_scatter(output, to_scatter + tail, group=self.process_group) + def _reduce_scatter_grad(self, p: nn.Parameter) -> torch.Tensor: + """Reduce-scatter a Parameter's gradient and return a single shard of + the summed gradient across workers.""" + assert p.grad is not None and p._is_sharded + grad_chunks = list(torch.flatten(p.grad.data).chunk(self.world_size)) + + # torch.chunk may return fewer than world_size chunks, pad accordingly. + num_pad_for_partial_chunk = grad_chunks[0].numel() - grad_chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + grad_chunks[-1] = F.pad(grad_chunks[-1], [0, num_pad_for_partial_chunk]) + if len(grad_chunks) < self.world_size: + grad_chunks.extend([torch.zeros_like(grad_chunks[0])] * (self.world_size - len(grad_chunks))) + + output = torch.zeros_like(grad_chunks[0]) # filled with gradient summed across workers + dist.reduce_scatter(output, grad_chunks, group=self.process_group) return output def assert_idle(self) -> None: diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 85530ec12..371e5624d 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -100,7 +100,7 @@ def torch_version() -> Tuple[int, ...]: # Assuming that we're interested in the second usecase more than the first, # return the pre-release or dev numbering - logging.warning(f"Pytorch pre-relase version {torch.__version__} - assuming intent to test it") + logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it") numbering[2] = "0" return tuple(int(n) for n in numbering) diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 82658bbc3..c6fdd30d0 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -12,7 +12,6 @@ class Parameter(Tensor): _orig_size: Size _cpu_grad: Tensor _full_param_padded: Tensor - _full_param: Tensor _fp32_shard: Tensor _fp16_shard: Optional[Tensor] diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index abddc8eb6..fc2397c66 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -47,7 +47,9 @@ def setUp(self): @staticmethod def _train_for_several_steps(model, num_steps, autocast, lr=0.01): model_device = next(model.parameters()).device - optim = torch.optim.Adam(model.parameters(), lr=lr) + # use SGD with momentum instead of Adam, since Adam is scale invariant + # and this makes it bad for tests + optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -65,9 +67,9 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01): @staticmethod def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> FullyShardedDataParallel: if cuda_first: - model = FullyShardedDataParallel(TransformerWithSharedParams(**model_kwargs).cuda(), group, **config) + model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs).cuda(), group, **config) else: - model = FullyShardedDataParallel(TransformerWithSharedParams(**model_kwargs), group, **config).cuda() + model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda() return model @@ -140,18 +142,19 @@ def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, @staticmethod def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): - # Patch _reduce_scatter op to check the dtype of the reduction - orig_reduce_scatter = FullyShardedDataParallel._reduce_scatter + # Patch torch.distributed.reduce_scatter to check the dtype of the reduction + orig_reduce_scatter = torch.distributed.reduce_scatter model = DeviceAndTypeCheckModule( expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, ) - def _reduce_scatter(self, tensor, shard_size): - model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) - return orig_reduce_scatter(self, tensor, shard_size) + def _reduce_scatter(output, input_list, **kwargs): + for tensor in input_list: + model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) + return orig_reduce_scatter(output, input_list, **kwargs) - with mock.patch.object(FullyShardedDataParallel, "_reduce_scatter", new=_reduce_scatter): + with mock.patch("torch.distributed.reduce_scatter", new=_reduce_scatter): model = FullyShardedDataParallel(model, group, **cfg).cuda() device = next(model.parameters()).device x = torch.rand(2, 5).to(device) @@ -184,10 +187,8 @@ def test_cpu_offload_and_cpu_grads(self): # the device transfer and PyTorch optimizers don't support this. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} test_fn = functools.partial( - self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.001 + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01 ) - # We use lower lr to reduce this test's sensitivity to slightly different CPU vs CUDA behavior of pytorch. - # With lr=0.01, it fails on torch 1.6.0. spawn_and_init(test_fn) def test_cpu_offload_and_cuda_grads_breaks(self): @@ -218,8 +219,26 @@ def test_delayed_reduce_scatter(self): test_fn = functools.partial(self._test_identical_outputs, model_fn, config) spawn_and_init(test_fn) + def test_mixture_of_experts(self): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, + MixtureOfExperts, + config, + # MixtureOfExperts implements custom reduce logic, so the reference + # behavior should use that logic instead of PyTorch DDP. + ref_ddp_fn=self._dummy_ddp_fn, + ) + spawn_and_init(test_fn) + + @classmethod + def _dummy_ddp_fn(self, model, group): + return DummyDDP(model) + @classmethod - def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True, lr=0.01): + def _test_identical_outputs( + cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None + ): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -232,7 +251,12 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() - model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) + if ref_ddp_fn is None: + model = nn.parallel.DistributedDataParallel( + model, device_ids=[rank], output_device=rank, process_group=group + ) + else: + model = ref_ddp_fn(model, group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) ref_state_dict = model.module.state_dict() @@ -319,7 +343,7 @@ def _one_step(self, model, group): for m in model.modules(): if isinstance(m, FullyShardedDataParallel): m.process_group = group - optim = torch.optim.Adam(model.parameters(), lr=0.0001) + optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) input = model.module.get_input(torch.device("cuda")) output = model(*input) loss = model.module.get_loss(input, output) @@ -415,7 +439,7 @@ def _test_module_state_dict(cls, config, rank, group): cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() # You must make a new FullyShardedDataParallel instance to use module.load_state_dict - unwrapped_model = TransformerWithSharedParams() + unwrapped_model = TransformerWithSharedParams(group) unwrapped_model.load_state_dict(state_1) new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda() cls._train_for_several_steps(new_ddp_model, 2, autocast) @@ -451,7 +475,7 @@ def _test_backward_hooks_after_save(self, rank, group, cuda_first=False): def _test_output_backward_hooks(self, rank, group, cuda_first=False, model=None): if model is None: model = self.get_wrapped_model(group, cuda_first=cuda_first) - optim = torch.optim.Adam(model.parameters(), lr=0.0001) + optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim.zero_grad() # Inputs always cuda regardless of move_grads_cpu, or model.device input = model.module.get_input(torch.device("cuda")) @@ -577,8 +601,10 @@ def _test_no_sync(self, model, batch_dim): class TransformerWithSharedParams(nn.Module): - def __init__(self, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): + def __init__(self, group, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): super().__init__() + self.rank = group.rank() + self.world_size = group.size() torch.manual_seed(0) # keep everything deterministic assert d_vocab >= 12 # we use torch.arange(12) as input self.embed_tokens = nn.Embedding(d_vocab, d_model) @@ -591,7 +617,7 @@ def __init__(self, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,))) def get_input(self, device): - torch.manual_seed(1) # keep everything deterministic + torch.manual_seed(1 + self.rank) # keep everything deterministic src = torch.arange(12, device=device).view(6, 2) # T x B tgt = torch.arange(8, device=device).view(4, 2) # T x B return (src, tgt) @@ -614,6 +640,9 @@ def run_backward(self, loss): class NestedWrappedModule(nn.Module): def __init__(self, group, wrapper_config): super().__init__() + self.rank = group.rank() + self.world_size = group.size() + self.wrapper_config = wrapper_config def _maybe_wrap(layer): if wrapper_config is not None: @@ -626,7 +655,7 @@ def _maybe_wrap(layer): ) def get_input(self, device): - torch.manual_seed(1) # keep everything deterministic + torch.manual_seed(1 + self.rank) # keep everything deterministic return (torch.rand(4, 8, device=device),) def forward(self, x): @@ -640,6 +669,52 @@ def run_backward(self, loss): loss.backward() +class DummyDDP(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +class MixtureOfExperts(NestedWrappedModule): + def __init__(self, group, wrapper_config): + super().__init__(group, wrapper_config) + self.group = group + + # "expert" params are different on each rank + torch.manual_seed(42 + group.rank()) + expert = nn.Linear(16, 4) + for p in expert.parameters(): + p.expert = True + + # everything else is shared + torch.manual_seed(0) + shared = nn.Linear(4, 16) + + if wrapper_config is not None: + # we create a process group of size 1 for the expert params + expert_group = torch.distributed.new_group([group.rank()]) + expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) + + shared = FullyShardedDataParallel(shared, group, **wrapper_config) + + self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8)) + + def run_backward(self, loss): + loss.backward() + + # manually reduce gradients if not wrapped in FullyShardedDataParallel + if self.wrapper_config is None: + with torch.no_grad(): + for p in self.parameters(): + if hasattr(p, "expert"): + continue # these params don't need grad reduction + p.grad.data.div_(self.world_size) + torch.distributed.all_reduce(p.grad.data, group=self.group) + + class ModuleWithDelay(nn.Module): def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0): super().__init__() From a0fe832b08be4387926104f4c4d7a6da813d23d9 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 19 Feb 2021 22:49:46 -0500 Subject: [PATCH 42/48] Bugfix for backward stream (#91) --- .../nn/data_parallel/fully_sharded_data_parallel.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 20f794e17..83b54b68a 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -626,6 +626,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): + orig_grad_data = param.grad.data + if self.mixed_precision and self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) @@ -653,6 +655,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: param._cpu_grad.copy_(param.grad.data, non_blocking=True) param.grad.data = param._cpu_grad + # After _post_backward_hook returns, orig_grad_data will eventually + # go out of scope, at which point it could otherwise be freed for + # further reuse by the main stream while the div/reduce_scatter/copy + # are underway in the post_backward stream. See: + # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py + orig_grad_data.record_stream(self._streams["post_backward"]) + # Enqueue a callback at the end of the backward pass to ensure that all # post-backward work has finished. We only need one callback and it only # needs to be called from the outer-most (root) instance. From 2abf21f7048cf023ef36976ca2a3f5418e74d25d Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 09:05:17 -0500 Subject: [PATCH 43/48] Small fix to activation checkpointing + FSDP (#92) * Add test (+ fix) for combining checkpointing + FSDP * Small fix for local_state_dict where it didn't call lazy_init --- .../fully_sharded_data_parallel.py | 19 ++++++++------ fairscale/nn/misc/checkpoint_activations.py | 25 ++++++++++++++++--- stubs/torch/__init__.pyi | 1 + stubs/torch/cuda/amp/__init__.pyi | 7 ++++++ tests/nn/data_parallel/test_fsdp.py | 16 ++++++++---- 5 files changed, 52 insertions(+), 16 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 83b54b68a..8b536d941 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -306,6 +306,8 @@ def local_state_dict(self, *args, **kwargs): # type: ignore so the resulting state_dict can only be loaded after the Module has been wrapped with FullyShardedDataParallel. """ + torch.cuda.synchronize() + self._lazy_init() if self.flatten_parameters: return self.module.flat_state_dict(*args, **kwargs) # type: ignore else: @@ -326,6 +328,7 @@ def load_local_state_dict( self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True ) -> NamedTuple: """Load a local (sharded) state_dict.""" + torch.cuda.synchronize() return self.module.load_state_dict(state_dict, strict) @contextlib.contextmanager @@ -375,11 +378,16 @@ def _lazy_init(self) -> None: else: self._all_buffers_to(dtype=self.compute_dtype) - # Don't free the full params for the outer-most (root) instance, since - # those params will be needed immediately after for the backward pass. if self._is_root: + # Don't free the full params for the outer-most (root) instance, + # since those params will be needed immediately after for the + # backward pass. self.reshard_after_forward = False + # Due to the use of streams, we need to make sure the previous + # ``optim.step()`` is done before we all-gather parameters. + self._wait_for_previous_optim_step() + @torch.no_grad() def _init_param_attributes(self, p: Parameter) -> None: """ @@ -500,11 +508,6 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # Start of a forward pass. self.training_state = TrainingState.FORWARD - # Due to the use of streams, we need to make sure the previous - # ``optim.step()`` is done before we all-gather parameters. - if self._is_root: - self._wait_for_previous_optim_step() - if self.mixed_precision: args, kwargs = cast_inputs_to_fp16(*args, **kwargs) @@ -590,7 +593,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: At the start of :func:`_post_backward_hook`, ``param.grad`` contains the full gradient for the local batch. The reduce-scatter op will replace ``param.grad`` with a single shard of the summed gradient across all - GPUs. This shard will align with the current GPU rank. For example:: + GPUs. This shard will align with the current GPU rank. For example:: before reduce_scatter: param.grad (GPU #0): [1, 2, 3, 4] diff --git a/fairscale/nn/misc/checkpoint_activations.py b/fairscale/nn/misc/checkpoint_activations.py index 44f29c612..a18f1fe06 100644 --- a/fairscale/nn/misc/checkpoint_activations.py +++ b/fairscale/nn/misc/checkpoint_activations.py @@ -3,8 +3,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from contextlib import contextmanager import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Generator, Optional, Tuple import torch from torch import Tensor @@ -66,6 +67,23 @@ def set_rng_state(state: Dict[str, Any]) -> None: torch.cuda.set_rng_state(state["cuda_rng_state"]) +def is_autocast_enabled() -> bool: + """Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1""" + if hasattr(torch, "is_autocast_enabled"): + return torch.is_autocast_enabled() + return False + + +@contextmanager +def autocast(enabled: bool) -> Generator: + """Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1""" + if enabled: + with torch.cuda.amp.autocast(enabled): + yield + else: + yield + + class CheckpointFunction(torch.autograd.Function): """Similar to the torch version, but support non-Tensor outputs. @@ -89,13 +107,13 @@ def forward( # type: ignore ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() + ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs) - else: ctx.fwd_device, ctx.grad_requirements = None, None @@ -135,10 +153,11 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: # Set the states to what it used to be before the forward pass. set_rng_state(ctx.fwd_rng_state) - with torch.enable_grad(): + with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd): unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index f22720d09..5732644b8 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -1914,6 +1914,7 @@ def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API def set_default_dtype(d : _dtype) -> None: ... def manager_path() -> str: ... def compiled_with_cxx11_abi() -> _bool: ... +def is_autocast_enabled() -> _bool: ... # The return value of this function depends on the value of `as_tuple`, # (similar to `unique`, `lu`, etc.); as such, it is not diff --git a/stubs/torch/cuda/amp/__init__.pyi b/stubs/torch/cuda/amp/__init__.pyi index 848378a60..f5bc87a59 100644 --- a/stubs/torch/cuda/amp/__init__.pyi +++ b/stubs/torch/cuda/amp/__init__.pyi @@ -1,3 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Any, Generator + from .grad_scaler import GradScaler as GradScaler + +class autocast: + def __init__(self, enabled=True) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, *args: Any) -> None: ... diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index fc2397c66..9563b2b79 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -16,6 +16,7 @@ from torch import nn from fairscale.nn.data_parallel import FullyShardedDataParallel +from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.utils.testing import ( DeviceAndTypeCheckModule, DummyProcessGroup, @@ -219,12 +220,13 @@ def test_delayed_reduce_scatter(self): test_fn = functools.partial(self._test_identical_outputs, model_fn, config) spawn_and_init(test_fn) - def test_mixture_of_experts(self): - config = {"mixed_precision": True} + @parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test) + def test_mixture_of_experts(self, moe_config): + fsdp_config = {"mixed_precision": True} test_fn = functools.partial( self._test_identical_outputs, - MixtureOfExperts, - config, + functools.partial(MixtureOfExperts, **moe_config), + fsdp_config, # MixtureOfExperts implements custom reduce logic, so the reference # behavior should use that logic instead of PyTorch DDP. ref_ddp_fn=self._dummy_ddp_fn, @@ -679,7 +681,7 @@ def forward(self, *args, **kwargs): class MixtureOfExperts(NestedWrappedModule): - def __init__(self, group, wrapper_config): + def __init__(self, group, wrapper_config, checkpoint_act=False): super().__init__(group, wrapper_config) self.group = group @@ -693,6 +695,10 @@ def __init__(self, group, wrapper_config): torch.manual_seed(0) shared = nn.Linear(4, 16) + if checkpoint_act: + expert = checkpoint_wrapper(expert) + shared = checkpoint_wrapper(shared) + if wrapper_config is not None: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group([group.rank()]) From 977f459a51d8b8cc8ea39c2c1756b46be38d9c09 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 09:05:38 -0500 Subject: [PATCH 44/48] Bugfix (+ add test) for no_sync before first forward (#93) --- .../fully_sharded_data_parallel.py | 17 +++++++++-------- tests/nn/data_parallel/test_fsdp.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 8b536d941..c3b777e55 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -339,7 +339,8 @@ def no_sync(self) -> Generator: variables, which will later be synchronized in the first forward-backward pass exiting the context. """ - assert self._is_root, "no_sync on inner FSDP is not tested." + self._lazy_init() + assert self._is_root, "no_sync on inner FSDP is not supported" self.assert_idle() # This instance may wrap other FullyShardedDataParallel instances and we # need to set all of them to accumulate gradients. @@ -622,6 +623,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # pre_backward_hook. self._free_fp16_param_shard([param]) + # Enqueue a callback at the end of the backward pass to ensure that all + # post-backward work has finished. We only need one callback and it only + # needs to be called from the outer-most (root) instance. + if self._is_root and not self._post_backward_callback_queued: + self._post_backward_callback_queued = True + Variable._execution_engine.queue_callback(self._wait_for_post_backward) + if not self.require_backward_grad_sync: return @@ -665,13 +673,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py orig_grad_data.record_stream(self._streams["post_backward"]) - # Enqueue a callback at the end of the backward pass to ensure that all - # post-backward work has finished. We only need one callback and it only - # needs to be called from the outer-most (root) instance. - if self._is_root and not self._post_backward_callback_queued: - self._post_backward_callback_queued = True - Variable._execution_engine.queue_callback(self._wait_for_post_backward) - @torch.no_grad() def _wait_for_post_backward(self) -> None: """Wait for post-backward work to finish. Only called on root instance.""" diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 9563b2b79..957ffdbf3 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -549,6 +549,18 @@ def test_nested_wrapper(self): fn = functools.partial(self._test_nested_wrapper, config={}) spawn_and_init(fn) + def test_no_sync_before_first_forward(self): + group = DummyProcessGroup(rank=0, size=1) + model = self.get_wrapped_model(group, config={}) + batch = model.module.get_input(torch.device("cuda")) + with model.no_sync(): + output = model(*batch) + loss = model.module.get_loss(batch, output) + loss.backward() + output = model(*batch) + loss = model.module.get_loss(batch, output) + loss.backward() + @classmethod def _test_transformer(self, rank, group, config): model = self.get_wrapped_model(group, config=config) From 43d1f7306819a0a347100e552f6dd29dab9a9747 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Feb 2021 13:14:02 -0500 Subject: [PATCH 45/48] FSDP.clip_grad_norm (#89) --- .../fully_sharded_data_parallel.py | 72 ++++++++++++++++++- fairscale/optim/oss.py | 10 +-- fairscale/optim/utils.py | 22 +++++- fairscale/utils/parallel.py | 2 +- tests/nn/data_parallel/test_fsdp.py | 33 +++++++-- tests/optim/test_oss.py | 1 + 6 files changed, 125 insertions(+), 15 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index c3b777e55..a67478abb 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -7,6 +7,7 @@ import copy from enum import Enum, auto import functools +from math import inf from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import torch @@ -18,6 +19,7 @@ import torch.nn.functional as F from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.optim.utils import calc_grad_norm from fairscale.utils.containers import ( apply_to_tensors, pack_kwargs, @@ -173,6 +175,68 @@ def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) self.apply(cast_fn) + @property + def params_with_grad(self) -> List[Parameter]: + """[p for p in self.parameters() if p.grad is not None] """ + return [p for p in self.parameters() if p.grad is not None] + + @torch.no_grad() + def clip_grad_norm_( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + # filter_params_fn: Callable[[Any], Any] = None, + ) -> torch.Tensor: + """ + Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Arguments: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + .. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank + under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads + in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters + + .. warning: This needs to be called on all ranks, since synchronization primitives will be used + + """ + assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" + assert self.training_state == TrainingState.IDLE + + max_norm = float(max_norm) + norm_type = float(norm_type) + params_with_grad = self.params_with_grad + if not self.children_share_process_group: + raise NotImplementedError( + "clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work" + ) + # Computes the max norm for this shard's gradients and sync's across workers + local_norm = calc_grad_norm(params_with_grad, norm_type).cuda() + if norm_type == inf: + total_norm = local_norm + dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) + else: + total_norm = local_norm ** norm_type + dist.all_reduce(total_norm, group=self.process_group) + total_norm = total_norm ** (1.0 / norm_type) + + if self.move_grads_to_cpu: + total_norm = total_norm.cpu() + # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq) + clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) + if clip_coef < 1: + + # multiply by clip_coef + for p in params_with_grad: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore + + return total_norm + @torch.no_grad() def _shard_parameters_(self) -> None: """ @@ -464,16 +528,22 @@ def _init_param_attributes(self, p: Parameter) -> None: def _set_is_root(self) -> None: """If ``True``, implies that no other :class:`FullyShardedDataParallel` - instance wraps this one. Called once by :func:`_lazy_init`.""" + instance wraps this one. Called once by :func:`_lazy_init`. + Also sets self.children_share_process_group = True if all child instances share the same process group. + If some child instances use a different process group, self.clip_grad_norm_ will raise an error. + """ if self._is_root is not None: return # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False self._is_root = True # As the root, we now set all children instances to False. + self.children_share_process_group = True for n, m in self.named_modules(): if n != "" and isinstance(m, FullyShardedDataParallel): assert m._is_root is None m._is_root = False + if m.process_group != self.process_group: + self.children_share_process_group = False def _setup_streams(self) -> None: """Create streams to overlap data transfer and computation.""" diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index cdf9f6425..15e908bde 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -15,7 +15,7 @@ from torch.nn import Parameter from torch.optim import SGD, Optimizer -from .utils import Workhandle, broadcast_object, recursive_copy_to_device +from .utils import Workhandle, broadcast_object, calc_grad_norm, recursive_copy_to_device __all__ = ["OSS"] @@ -274,18 +274,14 @@ def clip_grad_norm( # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params + local_norm = calc_grad_norm(local_params, norm_type) # Compute the norm on this grad set, # then sync all the norms from all ranks if norm_type == inf: - total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) + total_norm = local_norm # all reduce over data parallel and model parallel workers dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) else: - local_norm = torch.norm( - input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore - p=norm_type, - ) - # local norm result can be accumulated with the remote ones if put to the right power # n_i = sum_rank(a^p)^1/p # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p diff --git a/fairscale/optim/utils.py b/fairscale/optim/utils.py index 8167d70f3..d693ea768 100644 --- a/fairscale/optim/utils.py +++ b/fairscale/optim/utils.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import io -from typing import Any, Callable, Dict, Optional +from math import inf +from typing import Any, Callable, Dict, List, Optional import torch from torch._six import container_abcs @@ -101,3 +102,22 @@ def reset(self) -> None: def full(self) -> bool: """ is the bucket full ? """ return self.max_params_checked_in == self.params_checked_in + + +def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: + r"""Calculate gradient norm of an iterable of parameters. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda par: par.grad is not None, parameters)) + + if len(parameters) == 0: + return torch.tensor(0.0) + p = float(p) + if p == inf: + local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore + else: + local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore + return local_norm diff --git a/fairscale/utils/parallel.py b/fairscale/utils/parallel.py index eba13546f..a02ebb13f 100644 --- a/fairscale/utils/parallel.py +++ b/fairscale/utils/parallel.py @@ -21,7 +21,7 @@ def compute_shard_size(numel: int, world_size: int) -> int: # two shards, shard size is the size of the bigger one. shard_size = numel // world_size + 1 else: - # find the equal chunks until reminder is smaller than shard_size + # find the equal chunks until remainder is smaller than shard_size for div in range(world_size - 1, 1, -1): shard_size, rem = divmod(numel, div) if shard_size >= rem: diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 957ffdbf3..15948f096 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import functools import itertools +from math import inf import pickle import sys from typing import Dict @@ -46,7 +46,7 @@ def setUp(self): raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") @staticmethod - def _train_for_several_steps(model, num_steps, autocast, lr=0.01): + def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests @@ -60,6 +60,12 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01): loss = model.module.get_loss(input, output).to(model_device) assert loss.dtype == torch.float32 model.module.run_backward(loss) + if norm_type is not None: + clip_norm = 0.3 + if isinstance(model, FullyShardedDataParallel): + model.clip_grad_norm_(clip_norm, norm_type) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) optim.step() if hasattr(model, "assert_idle"): model.assert_idle() @@ -230,16 +236,25 @@ def test_mixture_of_experts(self, moe_config): # MixtureOfExperts implements custom reduce logic, so the reference # behavior should use that logic instead of PyTorch DDP. ref_ddp_fn=self._dummy_ddp_fn, + norm_type=None, ) spawn_and_init(test_fn) + def test_mixture_of_experts_grad_clip_breaks(self): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, MixtureOfExperts, config, ref_ddp_fn=self._dummy_ddp_fn, norm_type=2, + ) + with self.assertRaises(Exception): + spawn_and_init(test_fn) + @classmethod def _dummy_ddp_fn(self, model, group): return DummyDDP(model) @classmethod def _test_identical_outputs( - cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None + cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, ): if config["mixed_precision"]: autocast = True @@ -259,7 +274,7 @@ def _test_identical_outputs( ) else: model = ref_ddp_fn(model, group) - ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using FullyShardedDataParallel. @@ -268,7 +283,7 @@ def _test_identical_outputs( model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") - shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) + shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) shard_state_dict = model.state_dict() try: @@ -277,6 +292,14 @@ def _test_identical_outputs( except (AssertionError, RuntimeError) as e: raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") + @parameterized.expand([[1], [inf]], name_func=rename_test) + def test_clip_norm_transformer(self, norm_type): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, norm_type=norm_type, + ) + spawn_and_init(test_fn) + class TestParamInit(DistributedTest): def test_param_change_after_init(self): diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index 7f7cb058e..76b725c23 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -617,6 +617,7 @@ def check(norm): loss_oss = loss_fn(outputs_oss, target) loss_oss.backward() + torch.testing.assert_allclose(loss_oss, loss) # Check the equivalence with the non-sharded optim oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm) From 77dc3649de34acce95a5feb9df41b12fbd3233ff Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 17:08:59 -0500 Subject: [PATCH 46/48] Add ReduceScatterBucketer (#94) * Add ReduceScatterBucketer * Add test * Move chunk_and_pad to fairscale.utils.parallel and add tests * Iterate on tests * CR comments * more * Remove most tests to speed up CI iteration cycle * Fix for Python < 3.8 * more * Revert "Remove most tests to speed up CI iteration cycle" This reverts commit a1981ae496b021a767fbf411f95840c55fad8d17. * CR * lint --- .../fully_sharded_data_parallel.py | 98 +++++++----- fairscale/utils/parallel.py | 34 ++-- fairscale/utils/reduce_scatter_bucketer.py | 151 ++++++++++++++++++ tests/nn/data_parallel/test_fsdp.py | 2 +- tests/nn/data_parallel/test_fsdp_uneven.py | 3 +- tests/utils/test_parallel.py | 24 +-- tests/utils/test_reduce_scatter_bucketer.py | 115 +++++++++++++ 7 files changed, 351 insertions(+), 76 deletions(-) create mode 100644 fairscale/utils/reduce_scatter_bucketer.py create mode 100644 tests/utils/test_reduce_scatter_bucketer.py diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index a67478abb..98751c639 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -27,7 +27,8 @@ unpack_kwargs, unpack_non_tensors, ) -from fairscale.utils.parallel import validate_process_group +from fairscale.utils.parallel import chunk_and_pad, validate_process_group +from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -107,6 +108,12 @@ class FullyShardedDataParallel(nn.Module): move_grads_to_cpu (bool, Optional): move gradient shard to CPU after reduction. This is useful when combined with CPU-based optimizers. It defaults to the value of *``cpu_offload``*. + bucket_cap_mb (int, Optional): FSDP will bucket parameters so that + gradient reduction can potentially overlap with backward + computation. bucket_cap_mb controls the bucket size in MegaBytes + (MB). Buckets are sub-divided based on world_size, so the max shard + size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable + bucketing. Default: 25. """ def __init__( @@ -120,6 +127,7 @@ def __init__( cpu_offload: bool = False, compute_dtype: Optional[torch.dtype] = None, move_grads_to_cpu: Optional[bool] = None, + bucket_cap_mb: int = 25, ): super().__init__() self.process_group = process_group or dist.new_group() @@ -132,6 +140,7 @@ def __init__( self.cpu_offload = cpu_offload self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu + self.bucket_cap_mb = bucket_cap_mb if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") @@ -405,7 +414,7 @@ def no_sync(self) -> Generator: """ self._lazy_init() assert self._is_root, "no_sync on inner FSDP is not supported" - self.assert_idle() + self.assert_state(TrainingState.IDLE) # This instance may wrap other FullyShardedDataParallel instances and we # need to set all of them to accumulate gradients. old_flags = [] @@ -423,6 +432,7 @@ def _reset_lazy_init(self) -> None: """Reset instance so :func:`_lazy_init` will run on the next forward.""" self._is_root: Optional[bool] = None self._streams: Dict[str, torch.cuda.Stream] = {} + self._reducer: Optional[ReduceScatterBucketer] = None def _lazy_init(self) -> None: """Initialization steps that should happen lazily, typically right @@ -438,6 +448,7 @@ def _lazy_init(self) -> None: if self._is_root is None: self._set_is_root() self._setup_streams() + if self.cpu_offload: # Buffers stay on GPU, and don't get sharded self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) else: @@ -555,12 +566,16 @@ def _setup_streams(self) -> None: self._streams["all_gather"] = torch.cuda.Stream() # Stream for overlapping grad reduction with the backward pass. self._streams["post_backward"] = torch.cuda.Stream() + # Helper for bucketing reduce-scatter ops. This is also shared with + # children instances to improve bucket utilization. + self._reducer = ReduceScatterBucketer(self.bucket_cap_mb) # We share streams with all children instances, which allows them to # overlap transfers across the forward pass without synchronizing with # the default stream. for n, m in self.named_modules(): if n != "" and isinstance(m, FullyShardedDataParallel): m._streams = self._streams + m._reducer = self._reducer def _wait_for_previous_optim_step(self) -> None: """ @@ -679,6 +694,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: alignment is created by :func:`_shard_parameters_`, which ensures that the local optimizer only sees the relevant parameter shard. """ + self.assert_state(TrainingState.BACKWARD) if param.grad is None: return if param.grad.requires_grad: @@ -717,24 +733,18 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Average grad by world_size for consistency with PyTorch DDP. param.grad.data.div_(self.world_size) + callback_fn = functools.partial(self._post_reduction_hook, param) if param._is_sharded: - # Reduce-scatter grad. - param.grad.data = self._reduce_scatter_grad(param) - elif self.world_size > 1: - # All-reduce non-sharded grad. - dist.all_reduce(param.grad.data, group=self.process_group) - - # Cast grad to param's dtype (typically FP32). Note: we do this - # before the move_grads_to_cpu step so that this entire hook remains - # non-blocking. The downside is a bit more D2H transfer in that case. - if self.mixed_precision: - param.grad.data = param.grad.data.to(dtype=param.data.dtype) - - # Optionally move gradients to CPU, typically used if one is running - # the optimizer on the CPU. - if self.move_grads_to_cpu: - param._cpu_grad.copy_(param.grad.data, non_blocking=True) - param.grad.data = param._cpu_grad + assert param._is_sharded + assert self._reducer is not None + grad_chunks = chunk_and_pad(param.grad.data, self.world_size) + self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn) + else: + # Currently the only way for _is_sharded to be False is if + # world_size == 1. This could be relaxed in the future, in which + # case grads should be all-reduced here. + assert self.world_size == 1 + callback_fn(param.grad.data) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for @@ -743,10 +753,34 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py orig_grad_data.record_stream(self._streams["post_backward"]) + def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + """Hook to call on each param after the reduce-scatter.""" + assert torch.cuda.current_stream() == self._streams["post_backward"] + assert param.grad is not None + self.assert_state(TrainingState.BACKWARD) + param.grad.data = reduced_grad + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Optionally move gradients to CPU, typically used if one is running + # the optimizer on the CPU. + if self.move_grads_to_cpu: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad + # Don't let this memory get reused until after the transfers. + reduced_grad.record_stream(torch.cuda.current_stream()) + @torch.no_grad() def _wait_for_post_backward(self) -> None: """Wait for post-backward work to finish. Only called on root instance.""" assert self._is_root + self.assert_state(TrainingState.BACKWARD) + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self._streams["post_backward"]): + assert self._reducer is not None + self._reducer.flush() torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. @@ -862,29 +896,11 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No p._fp16_shard.record_stream(current_stream) free_storage_(p._fp16_shard) - @torch.no_grad() - def _reduce_scatter_grad(self, p: nn.Parameter) -> torch.Tensor: - """Reduce-scatter a Parameter's gradient and return a single shard of - the summed gradient across workers.""" - assert p.grad is not None and p._is_sharded - grad_chunks = list(torch.flatten(p.grad.data).chunk(self.world_size)) - - # torch.chunk may return fewer than world_size chunks, pad accordingly. - num_pad_for_partial_chunk = grad_chunks[0].numel() - grad_chunks[-1].numel() - if num_pad_for_partial_chunk > 0: - grad_chunks[-1] = F.pad(grad_chunks[-1], [0, num_pad_for_partial_chunk]) - if len(grad_chunks) < self.world_size: - grad_chunks.extend([torch.zeros_like(grad_chunks[0])] * (self.world_size - len(grad_chunks))) - - output = torch.zeros_like(grad_chunks[0]) # filled with gradient summed across workers - dist.reduce_scatter(output, grad_chunks, group=self.process_group) - return output - - def assert_idle(self) -> None: - """Assert we are in the idle state.""" + def assert_state(self, state: TrainingState) -> None: + """Assert we are in the given state.""" assert ( - self.training_state == TrainingState.IDLE - ), f"wrong state to call no_sync. current state is {self.training_state}" + self.training_state == state + ), f"expected to be in state {state} but current state is {self.training_state}" @torch.no_grad() diff --git a/fairscale/utils/parallel.py b/fairscale/utils/parallel.py index a02ebb13f..2b3cfaf30 100644 --- a/fairscale/utils/parallel.py +++ b/fairscale/utils/parallel.py @@ -5,32 +5,24 @@ """Useful functions for parallel training.""" +from typing import List + import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import torch.nn.functional as F -def compute_shard_size(numel: int, world_size: int) -> int: - """Compute shard size like the behavior of torch.chunk().""" - assert numel > 0 and world_size > 0, "invalid inputs" - if numel % world_size == 0: - # easy case, including world_size == 1. - shard_size = numel // world_size - else: - if world_size == 2: - # two shards, shard size is the size of the bigger one. - shard_size = numel // world_size + 1 - else: - # find the equal chunks until remainder is smaller than shard_size - for div in range(world_size - 1, 1, -1): - shard_size, rem = divmod(numel, div) - if shard_size >= rem: - break - # corner case: bunch of 1 elements and rest are 0s. - if shard_size == 0: - shard_size = 1 - assert shard_size > 0, f"bug: {shard_size}" - return shard_size +def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: + """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" + chunks = list(torch.flatten(tensor).chunk(num_chunks)) + # torch.chunk may return fewer than num_chunks chunks, pad accordingly. + num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) + if len(chunks) < num_chunks: + chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) + return chunks def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: diff --git a/fairscale/utils/reduce_scatter_bucketer.py b/fairscale/utils/reduce_scatter_bucketer.py new file mode 100644 index 000000000..b8e2eba54 --- /dev/null +++ b/fairscale/utils/reduce_scatter_bucketer.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Bucket: + def __init__(self, data: Tensor, group: ProcessGroup): + self.data = data + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + self.output_shard = torch.zeros_like(data[0]) + + def flush(self) -> None: + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.reduce_scatter( + self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group + ) + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.data[:, : self.offset].zero_() + self.offset = 0 + self.callbacks.clear() + self.output_shard = torch.zeros_like(self.data[0]) + + +class ReduceScatterBucketer: + """ + Helper for bucketing multiple reduce-scatter operations on small tensors + into larger reduce-scatter ops to improve communication efficiency. + + Usage:: + + bucketer = ReduceScatterBucketer() + bucketer.reduce_scatter_async( + small_tensors, callback_fn=lambda result: print("small") + ) + bucketer.reduce_scatter_async( + big_tensors, callback_fn=lambda result: print("big") + ) + bucketer.reduce_scatter_async( + more_small_tensors, callback_fn=lambda result: print("small2") + ) + bucketer.flush() # callbacks only guaranteed to be called after flush() + # Example output (note that it is out of order, due to bucketing): + # big + # small + # small2 + + Args: + bucket_cap_mb (int, Optional): bucket size for communicating. Buckets + are sub-divided based on world_size. Values <= 0 disable bucketing. + """ + + def __init__(self, bucket_cap_mb: int = 25): + self.bucket_cap_mb = bucket_cap_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def reduce_scatter_async( + self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, + ) -> None: + """ + Reduce-scatter a list of tensors asynchronously, so smaller reductions + can be bucketed together. The given callback (``callback_fn``) will be + called with the reduced result at some later time. Call ``flush()`` to + force all queued ops and callbacks to be executed. + + Note that large inputs will be reduced immediately, and this function + may also flush the relevant bucket to make room for ``input_list``. + + Args: + input_list (List[Tensor]): list of tensors to reduce-scatter. List + should contain ``group.size()`` tensors and each tensor should + have identical shape, dtype and device. + group (ProcessGroup): process group for reduction + callback_fn (Callable, Optional): callback function to call after + the reduction executes. Function will be called with a single + argument corresponding to the reduced result. + """ + world_size = group.size() + + assert ( + len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + + first_input = input_list[0] + first_input_size = first_input.numel() + + bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) + if first_input_size > bucket_shard_size: + # input is too big to fit in the bucket, reduce-scatter directly + output = torch.zeros_like(input_list[0]) + dist.reduce_scatter(output, input_list, group=group) + if callback_fn is not None: + callback_fn(output) + return + + bucket = self._get_bucket(first_input, group) + if first_input_size > bucket.data.size(1) - bucket.offset: + # not enough space remaining in bucket, flush it now + bucket.flush() + + # copy data from input_list into bucket + stacked_input = torch.stack(input_list).view(world_size, first_input_size) + offset = bucket.offset + bucket.data[:, offset : offset + first_input_size].copy_(stacked_input) + bucket.offset += first_input_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input) + bucket.callbacks.append(functools.partial(callback_fn, result_view)) + + @torch.no_grad() + def flush(self) -> None: + """Reduce-scatter any partial buckets.""" + for bucket in self.buckets.values(): + bucket.flush() + + @functools.lru_cache() + def _get_shard_size(self, element_size: int, num_shards: int) -> int: + if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_cap_mb * MB / element_size + return int(bucket_size // num_shards) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) + world_size = group.size() + shard_size = self._get_shard_size(tensor.element_size(), world_size) + data = tensor.new_zeros((world_size, shard_size)) + self.buckets[key] = Bucket(data, group) + return self.buckets[key] diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 15948f096..27f4f5986 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -36,7 +36,7 @@ class DistributedTest(unittest.TestCase): def setUp(self): major, minor = torch.__version__.split(".")[:2] major, minor = int(major), int(minor) - if major < 1 or minor < 6: + if major < 1 or (major == 1 and minor < 6): raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast") if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") diff --git a/tests/nn/data_parallel/test_fsdp_uneven.py b/tests/nn/data_parallel/test_fsdp_uneven.py index 0a5b5f761..e2ce116db 100644 --- a/tests/nn/data_parallel/test_fsdp_uneven.py +++ b/tests/nn/data_parallel/test_fsdp_uneven.py @@ -19,6 +19,7 @@ from torch.optim import SGD from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP +from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version @@ -48,7 +49,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_out, out) - model.assert_idle() + model.assert_state(TrainingState.IDLE) teardown() diff --git a/tests/utils/test_parallel.py b/tests/utils/test_parallel.py index 00a62d521..a09c2a385 100644 --- a/tests/utils/test_parallel.py +++ b/tests/utils/test_parallel.py @@ -7,20 +7,20 @@ # pylint: disable=missing-class-docstring # pylint: disable=missing-function-docstring -""" Test utility classes from containers.py. """ +""" Test utility classes from fairscale.utils.parallel """ -import pytest +from parameterized import parameterized import torch -from fairscale.utils.parallel import compute_shard_size +from fairscale.utils.parallel import chunk_and_pad -@pytest.mark.parametrize( - "test_case", [(1, 2), (2, 2), (3, 2), (4, 2), (4, 4), (3, 4), (9, 4), (9, 6), (10, 5), (14, 5)] -) -def test_compute_shard_size(test_case): - """Test compute_shard_size, verify using torch.chunk()""" - numel, world_size = test_case - result = compute_shard_size(numel, world_size) - expected = torch.zeros(numel).chunk(world_size)[0].numel() - assert result == expected, f"{result} == {expected}" +@parameterized.expand([[num_chunks] for num_chunks in range(1, 33)]) +def test_chunk_and_pad(num_chunks): + max_tensor_size = 256 + tensor = torch.zeros(max_tensor_size) + for tensor_size in range(1, max_tensor_size + 1): + tensor_i = tensor[:tensor_size] + chunks = chunk_and_pad(tensor_i, num_chunks) + assert len(chunks) == num_chunks + assert all(len(chunks[0]) == len(chunk) for chunk in chunks) diff --git a/tests/utils/test_reduce_scatter_bucketer.py b/tests/utils/test_reduce_scatter_bucketer.py new file mode 100644 index 000000000..4baa23c29 --- /dev/null +++ b/tests/utils/test_reduce_scatter_bucketer.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import itertools +import sys +import unittest +from unittest import mock + +from parameterized import parameterized +import torch + +from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer +from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes + + +def rename_test(testcase_func, param_num, param): + return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) + + +CONFIG_OPTIONS = [ + [dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144]) +] + + +class TestReduceScatterBucketer(unittest.TestCase): + # TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`. + def setUp(self): + major, minor = torch.__version__.split(".")[:2] + major, minor = int(major), int(minor) + if major < 1 or (major == 1 and minor < 6): + raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter") + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_reduce_scatter(self, config): + spawn_and_init(functools.partial(self._test_reduce_scatter, **config)) + + @staticmethod + def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None): + bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb) + world_size = group.size() + + tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)] + tensors[rank].fill_(0) + + input_bytes = shard_size * world_size * 4 + bucket_bytes = bucket_cap_mb * 1024 * 1024 + + callback = mock.MagicMock() + bucketer.reduce_scatter_async(tensors, group, callback_fn=callback) + + if bucket_cap_mb > 0 and input_bytes < bucket_bytes: + assert callback.call_count == 0 + bucketer.flush() + assert callback.call_count == 1 + + result = callback.call_args[0][0] # get first positional arg + assert torch.is_tensor(result), result + assert torch.all(result == (world_size - 1)) + + def test_out_of_order_reduction(self): + spawn_and_init(self._test_out_of_order_reduction) + + @staticmethod + def _test_out_of_order_reduction(rank, group): + bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25) + world_size = group.size() + + small_tensors = [torch.ones(1).cuda() for _ in range(world_size)] + big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)] + more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)] + + callback1 = mock.MagicMock() + callback2 = mock.MagicMock() + callback3 = mock.MagicMock() + + bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1) + assert callback1.call_count == 0 + bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2) + assert callback1.call_count == 0 + assert callback2.call_count == 1 + bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3) + assert callback1.call_count == 0 + assert callback2.call_count == 1 + assert callback3.call_count == 0 + + bucketer.flush() + assert callback1.call_count == 1 + assert callback2.call_count == 1 + assert callback3.call_count == 1 + + +def spawn_and_init(fn, args=None, **spawn_kwargs): + if args is None: + args = () + run_fn = functools.partial(init_and_run, fn, args) + spawn_for_all_world_sizes(run_fn, **spawn_kwargs) + + +def init_and_run(fn, args, rank, world_size, filename, filename_rpc): + dist_init(rank, world_size, filename, filename_rpc) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +if __name__ == "__main__": + unittest.main() From ec8bf9f1ef8d6856c8c7e4338d5cbce9c328b682 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 14:21:23 -0800 Subject: [PATCH 47/48] Fix merge conflict --- fairscale/optim/oss.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index a9db5c4b7..035c80137 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -15,11 +15,7 @@ from torch.nn import Parameter from torch.optim import SGD, Optimizer -<<<<<<< HEAD -from .utils import Workhandle, broadcast_object, calc_grad_norm, recursive_copy_to_device -======= -from .utils import broadcast_object, recursive_copy_to_device ->>>>>>> oss-master +from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device __all__ = ["OSS"] From 3f2b7f1d6a3f47c908a34c441a9ce82373c204dd Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 14:44:36 -0800 Subject: [PATCH 48/48] Update docstring slightly --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 98751c639..2c19b6d42 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -55,8 +55,8 @@ class TrainingState(Enum): class FullyShardedDataParallel(nn.Module): """ - A wrapper for sharding Module parameters. This is inspired by `Xu et al.`_ - as well as the ZeRO Stage 3 from the DeepSpeed_ work. + A wrapper for sharding Module parameters across data parallel workers. This + is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 .. _DeepSpeed: https://www.deepspeed.ai/