Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Legacy Checkpointing #3631

Merged
merged 9 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 9 additions & 43 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,24 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo
) + extra_suffix


def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str):
if source_path.endswith('.symlink') or os.path.islink(source_path):
source_path = extract_path_from_symlink(source_path, object_store=object_store)
metadata_path = str(Path(source_path) / Path('.metadata'))
log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.')
if object_store is None:
return not os.path.exists(metadata_path)
else:
try:
_, _, metadata_path = parse_uri(metadata_path)
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
download_object_or_file(metadata_path, metadata_destination, object_store)
return False
except FileNotFoundError:
return True


def load_checkpoint(
path: str,
state: State,
Expand Down Expand Up @@ -533,16 +515,8 @@ def load_checkpoint(
"""
path = partial_format(path, run_name=state.run_name)
log.debug(f'Loading checkpoint from formatted path: {path}')
using_legacy_sharded = False

if state.fsdp_sharded_state_dict_enabled:
assert object_store is None or isinstance(
object_store,
ObjectStore,
), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore'
using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path)
log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}')

if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded:
rng_state_dicts = load_sharded_checkpoint(
source_path=path,
state=state,
Expand All @@ -557,26 +531,20 @@ def load_checkpoint(
)
else:
# Download the checkpoint to the node-local folder
log.debug('Loading checkpoint at %s', path)
# Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node.
# If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder.
needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0
tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None)
tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None)
with tempdir_ctx as tempdir:
try:
# Get the path to the proper checkpoint folder corresponding to the current rank's node.
# If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir.
node_checkpoint_folder = (
tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir)
)
assert node_checkpoint_folder is not None
node_checkpoint_folder = _get_local_rank_zero_path(tempdir)

composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint(
path=path,
node_checkpoint_folder=node_checkpoint_folder,
object_store=object_store,
progress_bar=progress_bar,
fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled,
deepspeed_sharded_checkpoint=is_model_deepspeed(state.model),
)
rng_state_dicts = _restore_checkpoint(
Expand All @@ -596,6 +564,8 @@ def load_checkpoint(
# be a shared resource between nodes.
dist.barrier()
log.info('%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path)

# Verify all ranks resumed on same step
step_to_resume_from = state.timestamp.batch.value
max_step_to_resume_from = state.device.tensor_to_device(
torch.tensor(state.timestamp.batch.value, dtype=torch.int64),
Expand Down Expand Up @@ -802,7 +772,6 @@ def download_checkpoint(
node_checkpoint_folder: str,
object_store: Optional[Union[ObjectStore, LoggerDestination]],
progress_bar: bool,
fsdp_sharded_state_dict_enabled: bool = False,
deepspeed_sharded_checkpoint: bool = False,
) -> tuple[str, Optional[str], bool]:
"""Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``.
Expand All @@ -829,9 +798,7 @@ def download_checkpoint(
# and only rank zero has this file unless fsdp_sharded_state_dict_enabled then
# every rank has it's own file.
extracted_checkpoint_folder = None
composer_states_filepath = (
rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath
)
composer_states_filepath = rank_zero_checkpoint_filepath

if is_compressed_pt(path):
original_path = path
Expand All @@ -841,9 +808,8 @@ def download_checkpoint(
with compressor.decompress(original_path) as in_file:
shutil.copyfileobj(in_file, out_file)

checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint
try:
if not checkpoint_is_sharded and dist.get_local_rank() == 0:
if not deepspeed_sharded_checkpoint and dist.get_local_rank() == 0:
# If the checkpoint is not sharded, then local rank 0 on each node needs to download the
# global rank 0 checkpoint
path = _format_path_with_rank_zero(path)
Expand All @@ -862,7 +828,7 @@ def download_checkpoint(
# the underlying issue is that the checkpoint file does not exist on the disk
# or could not be downloaded
raise RuntimeError(f'Checkpoint {path} does not exist')
elif checkpoint_is_sharded:
elif deepspeed_sharded_checkpoint:
# If the checkpoint is sharded, then every rank needs to download its own checkpoint
path = _format_path_with_current_rank(path)
try:
Expand Down Expand Up @@ -892,7 +858,7 @@ def download_checkpoint(

finally:
# Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints
if not checkpoint_is_sharded:
if not deepspeed_sharded_checkpoint:
signal_file_path = os.path.join(
node_checkpoint_folder,
dist.get_node_signal_file_name(),
Expand Down
25 changes: 4 additions & 21 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from composer.utils import FSDPConfig, TPConfig, dist, parse_uri
from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded
from composer.utils.checkpoint import dist_cp_load
from composer.utils.file_helpers import get_file
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
Expand Down Expand Up @@ -537,6 +537,9 @@ def test_fsdp_load_old_checkpoint(
pytest.skip('Current torch version is older than torch version that checkpoint was written with.')

if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']:
if state_dict_type == 'sharded':
pytest.mark.skip('Loading legacy sharded checkpoints are not supported after v0.25.0.')

rank = 0 if state_dict_type == 'full' else '{rank}'
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

load_path_dir = (
Expand All @@ -548,10 +551,6 @@ def test_fsdp_load_old_checkpoint(
load_path_dir = (load_path_dir + 'ep0-ba2/')

load_path = load_path_dir + f'ba2_rank{rank}.pt'
assert is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=load_path.lstrip(f's3://{s3_bucket}/'),
)
else:
load_path = (
f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/'
Expand Down Expand Up @@ -909,16 +908,9 @@ def test_fsdp_partitioned_state_dict_load(
load_path = 's3://' + save_folder.strip('s3://').format(
run_name=run_name,
) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink')
object_store = S3ObjectStore(bucket=f'{s3_bucket}')
else:
object_store = None
load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=object_store,
source_path=load_path.replace(f's3://{s3_bucket}/', ''),
)

if autoresume:
load_path = None
trainer2 = get_trainer(
Expand Down Expand Up @@ -1013,10 +1005,6 @@ def test_elastic_resumption(
else:
save_folder = None
sharded_load_path = os.path.join(base_path, 'ba2')
assert not is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''),
)

sharded_trainer = get_trainer(
save_folder=save_folder,
Expand Down Expand Up @@ -1237,11 +1225,6 @@ def set_up_planner(

load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=None,
source_path=load_path,
)

trainer2 = get_trainer(
save_folder=str(save_folder),
load_path=load_path,
Expand Down
Loading