Skip to content

Commit

Permalink
Pass PG into checkpoint load and load rng with state_dict (#2897)
Browse files Browse the repository at this point in the history
* checkdown

* remove comment

* lint

* comments

* fix

* accelerate test

* fix test

* lint

* fix test
  • Loading branch information
mvpatel2000 authored Jan 24, 2024
1 parent 4a53dfe commit cfc439a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 41 deletions.
59 changes: 24 additions & 35 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ def load_sharded_checkpoint(
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
) -> list[dict]:

) -> Union[list[dict], None]:
if not using_torch_2():
raise ValueError(
f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}')
Expand All @@ -389,16 +388,6 @@ def load_sharded_checkpoint(
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner

# This function is used so we can figure out which ranks need to load saved rngs and which can just make their own.
def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
return len(rng_inds)

class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
"""FileSystemReader that validates checkpoint files prior to reading."""

Expand Down Expand Up @@ -501,13 +490,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
with torch.no_grad():
# 1. Load model and metadata first
if load_weights_only:
state_dict = {'state': {'model': state.get_model_state_dict()}}
state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.1.3'):
cur_state_dict.pop('optimizers')
state_dict = {'state': cur_state_dict}
state_dict: Dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state(),
}

if ignore_keys:
# Filter provided list of key paths
Expand All @@ -518,17 +510,32 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

# Only some ranks are meant to load checkpoint
expect_file = False
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

if version.parse(torch.__version__) > version.parse('2.1.3'):
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)

state.load_state_dict(
Expand All @@ -547,26 +554,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
storage_reader=storage_reader)
state._legacy_load_optim_state(optim_state)

# 3. Optionally load RNG
rng_state_dicts = reproducibility.get_rng_state()
if not load_weights_only:
# If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks
num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
rng_state_dicts_load = {}
rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len(
rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts
dist_cp.load_state_dict(
state_dict=rng_state_dicts_load,
storage_reader=storage_reader,
planner=load_planner,
)
# We also want to append newly generated rng states for the ranks that don't have an rng state to load in
# if we are resuming on more ranks than were used at save time.
if len(rng_state_dicts) > num_ranks_that_saved_rng:
rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:])
rng_state_dicts = rng_state_dicts_load['rng']

return rng_state_dicts
return state_dict.get('rng', None)


def _get_local_rank_zero_path(path: Optional[str]) -> str:
Expand Down Expand Up @@ -1010,9 +998,10 @@ def _save_checkpoint(
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Only save on first replica
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True
Expand Down
22 changes: 16 additions & 6 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from contextlib import nullcontext as does_not_raise
from functools import partial
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -600,12 +600,16 @@ def mock_get_checkpoint_validation_function():

@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('weights_only', [False, True])
@pytest.mark.parametrize('optimizer', ['adam', 'adamw'])
@pytest.mark.parametrize('state_dict_type', ['sharded', 'local'])
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize('autoresume', [True, False])
@pytest.mark.parametrize('weights_only,optimizer,precision,autoresume,load_ignore_keys', [
[False, 'adamw', 'amp_bf16', False, None],
[True, 'adamw', 'amp_bf16', False, None],
[False, 'adam', 'amp_bf16', False, None],
[False, 'adamw', 'amp_fp16', False, None],
[False, 'adamw', 'amp_bf16', True, None],
[False, 'adamw', 'amp_bf16', False, ['rng']],
])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
Expand All @@ -619,6 +623,7 @@ def test_fsdp_partitioned_state_dict_load(
precision: str,
optimizer: str,
weights_only: bool,
load_ignore_keys: Union[list[str], None],
use_remote,
s3_bucket,
s3_ephemeral_prefix,
Expand All @@ -630,6 +635,7 @@ def test_fsdp_partitioned_state_dict_load(
pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True '
'errors out. See https://github.com/pytorch/pytorch/issues/102667 '
'for more info'))
load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys

if autoresume:
local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}'
Expand Down Expand Up @@ -700,6 +706,7 @@ def test_fsdp_partitioned_state_dict_load(
optimizer=optimizer,
load_weights_only=weights_only,
fsdp_config=fsdp_config,
load_ignore_keys=load_ignore_keys,
)
state_dict_from_trainer2 = trainer2.state.state_dict()
rng2 = trainer2._rng_state
Expand All @@ -709,7 +716,10 @@ def test_fsdp_partitioned_state_dict_load(
state_dict_from_trainer2,
)
if not weights_only:
_compare_rng_states_between_trainers(rng1, rng2)
if any('rng' in x for x in load_ignore_keys):
assert rng1 is not None and rng2 is None
else:
_compare_rng_states_between_trainers(rng1, rng2)
_compare_optims_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
Expand Down

0 comments on commit cfc439a

Please sign in to comment.