diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 6edca1723b..48a5dc51c8 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -15,7 +15,7 @@ import textwrap import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from packaging import version @@ -123,9 +123,9 @@ def load_checkpoint( load_weights_only: bool = False, strict_model_weights: bool = False, progress_bar: bool = True, - ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, - exclude_algorithms: Optional[List[str]] = None, - algorithm_passes: Optional[List[AlgorithmPass]] = None, + ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, + exclude_algorithms: Optional[list[str]] = None, + algorithm_passes: Optional[list[AlgorithmPass]] = None, ): """Load a checkpoint from a local file, URI, or cloud object store into ``state``. @@ -176,7 +176,7 @@ def load_checkpoint( match the model weights. (default: ``False``) progress_bar (bool, optional): Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default: ``True``) - ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint, + ignore_keys (list[str] | (dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2). @@ -197,7 +197,7 @@ def load_checkpoint( the state_dict before it is loaded. (default: ``None``) - exclude_algorithms (List[str], optional): A list of algorithm names to exclude from loading. + exclude_algorithms (list[str], optional): A list of algorithm names to exclude from loading. By default, algorithms with `required_on_load=True` which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having @@ -208,11 +208,11 @@ def load_checkpoint( Example 2: ``exclude_algorithms = ["FusedLayerNorm", "Alibi"]`` would exclude FusedLayerNorm and Alibi from loading. (default: ``None``) - algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms + algorithm_passes (list[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default: ``None``) Returns: - Optional[List[Dict[str, Any]]]: The RNG state dicts, indexed by global rank, if + Optional[list[dict[str, Any]]]: The RNG state dicts, indexed by global rank, if :attr:`load_weights_only` is not None. Otherwise, None. """ using_legacy_sharded = False @@ -303,10 +303,10 @@ def load_sharded_checkpoint( load_weights_only: bool = False, strict_model_weights: bool = False, progress_bar: bool = True, - ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, - exclude_algorithms: Optional[List[str]] = None, - algorithm_passes: Optional[List[AlgorithmPass]] = None, -) -> List[Dict]: + ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, + exclude_algorithms: Optional[list[str]] = None, + algorithm_passes: Optional[list[AlgorithmPass]] = None, +) -> list[dict]: if not using_torch_2(): raise ValueError( @@ -478,7 +478,7 @@ def download_checkpoint(path: str, object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, fsdp_sharded_state_dict_enabled: bool = False, - deepspeed_sharded_checkpoint: bool = False) -> Tuple[str, Optional[str], bool]: + deepspeed_sharded_checkpoint: bool = False) -> tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``). @@ -573,9 +573,9 @@ def download_checkpoint(path: str, return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n -def _flatten_keys(obj: Any, paths: List[str], existing_path: str): +def _flatten_keys(obj: Any, paths: list[str], existing_path: str): """Recursively flatten the keys of a dictionary or list into a set of paths.""" - # Store path when we reach end, which is either non-Dict or empty Dict + # Store path when we reach end, which is either non-dict or empty dict if isinstance(obj, list) and len(obj) > 0: for i, elm in enumerate(obj): _flatten_keys(elm, paths, f'{existing_path}/{i}') @@ -586,7 +586,7 @@ def _flatten_keys(obj: Any, paths: List[str], existing_path: str): paths.append(existing_path.lstrip('/')) -def _remove_paths(obj: Union[list, Dict[str, Any]], exclude_paths: List[List[str]]): +def _remove_paths(obj: Union[list, dict[str, Any]], exclude_paths: list[list[str]]): # First determine the keys which will be recursed on and which will be removed entirely # Group the `exclude_paths` by the key keys_to_recurse = {} @@ -614,10 +614,10 @@ def _remove_paths(obj: Union[list, Dict[str, Any]], exclude_paths: List[List[str del obj[key] -def glob_filter(exclude_globs: List[str]) -> Callable[[Dict], None]: +def glob_filter(exclude_globs: list[str]) -> Callable[[dict], None]: """Provides a function which deletes all subparts of a dictionary based on a list of paths.""" - def filter_func(state_dict: Dict) -> None: + def filter_func(state_dict: dict) -> None: # Flatten dictionary into paths paths = [] _flatten_keys(state_dict, paths, '/') @@ -679,7 +679,7 @@ def safe_torch_load( composer_states_filepath: Union[Path, str], map_location: str = 'cpu', load_fsdp_monolith_rank0_only: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Load a torch checkpoint, catching errors due to backwards compatibility issues. Args: @@ -706,7 +706,7 @@ def safe_torch_load( log.debug('Broadcasting state_dict to all ranks.') dist.broadcast_object_list(state_dict_list, src=0) - state_dict: Dict[str, Any] = state_dict_list[0] # type: ignore + state_dict: dict[str, Any] = state_dict_list[0] # type: ignore if dist.get_global_rank() == 0: if model is not None: @@ -734,10 +734,10 @@ def _restore_checkpoint( extracted_checkpoint_folder: Optional[str], load_weights_only: bool, strict_model_weights: bool, - ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]], - exclude_algorithms: Optional[List[str]], - algorithm_passes: Optional[List[AlgorithmPass]], -) -> Optional[List[Dict[str, Any]]]: + ignore_keys: Optional[Union[list[str], Callable[[dict], None]]], + exclude_algorithms: Optional[list[str]], + algorithm_passes: Optional[list[AlgorithmPass]], +) -> Optional[list[dict[str, Any]]]: """Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False).""" # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = safe_torch_load( @@ -978,7 +978,7 @@ def _save_deepspeed_model(model, filename: str): compatible with DeepSpeed, Returns: - List[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process. + list[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process. .. note::