Skip to content

Commit

Permalink
Leave buffers on self.compute_device (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored Feb 9, 2021
1 parent 74b0223 commit 6797964
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
27 changes: 19 additions & 8 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,18 @@ def __init__(
# Shard module parameters in place
self._shard_parameters_()

if self.mixed_precision:
# Cast all module buffers to FP16 (buffers are not sharded).
self.apply(cast_buffers_to_fp16)

# Make sure all parameters are sharded.
for n, p in self.named_parameters():
assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}"

self._reset_lazy_init()

@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)

@torch.no_grad()
def _shard_parameters_(self) -> None:
"""
Expand Down Expand Up @@ -217,17 +219,19 @@ def state_dict(self, *args, **kwargs): # type: ignore
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic.
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32
"""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters.
state_dict = self.module.state_dict(*args, **kwargs)
# We don't free the params after generating the state dict, since
# freeing is done in-place (via the Storage) and would corrupt the
# returned state dict. However, we need to maintain the invariant that
# p.data corresponds to the FP32 param shard, so we do that here.
self._use_fp32_param_shard()
self._all_buffers_to(dtype=self.compute_dtype)
return state_dict

# TODO (Min): figuring out how to do typing for this overloaded function.
Expand Down Expand Up @@ -278,6 +282,10 @@ def _lazy_init(self) -> None:
if self._is_root is None:
self._set_is_root()
self._setup_streams()
if self.cpu_offload: # Buffers stay on GPU, and dont get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype)
else:
self._all_buffers_to(dtype=self.compute_dtype)

# Don't free the full params for the outer-most (root) instance, since
# those params will be needed immediately after for the backward pass.
Expand Down Expand Up @@ -652,11 +660,14 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
return args, kwargs


def cast_buffers_to_fp16(module: nn.Module) -> None:
"""Cast buffers of a module to FP16."""
def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device, dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
setattr(module, key, buf.half())
setattr(module, key, buf.to(dtype=dtype, device=device))


def free_storage_(data: torch.Tensor) -> None:
Expand Down
30 changes: 19 additions & 11 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4

_BUFFER_NAME = "vocab_bias"


class DistributedTest(unittest.TestCase):
def setUp(self):
Expand All @@ -42,9 +44,9 @@ def setUp(self):
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

@staticmethod
def _train_for_several_steps(model, num_steps, autocast):
def _train_for_several_steps(model, num_steps, autocast, lr=0.01):
model_device = next(model.parameters()).device
optim = torch.optim.Adam(model.parameters(), lr=0.01)
optim = torch.optim.Adam(model.parameters(), lr=lr)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
Expand Down Expand Up @@ -178,7 +180,11 @@ def test_cpu_offload_and_cpu_grads(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False)
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.001
)
# We use lower lr to reduce this test's sensitivity to slightly different CPU vs CUDA behavior of pytorch.
# With lr=0.01, it fails on torch 1.6.0.
spawn_and_init(test_fn)

def test_cpu_offload_and_cuda_grads_breaks(self):
Expand Down Expand Up @@ -210,7 +216,7 @@ def test_delayed_reduce_scatter(self):
spawn_and_init(test_fn)

@classmethod
def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True):
def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True, lr=0.01):
if config["mixed_precision"]:
autocast = True
# Force the compute dtype to be torch.float32 so that we get
Expand All @@ -224,7 +230,7 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr)
ref_state_dict = model.module.state_dict()

# Confirm we get the same behavior using ShardParamsDataParallel.
Expand All @@ -233,14 +239,14 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast)
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr)
shard_state_dict = model.state_dict()

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}")
raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")


class TestParamInit(DistributedTest):
Expand Down Expand Up @@ -332,7 +338,7 @@ def test_local_state_dict_odd_vocab_shape_breaks(self):
spawn_and_init(test_fn)

@classmethod
def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32):
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=16):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = ShardParamsDataParallel(
TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config
Expand All @@ -346,11 +352,11 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32):
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# This weight will be sharded since we access module.state_dict directly
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 4, model.mixed_precision)
self._train_for_several_steps(model, 1, model.mixed_precision)

state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
Expand All @@ -361,7 +367,7 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32):
# Assert that parameters were updated since before training
unchanged = []
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all():
if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
Expand Down Expand Up @@ -520,6 +526,7 @@ def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs):
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,)))

def get_input(self, device):
torch.manual_seed(1) # keep everything deterministic
Expand All @@ -529,6 +536,7 @@ def get_input(self, device):

def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias
tgt = self.embed_tokens(tgt_ids)
x = self.transformer(src, tgt)
return self.output_proj(x)
Expand Down

0 comments on commit 6797964

Please sign in to comment.