Skip to content

Commit

Permalink
Add SavePlanner and LoadPlanner support in Trainer
Browse files Browse the repository at this point in the history
This PR adds SavePlanner and LoadPlanner from `torch.distributed.checkpoint.planner` into Composer. These parameters are passed in when initializing trainer and are eventually passed to `save_checkpoint` and `load_sharded_checkpoint`. Because planners were implemented in torch 2, there's no type hinting, but type checking happens before the planners are used.

Some minor formatting fixes are included in this PR.

cc: @vchiley

commit-id:4ea6578b
  • Loading branch information
b-chu committed Nov 9, 2023
1 parent 66096f2 commit 200e466
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 13 deletions.
2 changes: 2 additions & 0 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
fsdp_config.setdefault('load_monolith_rank0_only', False)
fsdp_config.setdefault('load_planner', None)
fsdp_config.setdefault('mixed_precision', 'DEFAULT')
fsdp_config.setdefault('save_planner', None)
fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}')
fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD')
fsdp_config.setdefault('state_dict_type', 'full')
Expand Down
63 changes: 60 additions & 3 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def load_sharded_checkpoint(
f'Sharded checkpoint loading on >1 node requires torch version >= 2.0.1. You have torch version {torch.__version__}'
)

if state.fsdp_config is None:
raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.')
load_planner = state.fsdp_config['load_planner']
_validate_load_planner(load_planner)

from torch.distributed import checkpoint as dist_cp
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
Expand Down Expand Up @@ -416,7 +421,11 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Call function to modify state_dict
ignore_keys(model_state_dict)

dist_cp.load_state_dict(model_state_dict, storage_reader)
dist_cp.load_state_dict(
state_dict=model_state_dict,
storage_reader=storage_reader,
planner=load_planner,
)

