Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass PG into checkpoint load and load rng with state_dict #2897

Merged
merged 10 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 24 additions & 35 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ def load_sharded_checkpoint(
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
) -> list[dict]:

) -> Union[list[dict], None]:
if not using_torch_2():
raise ValueError(
f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}')
Expand All @@ -389,16 +388,6 @@ def load_sharded_checkpoint(
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner

# This function is used so we can figure out which ranks need to load saved rngs and which can just make their own.
def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
return len(rng_inds)

class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
"""FileSystemReader that validates checkpoint files prior to reading."""

Expand Down Expand Up @@ -501,13 +490,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
with torch.no_grad():
# 1. Load model and metadata first
if load_weights_only:
state_dict = {'state': {'model': state.get_model_state_dict()}}
state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.1.3'):
cur_state_dict.pop('optimizers')
state_dict = {'state': cur_state_dict}
state_dict: Dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state(),
}

if ignore_keys:
# Filter provided list of key paths
Expand All @@ -518,17 +510,32 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

# Only some ranks are meant to load checkpoint
expect_file = False
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
if expect_file:
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

if version.parse(torch.__version__) > version.parse('2.1.3'):
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)

state.load_state_dict(
Expand All @@ -547,26 +554,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
storage_reader=storage_reader)
state._legacy_load_optim_state(optim_state)

# 3. Optionally load RNG
rng_state_dicts = reproducibility.get_rng_state()
if not load_weights_only:
# If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks
num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
rng_state_dicts_load = {}
rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len(
rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts
dist_cp.load_state_dict(
state_dict=rng_state_dicts_load,
storage_reader=storage_reader,
planner=load_planner,
)
# We also want to append newly generated rng states for the ranks that don't have an rng state to load in
# if we are resuming on more ranks than were used at save time.
if len(rng_state_dicts) > num_ranks_that_saved_rng:
rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:])
rng_state_dicts = rng_state_dicts_load['rng']

return rng_state_dicts
return state_dict.get('rng', None)


def _get_local_rank_zero_path(path: Optional[str]) -> str:
Expand Down Expand Up @@ -1010,9 +998,10 @@ def _save_checkpoint(
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Only save on first replica
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True
Expand Down
22 changes: 16 additions & 6 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from contextlib import nullcontext as does_not_raise
from functools import partial
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -600,12 +600,16 @@ def mock_get_checkpoint_validation_function():

@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('weights_only', [False, True])
@pytest.mark.parametrize('optimizer', ['adam', 'adamw'])
@pytest.mark.parametrize('state_dict_type', ['sharded', 'local'])
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize('autoresume', [True, False])
@pytest.mark.parametrize('weights_only,optimizer,precision,autoresume,load_ignore_keys', [
[False, 'adamw', 'amp_bf16', False, None],
[True, 'adamw', 'amp_bf16', False, None],
[False, 'adam', 'amp_bf16', False, None],
[False, 'adamw', 'amp_fp16', False, None],
[False, 'adamw', 'amp_bf16', True, None],
[False, 'adamw', 'amp_bf16', False, ['rng']],
])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
Expand All @@ -619,6 +623,7 @@ def test_fsdp_partitioned_state_dict_load(
precision: str,
optimizer: str,
weights_only: bool,
load_ignore_keys: Union[list[str], None],
use_remote,
s3_bucket,
s3_ephemeral_prefix,
Expand All @@ -630,6 +635,7 @@ def test_fsdp_partitioned_state_dict_load(
pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True '
'errors out. See https://github.com/pytorch/pytorch/issues/102667 '
'for more info'))
load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys

if autoresume:
local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}'
Expand Down Expand Up @@ -700,6 +706,7 @@ def test_fsdp_partitioned_state_dict_load(
optimizer=optimizer,
load_weights_only=weights_only,
fsdp_config=fsdp_config,
load_ignore_keys=load_ignore_keys,
)
state_dict_from_trainer2 = trainer2.state.state_dict()
rng2 = trainer2._rng_state
Expand All @@ -709,7 +716,10 @@ def test_fsdp_partitioned_state_dict_load(
state_dict_from_trainer2,
)
if not weights_only:
_compare_rng_states_between_trainers(rng1, rng2)
if any('rng' in x for x in load_ignore_keys):
assert rng1 is not None and rng2 is None
else:
_compare_rng_states_between_trainers(rng1, rng2)
_compare_optims_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
Expand Down
Loading