From 8a5f81c7ba7cca3da8d0d95781b79e7ce6802a6a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Feb 2021 13:13:23 -0500 Subject: [PATCH] Do reduce-scatter in a separate CUDA stream (#62) * Do reduce-scatter in a separate CUDA stream * Add _post_backward_stream to stubs --- .../shard_params_data_parallel.py | 39 ++++++++++++------- stubs/torch/nn/parameter.pyi | 2 + .../test_shard_params_data_parallel.py | 15 +++++++ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index ab29c0418..abd6f35ea 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -300,6 +300,9 @@ def _init_param(self, p: Parameter) -> None: # shard in pinned memory so that we can do a non-blocking transfer. p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() + # Stream for overlapping the backward pass and gradient reductions. + p._post_backward_stream = torch.cuda.Stream() + @torch.no_grad() def _pre_forward_init(self) -> None: first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] @@ -430,29 +433,35 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Free full params and switch to FP32 shard after backward. self._free_full_params([param]) self._use_fp32_param_shard([param]) - if self.mixed_precision: + # This is a no-op if reshard_after_forward is True, since we already + # free the param shard when rebuilding the full params in the + # pre_backward_hook. self._free_fp16_param_shard([param]) - if self.fp32_reduce_scatter: + # Wait for all work in the current stream to finish, then start the + # reductions in _post_backward_stream. + param._post_backward_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(param._post_backward_stream): + if self.mixed_precision and self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) - # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.world_size) + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.world_size) - # Reduce-scatter grad. - param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) + # Reduce-scatter grad. + param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data)) - # Cast grad to param's dtype (typically FP32). Note: we do this - # before the move_grads_to_cpu step so that this entire hook remains - # non-blocking. The downside is a bit more D2H transfer in that case. - if self.mixed_precision: - param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) - if self.move_grads_to_cpu: - param._cpu_grad.copy_(param.grad.data, non_blocking=True) - param.grad.data = param._cpu_grad + if self.move_grads_to_cpu: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad # Enqueue a callback at the end of the backward pass to ensure that all # post-backward work has finished. We only need one callback. @@ -463,6 +472,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @torch.no_grad() def _wait_for_post_backward(self) -> None: """Wait for all post-backward work to finish.""" + for p in self.params: + torch.cuda.current_stream().wait_stream(p._post_backward_stream) if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index 05f24df38..7cd76d9d2 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -2,6 +2,7 @@ from typing import Optional from .. import Size, Tensor +from ..cuda import Stream import builtins class Parameter(Tensor): @@ -13,6 +14,7 @@ class Parameter(Tensor): _full_param: Tensor _fp32_shard: Tensor _fp16_shard: Optional[Tensor] + _post_backward_stream: Stream def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 9bd9346ab..452f2421a 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -203,6 +203,21 @@ def _maybe_wrap(layer): ) return ModuleWithDelay(model, delay_after_loss_ms=250) + def test_delayed_reduce_scatter(self): + # We insert a delay in the torch.distributed.reduce_scatter op, so that + # the post_backward_stream takes much longer than the backward pass. + # This tests that we properly block at the end of the backward pass for + # the reductions to finish. + config = {"mixed_precision": True} + with mock.patch("torch.distributed.reduce_scatter", wraps=self._delayed_reduce_scatter): + test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config) + spawn_and_init(test_fn) + + @classmethod + def _delayed_reduce_scatter(cls, *args, **kwargs): + torch.cuda._sleep(int(250 * get_cycles_per_ms())) + return torch.distributed.reduce_scatter(*args, **kwargs) + @classmethod def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: