diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index bb99a9287e..2e5fd5d07b 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -289,21 +289,21 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2): @pytest.mark.gpu @pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning') @pytest.mark.parametrize( - 'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp', + 'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp,use_hsdp', [ - pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False, + pytest.param('adam', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adamw', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adam', True, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adam', False, 'amp_fp16', False, False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adam', False, 'amp_bf16', True, True, False, False, False, marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only - pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)), - pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)), + pytest.param('adam', False, 'amp_bf16', False, True, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adam', False, 'amp_bf16', False, False, True, False, False, marks=pytest.mark.world_size(2)), + pytest.param('adam', False, 'amp_bf16', False, False, False, True, False, marks=pytest.mark.world_size(4)), + pytest.param('adam', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)), ], ) def test_fsdp_full_state_dict_load( - world_size, tmp_path: pathlib.Path, autoresume: bool, precision: str, @@ -312,7 +312,10 @@ def test_fsdp_full_state_dict_load( load_weights_only: bool, load_monolith_rank0_only: bool, use_tp: bool, + use_hsdp: bool, ): + if use_hsdp: + pytest.xfail('Known Pytorch issue with HSDP, waiting for pytorch patch') if autoresume: run_name = 'my-cool-autoresume-run' else: @@ -320,11 +323,20 @@ def test_fsdp_full_state_dict_load( save_folder = tmp_path save_filename = 'rank{rank}.pt' - fsdp_config = FSDPConfig( - sharded_ckpt_prefix_dir='ba{batch}', - sync_module_states=load_monolith_rank0_only, - load_monolith_rank0_only=load_monolith_rank0_only, - ) + if use_hsdp: + fsdp_config = FSDPConfig( + sharding_strategy='HYBRID_SHARD', + sharded_ckpt_prefix_dir='ba{batch}', + data_parallel_shard_degree=2, + data_parallel_replicate_degree=2, + sync_module_states=True, + ) + else: + fsdp_config = FSDPConfig( + sharded_ckpt_prefix_dir='ba{batch}', + sync_module_states=load_monolith_rank0_only, + load_monolith_rank0_only=load_monolith_rank0_only, + ) tp_config = None if use_tp: from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -778,23 +790,33 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) @pytest.mark.parametrize( - 'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp', + 'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp', [ - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)), - pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)), + pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param( + False, + 'adamw', + 'amp_bf16', + False, + ['rng'], + False, + False, + False, + marks=pytest.mark.world_size(2), + ), + pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)), + pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)), + pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)), ], ) @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') @pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning') def test_fsdp_partitioned_state_dict_load( - world_size, tmp_path: pathlib.Path, autoresume: bool, precision: str, @@ -803,6 +825,7 @@ def test_fsdp_partitioned_state_dict_load( load_ignore_keys: Union[list[str], None], use_symlink: bool, use_tp: bool, + use_hsdp: bool, use_remote, s3_bucket, s3_ephemeral_prefix, @@ -829,10 +852,19 @@ def test_fsdp_partitioned_state_dict_load( save_filename = 'ba{batch}-rank{rank}.pt' - fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') + if use_hsdp: + fsdp_config = FSDPConfig( + sharding_strategy='HYBRID_SHARD', + sharded_ckpt_prefix_dir='ba{batch}', + state_dict_type='sharded', + data_parallel_shard_degree=2, + data_parallel_replicate_degree=2, + sync_module_states=True, + ) + else: + fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') tp_config = None if use_tp: - fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}') from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel tp_config = { 'tensor_parallel_degree': 2,