state.load_state_dict(
model_state_dict['state'],
Expand All @@ -441,7 +450,11 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
rng_state_dicts_load = {}
rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len(
rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts
dist_cp.load_state_dict(rng_state_dicts_load, storage_reader)
dist_cp.load_state_dict(
state_dict=rng_state_dicts_load,
storage_reader=storage_reader,
planner=load_planner,
)
# We also want to append newly generated rng states for the ranks that don't have an rng state to load in
# if we are resuming on more ranks than were used at save time.
if len(rng_state_dicts) > num_ranks_that_saved_rng:
Expand Down Expand Up @@ -628,6 +641,40 @@ def filter_func(state_dict: Dict) -> None:
return filter_func


def _validate_save_planner(save_planner: Optional[Any]) -> None:
"""Checks that ``save_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.SavePlanner`.
TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
type hints.
Raises:
ValueError: If ``save_planner`` is not a
:class:`~torch.distributed.checkpoint.planner.SavePlanner`.
"""
from torch.distributed.checkpoint.planner import SavePlanner

if save_planner is not None and not isinstance(save_planner, SavePlanner):
raise ValueError((f'save_planner {type(save_planner)} is not a '
'torch.distributed.checkpoint.planner.SavePlanner'))


def _validate_load_planner(load_planner: Optional[Any]) -> None:
"""Checks that ``load_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
type hints.
Raises:
ValueError: If ``load_planner`` is not a
:class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
"""
from torch.distributed.checkpoint.planner import LoadPlanner

if load_planner is not None and not isinstance(load_planner, LoadPlanner):
raise ValueError((f'load_planner {type(load_planner)} is not a '
'torch.distributed.checkpoint.planner.LoadPlanner'))


def safe_torch_load(
composer_states_filepath: Union[Path, str],
map_location: str = 'cpu',
Expand Down Expand Up @@ -806,9 +853,19 @@ def save_checkpoint(

# Sharded checkpointing for torch >=2.0 uses the torch.distributed.checkpoint module.
elif state.fsdp_elastic_sharded_enabled:
if state.fsdp_config is None:
raise ValueError('Saving a sharded checkpoint requires passing an FSDP config to Trainer.')
save_planner = state.fsdp_config['save_planner']
_validate_save_planner(save_planner)

import torch.distributed.checkpoint as dist_cp

log.debug('Saving sharded checkpoints to %s...', save_filename)
dist_cp.save_state_dict(state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname))
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
planner=save_planner,
)

# Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0
elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled:
Expand Down
11 changes: 9 additions & 2 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,20 @@ def is_model_fsdp(model: torch.nn.Module) -> bool:
def is_notebook():
"""Whether Composer is running in a IPython/Jupyter Notebook."""
try:
__IPYTHON__ #type: ignore
__IPYTHON__ # type: ignore
return True
except NameError:
return False


def warning_on_one_line(message: str, category: Type[Warning], filename: str, lineno: int, file=None, line=None):
def warning_on_one_line(
message: str,
category: Type[Warning],
filename: str,
lineno: int,
file=None,
line=None,
):
"""Force Python warnings to consolidate into one line."""
# From https://stackoverflow.com/questions/26430861/make-pythons-warnings-warn-not-mention-itself
return f'{category.__name__}: {message} (source: {filename}:{lineno})\n'
Expand Down
26 changes: 18 additions & 8 deletions docs/source/notes/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,32 @@ The full spec and defaults for Composer's `fsdp_config` is here:
.. code:: python
fsdp_config = {
'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD'
'activation_checkpointing': bool = True | False, # Default: False
'activation_checkpointing_reentrant': bool = True | False, # Default: True
'activation_cpu_offload': bool = True | False, # Default: False
'backward_prefetch': str = 'BACKWARD_PRE' | 'BACKWARD_POST' | 'NONE', # Default: 'BACKWARD_POST'
'cpu_offload': bool = True | False, # Default: False, cpu_offload not supported yet
'flatten_parameters': bool = True | False, # Default: True
'forward_prefetch': bool = True | False, # Default: False
'ignored_modules': Optional[Iterable[torch.nn.Module]], # Default: None
'keep_low_precision_grads': bool = True | False, # Default: False
'limit_all_gathers': bool = True | False, # Default: False
'load_monolith_rank0_only': bool = True | False, # Default: False
'load_planner': torch.distributed.checkpoint.planner.LoadPlanner, # Default: None
'mixed_precision': str = 'FULL' | 'DEFAULT' | 'PURE', # Default: 'DEFAULT'
# Note: you can explicitly provide a dictionary too
# 'mixed_precision': dict = {
# 'param_dtype': 'fp32' | 'fp16' | 'bf16',
# 'reduce_dtype': 'fp32' | 'fp16' | 'bf16',
# 'buffer_dtype': 'fp32' | 'fp16' | 'bf16',
# },
'backward_prefetch': str = 'BACKWARD_PRE' | 'BACKWARD_POST' | 'NONE', # Default: 'BACKWARD_POST'
'activation_checkpointing': bool = True | False, # Default: False
'activation_cpu_offload': bool = True | False, # Default: False
'verbose': bool = True | False,
'state_dict_type': str = 'full' | 'local' | 'sharded' # Default: full
'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}' # Default: 'ep{epoch}-ba{batch}'
'load_monolith_rank0_only': bool = True | False # Default: False
'save_planner': torch.distributed.checkpoint.planner.SavePlanner, # Default: None
'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}', # Default: 'ep{epoch}-ba{batch}'
'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD'
'state_dict_type': str = 'full' | 'local' | 'sharded', # Default: full
'sync_module_states': bool = True | False, # Default: False
'use_orig_params': bool = True | False, # Default: True
'verbose': bool = True | False, # Default: False
}
All values come with defaults and can be optionally defined in the :code:`fsdp_config`. Most parameters map directly to parameters in the `FSDP documentation <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel>`__.
Expand Down
177 changes: 177 additions & 0 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class FSDPConfig:
sync_module_states: bool = True
use_orig_params: bool = False
load_fsdp_monolith_rank0_only: bool = False
save_planner: Optional[Any] = None
load_planner: Optional[Any] = None


def get_trainer(
Expand Down Expand Up @@ -890,3 +892,178 @@ def test_cleanup_sharded_checkpoints(
non_elastic_file_list = {save_filename.format(rank=rank) for rank in range(dist.get_world_size())}
file_list = elastic_file_list if using_torch_2() else non_elastic_file_list
assert set(os.listdir(full_path_ckpt_dir)) == file_list


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('weights_only', [False, True])
@pytest.mark.parametrize('planner', [None, 'rename'])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2.0 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
def test_fsdp_planner(
world_size,
tmp_path: pathlib.Path,
weights_only: bool,
planner: Optional[str],
s3_bucket,
s3_ephemeral_prefix,
request,
):
optimizer = 'adamw'
state_dict_type = 'sharded'
precision = 'amp_bf16'
use_remote = False
autoresume = False
if weights_only and autoresume:
pytest.xfail('Weights only with autoresume is not supported')

from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata

class RenameSavePlanner(DefaultSavePlanner):

def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
is_coordinator: bool,
) -> None:
# suffix all keys with `foo_``
state_dict['state']['model'] = {k + '_foo': v for k, v in state_dict['state']['model'].items()}

super().set_up_planner(state_dict, is_coordinator)

class RenameLoadPlanner(DefaultLoadPlanner):

def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
is_coordinator: bool,
) -> None:
if 'state' not in state_dict:
super().set_up_planner(state_dict, metadata, is_coordinator)
return

self.original_state_dict = state_dict

state_dict = dict(state_dict.items())
state_dict['state'] = dict(state_dict['state'].items())
state_dict['state']['model'] = {k + '_foo': v for k, v in state_dict['state']['model'].items()}

if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)

if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)

