Skip to content

Commit

Permalink
Create callback to load checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 1, 2024
1 parent e5e2f74 commit 00682c7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.load_checkpoint import LoadCheckpoint
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.memory_snapshot import MemorySnapshot
Expand Down Expand Up @@ -44,4 +45,5 @@
'FreeOutputs',
'MemorySnapshot',
'OOMObserver',
'LoadCheckpoint',
]
34 changes: 34 additions & 0 deletions composer/callbacks/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Load a checkpoint."""
from typing import Optional

from checkpoint.load import CheckpointLoadOptions, load_checkpoint

from composer.core import Callback, State
from composer.loggers import Logger


class LoadCheckpoint(Callback):
"""Callback that loads a checkpoint after other checkpoints have been loaded.
Args:
load_path: The path to the checkpoint to load.
load_options: A dictionary of options to pass to the checkpoint loading function.
"""

def __init__(self, load_path: str, load_options: Optional[dict] = None):
super().__init__()
self.load_path = load_path
self.load_options = CheckpointLoadOptions(**(load_options or {}))

def after_load(self, state: State, logger: Logger) -> None:
load_checkpoint(
load_path=self.load_path,
load_options=self.load_options,
state=state,
model_child_path=None,
optim_child_path=None,
)
return super().after_load(state, logger)

0 comments on commit 00682c7

Please sign in to comment.