Skip to content

Commit

Permalink
More docs (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Feb 1, 2021
1 parent 5fb36d1 commit 774e130
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@ class ShardParamsDataParallel(nn.Module):
Usage::
sharded_module = ShardParamsDistributedWrapper(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters::
model = nn.Sequential(
nn.Linear(5, 100),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
nn.Linear(100, 5),
wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the
unsharding (all-gather) across the forward pass. For example::
sharded_model = ShardParamsDistributedWrapper(
nn.Sequential(
nn.Linear(5, 100),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
nn.Linear(100, 5),
)
)
sharded_model = ShardParamsDistributedWrapper(model)
x = sharded_model(x)
loss = x.sum()
loss.backward()
Args:
module (nn.Module): module to checkpoint
Expand Down Expand Up @@ -129,6 +131,27 @@ def __init__(

@torch.no_grad()
def _shard_initial_params(self):
"""
At initialization we wrap a module with full parameters and shard the
parameters in-place. Sharding is implemented by viewing each parameter
as a 1D Tensor and retaining only a single slice, where the slice size
is determined by the number of data parallel workers.
Wrapping modules with many small parameters (or with a very large data
parallel world size) will result in many small parameter shards and slow
performance. In this case it's better to set *flatten_parameters* to
``True``, so that all of the small parameters in the module are combined
into a single contiguous Tensor and sharded once.
After this initial sharding is complete, the user can initialize a
``torch.optim.Optimizer`` in the usual way, i.e.::
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
The optimizer will see only a single slice of parameters and will thus
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
Expand Down

0 comments on commit 774e130

Please sign in to comment.