self.state_dict = state_dict
self.metadata = metadata
self.is_coordinator = is_coordinator

save_planner = planner
load_planner = planner
if planner == 'rename':
save_planner = RenameSavePlanner()
load_planner = RenameLoadPlanner()

if autoresume:
local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}'
run_name = dist.all_gather_object(local_run_name)[0]
else:
run_name = None

if use_remote:
save_folder = f's3://{s3_bucket}/{s3_ephemeral_prefix}/checkpoints/{{run_name}}'
else:
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = os.path.join(tmp_paths[0], 'checkpoints', '{run_name}')

save_filename = 'ba{batch}-rank{rank}.pt'

fsdp_config = FSDPConfig(
state_dict_type=state_dict_type,
load_planner=load_planner,
save_planner=save_planner,
)

trainer1 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
run_name=run_name,
precision=precision,
autoresume=autoresume,
optimizer=optimizer,
max_duration='2ba',
save_interval='2ba',
save_weights_only=weights_only,
fsdp_config=fsdp_config,
)
run_name = trainer1.state.run_name
trainer1.fit()
rng1 = get_rng_state()
state_dict_from_trainer1_ba2 = trainer1.state.state_dict()
trainer1.close()

if use_remote:
load_path = 's3://' + save_folder.strip('s3://').format(run_name=run_name) + '/ba2'
object_store = S3ObjectStore(bucket=f'{s3_bucket}')
else:
object_store = None
load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

if not using_torch_2():
load_filename = f"{save_filename.format(batch=2, rank='{rank}')}"
assert load_filename == 'ba2-rank{rank}.pt'
load_path += '/' + load_filename
assert is_checkpoint_legacy_sharded(
object_store=object_store,
source_path=load_path.replace(f's3://{s3_bucket}/', ''),
)
else:
assert not is_checkpoint_legacy_sharded(
object_store=object_store,
source_path=load_path.replace(f's3://{s3_bucket}/', ''),
)

if autoresume:
load_path = None
trainer2 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
load_path=load_path,
precision=precision,
autoresume=autoresume,
run_name=run_name,
max_duration='4ba',
save_interval='4ba',
optimizer=optimizer,
load_weights_only=weights_only,
fsdp_config=fsdp_config,
)
state_dict_from_trainer2 = trainer2.state.state_dict()
rng2 = trainer2._rng_state
# Compare saved state and loaded state for both ranks.
_compare_model_params_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
)
if not weights_only:
_compare_rng_states_between_trainers(rng1, rng2)
_compare_optims_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
)
_compare_metrics_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
)
_compare_timestamps_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
)

trainer2.fit()
trainer2.close()

0 comments on commit 200e466

Please sign in to comment.