Skip to content

Commit

Permalink
Add better error handling for non-rank 0 during Monolithic Checkpoint…
Browse files Browse the repository at this point in the history
… Loading (#3647)
  • Loading branch information
j316chuck authored Oct 14, 2024
1 parent 6ca3936 commit 2972a2a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
9 changes: 9 additions & 0 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,15 @@ def safe_torch_load(
'pass `load_ignore_keys = ["state/train_metrics/*", "state/eval_metrics/*"]`.',
) from e
raise e
except FileNotFoundError as e:
if 'No such file or directory' in str(e) and dist.get_local_rank() != 0:
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
raise FileNotFoundError(
f'No such file or directory: {e.filename}. '
f'This likely implies a download failed on local rank 0, which is global rank {local_rank_zero}'
f'Please check the logs for global rank {local_rank_zero} to debug the checkpoint download issue.',
) from e
raise e


def _restore_checkpoint(
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,33 @@ def test_set_dataloaders_to_cur_epoch(
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@world_size(2)
@pytest.mark.gpu
def test_load_incorrect_path(self, world_size: int, tmp_path: pathlib.Path, caplog):
save_folder = tmp_path / 'checkpoints'
save_folder.mkdir(exist_ok=True)

# Attempt to load from an incorrect path
incorrect_path = str(tmp_path / 'nonexistent_checkpoint.pt')

with pytest.raises(FileNotFoundError) as exc_info:
self.get_trainer(
load_path=incorrect_path,
max_duration='1ep',
)

# Check error messages for each rank, ensure they are different.
if dist.get_global_rank() == 0:
assert f'Local path {incorrect_path} does not exist' in str(exc_info.value)
else:
assert 'No such file or directory:' in str(exc_info.value)
assert 'This likely implies a download failed on local rank 0, which is global rank 0' in str(
exc_info.value,
)
assert 'Please check the logs for global rank 0 to debug the checkpoint download issue.' in str(
exc_info.value,
)

@pytest.mark.parametrize('spin_dataloaders', [False, True])
def test_spin_dataloaders(
self,
Expand Down

0 comments on commit 2972a2a

Please sign in to comment.