diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 2af494e68b..c47184eada 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -367,8 +367,7 @@ def load_sharded_checkpoint( ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, -) -> list[dict]: - +) -> Union[list[dict], None]: if not using_torch_2(): raise ValueError( f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}') @@ -389,16 +388,6 @@ def load_sharded_checkpoint( from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner - # This function is used so we can figure out which ranks need to load saved rngs and which can just make their own. - def _get_num_ranks_that_saved_rng(metadata: Metadata): - rng_inds = [] - for field_name, field_value in metadata.planner_data.items(): - if 'rng' in field_name: - _, rng_rank_index, _ = field_value - rng_inds.append(rng_rank_index) - rng_inds = set(rng_inds) - return len(rng_inds) - class FileSystemReaderWithValidation(dist_cp.FileSystemReader): """FileSystemReader that validates checkpoint files prior to reading.""" @@ -501,13 +490,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): with torch.no_grad(): # 1. Load model and metadata first if load_weights_only: - state_dict = {'state': {'model': state.get_model_state_dict()}} + state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() # For older versions of torch, we load optimizer separately. if version.parse(torch.__version__) < version.parse('2.1.3'): cur_state_dict.pop('optimizers') - state_dict = {'state': cur_state_dict} + state_dict: Dict[str, Any] = { + 'state': cur_state_dict, + 'rng': reproducibility.get_rng_state(), + } if ignore_keys: # Filter provided list of key paths @@ -518,17 +510,32 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # Ensure state exists state_dict['state'] = state_dict.get('state', {}) + # Only some ranks are meant to load checkpoint + expect_file = False + process_group = None + device_mesh = state.fsdp_device_mesh + if device_mesh is not None and device_mesh.ndim == 2: + # If hybrid shard, only rank in first replica saves + expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) + if expect_file: + process_group = device_mesh.get_group(1) # Shard process_group for first replica + log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}') + else: + expect_file = True + if version.parse(torch.__version__) > version.parse('2.1.3'): dist_cp.load( # type: ignore state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, + process_group=process_group, ) else: dist_cp.load_state_dict( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, + process_group=process_group, ) state.load_state_dict( @@ -547,26 +554,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): storage_reader=storage_reader) state._legacy_load_optim_state(optim_state) - # 3. Optionally load RNG - rng_state_dicts = reproducibility.get_rng_state() - if not load_weights_only: - # If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks - num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata()) - rng_state_dicts_load = {} - rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len( - rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts - dist_cp.load_state_dict( - state_dict=rng_state_dicts_load, - storage_reader=storage_reader, - planner=load_planner, - ) - # We also want to append newly generated rng states for the ranks that don't have an rng state to load in - # if we are resuming on more ranks than were used at save time. - if len(rng_state_dicts) > num_ranks_that_saved_rng: - rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:]) - rng_state_dicts = rng_state_dicts_load['rng'] - - return rng_state_dicts + return state_dict.get('rng', None) def _get_local_rank_zero_path(path: Optional[str]) -> str: @@ -1010,9 +998,10 @@ def _save_checkpoint( process_group = None device_mesh = state.fsdp_device_mesh if device_mesh is not None and device_mesh.ndim == 2: + # If hybrid shard, only rank in first replica saves expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) if expect_file: - process_group = device_mesh.get_group(1) # Only save on first replica + process_group = device_mesh.get_group(1) # Shard process_group for first replica log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}') else: expect_file = True diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 28216d3541..5bd416f4c7 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -11,7 +11,7 @@ import uuid from contextlib import nullcontext as does_not_raise from functools import partial -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union from unittest.mock import patch import numpy as np @@ -600,12 +600,16 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu @world_size(2) -@pytest.mark.parametrize('weights_only', [False, True]) -@pytest.mark.parametrize('optimizer', ['adam', 'adamw']) @pytest.mark.parametrize('state_dict_type', ['sharded', 'local']) -@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16']) @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) -@pytest.mark.parametrize('autoresume', [True, False]) +@pytest.mark.parametrize('weights_only,optimizer,precision,autoresume,load_ignore_keys', [ + [False, 'adamw', 'amp_bf16', False, None], + [True, 'adamw', 'amp_bf16', False, None], + [False, 'adam', 'amp_bf16', False, None], + [False, 'adamw', 'amp_fp16', False, None], + [False, 'adamw', 'amp_bf16', True, None], + [False, 'adamw', 'amp_bf16', False, ['rng']], +]) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -619,6 +623,7 @@ def test_fsdp_partitioned_state_dict_load( precision: str, optimizer: str, weights_only: bool, + load_ignore_keys: Union[list[str], None], use_remote, s3_bucket, s3_ephemeral_prefix, @@ -630,6 +635,7 @@ def test_fsdp_partitioned_state_dict_load( pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True ' 'errors out. See https://github.com/pytorch/pytorch/issues/102667 ' 'for more info')) + load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys if autoresume: local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}' @@ -700,6 +706,7 @@ def test_fsdp_partitioned_state_dict_load( optimizer=optimizer, load_weights_only=weights_only, fsdp_config=fsdp_config, + load_ignore_keys=load_ignore_keys, ) state_dict_from_trainer2 = trainer2.state.state_dict() rng2 = trainer2._rng_state @@ -709,7 +716,10 @@ def test_fsdp_partitioned_state_dict_load( state_dict_from_trainer2, ) if not weights_only: - _compare_rng_states_between_trainers(rng1, rng2) + if any('rng' in x for x in load_ignore_keys): + assert rng1 is not None and rng2 is None + else: + _compare_rng_states_between_trainers(rng1, rng2) _compare_optims_between_state_dicts( state_dict_from_trainer1_ba2, state_dict_from_trainer2,