Skip to content

Commit

Permalink
Do reduce-scatter in a separate CUDA stream (#62)
Browse files Browse the repository at this point in the history
* Do reduce-scatter in a separate CUDA stream

* Add _post_backward_stream to stubs
  • Loading branch information
myleott authored Feb 4, 2021
1 parent bc5190b commit 8a5f81c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
39 changes: 25 additions & 14 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def _init_param(self, p: Parameter) -> None:
# shard in pinned memory so that we can do a non-blocking transfer.
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

# Stream for overlapping the backward pass and gradient reductions.
p._post_backward_stream = torch.cuda.Stream()

@torch.no_grad()
def _pre_forward_init(self) -> None:
first_time_params = [p for p in self.params if not hasattr(p, "_full_param")]
Expand Down Expand Up @@ -430,29 +433,35 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Free full params and switch to FP32 shard after backward.
self._free_full_params([param])
self._use_fp32_param_shard([param])

if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self._free_fp16_param_shard([param])

if self.fp32_reduce_scatter:
# Wait for all work in the current stream to finish, then start the
# reductions in _post_backward_stream.
param._post_backward_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(param._post_backward_stream):
if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)

# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)

# Reduce-scatter grad.
param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data))
# Reduce-scatter grad.
param.grad.data = self._reduce_scatter(torch.flatten(param.grad.data))

# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)

if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad

# Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback.
Expand All @@ -463,6 +472,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for all post-backward work to finish."""
for p in self.params:
torch.cuda.current_stream().wait_stream(p._post_backward_stream)
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
Expand Down
2 changes: 2 additions & 0 deletions stubs/torch/nn/parameter.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Optional
from .. import Size, Tensor
from ..cuda import Stream
import builtins

class Parameter(Tensor):
Expand All @@ -13,6 +14,7 @@ class Parameter(Tensor):
_full_param: Tensor
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]
_post_backward_stream: Stream

def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...

Expand Down
15 changes: 15 additions & 0 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,21 @@ def _maybe_wrap(layer):
)
return ModuleWithDelay(model, delay_after_loss_ms=250)

def test_delayed_reduce_scatter(self):
# We insert a delay in the torch.distributed.reduce_scatter op, so that
# the post_backward_stream takes much longer than the backward pass.
# This tests that we properly block at the end of the backward pass for
# the reductions to finish.
config = {"mixed_precision": True}
with mock.patch("torch.distributed.reduce_scatter", wraps=self._delayed_reduce_scatter):
test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config)
spawn_and_init(test_fn)

@classmethod
def _delayed_reduce_scatter(cls, *args, **kwargs):
torch.cuda._sleep(int(250 * get_cycles_per_ms()))
return torch.distributed.reduce_scatter(*args, **kwargs)

@classmethod
def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True):
if config["mixed_precision"]:
Expand Down

0 comments on commit 8a5f81c

Please sign in to comment.