Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sharded grad scaler to cpu offload with FSDP #831

Merged
merged 33 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e51e891
first commit
Oct 19, 2021
c63d1c3
sharded scaler hitting nan assertions
Oct 21, 2021
a39b37b
adding test for sharded grad scaler without cpu offload
Oct 26, 2021
daafd25
ddp grad scaler and fsdp sharded grad scaler test failing
Oct 27, 2021
bc9e244
removing test_output
Oct 27, 2021
917ae0e
fix no cpu offload test
Oct 27, 2021
65c3093
changing optimizer from OSS to SGD
Oct 27, 2021
3404b69
all tests passing, code cleanup pending
Oct 28, 2021
cf7d2e2
code cleanup
Oct 29, 2021
fa18c8e
fix pyproject.toml
Oct 29, 2021
6bb3a71
removing .isort.cfg
Oct 29, 2021
cda515f
resolving merge conflicts
Oct 29, 2021
98d04fd
running isort linter
Nov 1, 2021
41d012e
resolving isort issues
Nov 2, 2021
6346960
resolving black linter issue
Nov 2, 2021
c4e94a5
resolving mypy issues
Nov 2, 2021
8d79f1b
fix import statement
Nov 2, 2021
006db9c
Merge branch 'main' into cpu_gradscaler. Taking in changes from PR 838.
Nov 2, 2021
9ad7d3e
fix mypy error
Nov 2, 2021
bd7c7a9
modifying import statement
Nov 2, 2021
a51b49d
adding pytorch version requirement
Nov 3, 2021
cc63fbd
fixing pytest skip test decorator
Nov 3, 2021
a973fb6
apply version guard for ShardedGradScaler
Nov 3, 2021
d4bb7c5
removing test_fsdp_grad_scaler
Nov 3, 2021
5fb0a77
increasing num_epochs for ShardedGradScaler so that updates are not s…
Nov 4, 2021
345835b
adding support for torch 1.8
Nov 11, 2021
b5cfc86
minor edit
Nov 11, 2021
1ad6277
[skip ci] more torch 1.8 changes
Nov 12, 2021
110e52d
parametrizing the tests
Nov 12, 2021
447b9db
Merge branch 'main' into cpu_gradscaler
Nov 12, 2021
ad5e979
cleanup code with linters
Nov 12, 2021
e693c31
[skip ci] update doc string
Nov 12, 2021
acb4304
[skip ci] addressing some more comments
Nov 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fairscale/experimental/nn/distributed_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


def check_pytorch_version() -> None:
if torch_version() < (1, 9, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
if torch_version() < (1, 8, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.8 or higher")


MOVING_DENIED = TypeError(
Expand Down
383 changes: 362 additions & 21 deletions fairscale/optim/grad_scaler.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/ci_test_list_1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
Expand Down
39 changes: 24 additions & 15 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from unittest import mock

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed
Expand All @@ -29,6 +30,9 @@
spawn_for_all_world_sizes,
)

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod

Expand All @@ -49,14 +53,17 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
scaler = ShardedGradScaler()
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
loss = scaler.scale(loss)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
if norm_type is not None:
Expand All @@ -65,10 +72,10 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None
model.clip_grad_norm_(clip_norm, norm_type)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
params = [p for p in model.parameters()]
print(f"params.device {params[0].device} param.grad.device {params[0].grad.device}")

optim.step()
scaler.step(optim)
scaler.update()
if hasattr(model, "assert_idle"):
model.assert_idle()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
Expand Down Expand Up @@ -308,21 +315,21 @@ def test_transformer_parameterized(self, config):
# Test every combination of these options:
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config))

def test_cpu_offload_and_cpu_grads(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
# testing moving params to cpu while using full and mixed precision
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_cpu_offload_and_cpu_grads(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "cpu_offload": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
)
spawn_and_init(test_fn)

def test_cpu_offload_and_cpu_grads_no_mixed_precision(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": False, "cpu_offload": True, "move_grads_to_cpu": True}
# testing full and mixed precision on the gpu
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_no_cpu_offload_with_sharded_grad_scaler(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "move_params_to_cpu": False}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=True, lr=0.01
)
spawn_and_init(test_fn)

Expand Down Expand Up @@ -485,10 +492,10 @@ def _one_step(self, model, group):
optim.step()


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used

@parameterized.expand([[True], [False]])
def test_output_backward_hooks(self, cuda_first):
fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first)
Expand Down Expand Up @@ -541,6 +548,7 @@ def _test_register_functions_called(self, rank, group, cuda_first=False):
assert model._register_pre_backward_hooks.called


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
Expand Down Expand Up @@ -568,6 +576,7 @@ def _test_transformer(self, rank, group, config):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config):
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/data_parallel/test_fsdp_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import unittest

from parameterized import parameterized
import pytest
import torch.nn as nn

from fairscale.utils import torch_version

from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
Expand All @@ -19,6 +22,7 @@
)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config):
Expand Down
63 changes: 0 additions & 63 deletions tests/nn/data_parallel/test_fsdp_grad_scaler.py

This file was deleted.

8 changes: 5 additions & 3 deletions tests/nn/data_parallel/test_fsdp_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import (
dist_init,
Expand All @@ -47,6 +46,9 @@
torch_cuda_version,
)

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
Expand Down Expand Up @@ -352,8 +354,8 @@ def dump(d):
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")

state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref

Expand Down
9 changes: 8 additions & 1 deletion tests/nn/data_parallel/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx

from .test_fsdp import (
Expand All @@ -23,6 +25,7 @@
)


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
Expand Down Expand Up @@ -50,7 +53,9 @@ def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
# increasing number of epochs from 1 to 6 for ShardedGradScaler to work properly.
# test fails for num_epochs < 6 since the updates are skipped due to gradient being inf.
self._train_for_several_steps(model, 6, model.mixed_precision)

state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
Expand All @@ -69,6 +74,7 @@ def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
raise AssertionError(f"params {unchanged} not changed after training")


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
Expand Down Expand Up @@ -178,6 +184,7 @@ def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, l
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload):
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/data_parallel/test_fsdp_summon_full_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import unittest

from parameterized import parameterized
import pytest
import torch

from fairscale.utils.version import torch_version

from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init


Expand All @@ -19,6 +22,7 @@ def get_cuda_mem():
return torch.cuda.memory_allocated()


@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config):
Expand Down
6 changes: 5 additions & 1 deletion tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler

"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
"""
Expand Down Expand Up @@ -249,6 +251,8 @@ def test_ddp_parity(
manual_reduction,
multiple_fw,
):
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")

Expand Down