Skip to content

Commit

Permalink
Update types to follow PEP 585
Browse files Browse the repository at this point in the history
commit-id:ecb54855
  • Loading branch information
b-chu committed Nov 8, 2023
1 parent d5fed8c commit 2c3ba52
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import textwrap
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from packaging import version
Expand Down Expand Up @@ -131,9 +131,9 @@ def load_checkpoint(
load_weights_only: bool = False,
strict_model_weights: bool = False,
progress_bar: bool = True,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
):
"""Load a checkpoint from a local file, URI, or cloud object store into ``state``.
Expand Down Expand Up @@ -184,7 +184,7 @@ def load_checkpoint(
match the model weights. (default: ``False``)
progress_bar (bool, optional): Whether or not to show a progress bar when downloading checkpoints.
Ignored if the checkpoint is a local file path. (default: ``True``)
ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
ignore_keys (list[str] | (dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list
of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
Expand All @@ -205,7 +205,7 @@ def load_checkpoint(
the state_dict before it is loaded.
(default: ``None``)
exclude_algorithms (List[str], optional): A list of algorithm names to exclude from loading.
exclude_algorithms (list[str], optional): A list of algorithm names to exclude from loading.
By default, algorithms with `required_on_load=True` which were enabled when training the loaded
checkpoint are automatically applied unless they conflict with a user specified algorithm. These
algorithms often change the model, and not applying them could result in certain layers not having
Expand All @@ -216,11 +216,11 @@ def load_checkpoint(
Example 2: ``exclude_algorithms = ["FusedLayerNorm", "Alibi"]`` would exclude FusedLayerNorm and Alibi from loading.
(default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
algorithm_passes (list[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
Returns:
Optional[List[Dict[str, Any]]]: The RNG state dicts, indexed by global rank, if
Optional[list[dict[str, Any]]]: The RNG state dicts, indexed by global rank, if
:attr:`load_weights_only` is not None. Otherwise, None.
"""
using_legacy_sharded = False
Expand Down Expand Up @@ -311,10 +311,10 @@ def load_sharded_checkpoint(
load_weights_only: bool = False,
strict_model_weights: bool = False,
progress_bar: bool = True,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
) -> List[Dict]:
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
) -> list[dict]:

if not using_torch_2():
raise ValueError(
Expand Down Expand Up @@ -484,7 +484,7 @@ def download_checkpoint(path: str,
object_store: Optional[Union[ObjectStore, LoggerDestination]],
progress_bar: bool,
fsdp_sharded_state_dict_enabled: bool = False,
deepspeed_sharded_checkpoint: bool = False) -> Tuple[str, Optional[str], bool]:
deepspeed_sharded_checkpoint: bool = False) -> tuple[str, Optional[str], bool]:
"""Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``.
Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``).
Expand Down Expand Up @@ -579,9 +579,9 @@ def download_checkpoint(path: str,
return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n


def _flatten_keys(obj: Any, paths: List[str], existing_path: str):
def _flatten_keys(obj: Any, paths: list[str], existing_path: str):
"""Recursively flatten the keys of a dictionary or list into a set of paths."""
# Store path when we reach end, which is either non-Dict or empty Dict
# Store path when we reach end, which is either non-dict or empty dict
if isinstance(obj, list) and len(obj) > 0:
for i, elm in enumerate(obj):
_flatten_keys(elm, paths, f'{existing_path}/{i}')
Expand All @@ -592,7 +592,7 @@ def _flatten_keys(obj: Any, paths: List[str], existing_path: str):
paths.append(existing_path.lstrip('/'))


def _remove_paths(obj: Union[list, Dict[str, Any]], exclude_paths: List[List[str]]):
def _remove_paths(obj: Union[list, dict[str, Any]], exclude_paths: list[list[str]]):
# First determine the keys which will be recursed on and which will be removed entirely
# Group the `exclude_paths` by the key
keys_to_recurse = {}
Expand Down Expand Up @@ -620,10 +620,10 @@ def _remove_paths(obj: Union[list, Dict[str, Any]], exclude_paths: List[List[str
del obj[key]


def glob_filter(exclude_globs: List[str]) -> Callable[[Dict], None]:
def glob_filter(exclude_globs: list[str]) -> Callable[[dict], None]:
"""Provides a function which deletes all subparts of a dictionary based on a list of paths."""

def filter_func(state_dict: Dict) -> None:
def filter_func(state_dict: dict) -> None:
# Flatten dictionary into paths
paths = []
_flatten_keys(state_dict, paths, '/')
Expand Down Expand Up @@ -693,7 +693,7 @@ def safe_torch_load(
composer_states_filepath: Union[Path, str],
map_location: str = 'cpu',
load_fsdp_monolith_rank0_only: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Load a torch checkpoint, catching errors due to backwards compatibility issues.
Args:
Expand All @@ -720,7 +720,7 @@ def safe_torch_load(

log.debug('Broadcasting state_dict to all ranks.')
dist.broadcast_object_list(state_dict_list, src=0)
state_dict: Dict[str, Any] = state_dict_list[0] # type: ignore
state_dict: dict[str, Any] = state_dict_list[0] # type: ignore

if dist.get_global_rank() == 0:
if model is not None:
Expand Down Expand Up @@ -748,10 +748,10 @@ def _restore_checkpoint(
extracted_checkpoint_folder: Optional[str],
load_weights_only: bool,
strict_model_weights: bool,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]],
exclude_algorithms: Optional[List[str]],
algorithm_passes: Optional[List[AlgorithmPass]],
) -> Optional[List[Dict[str, Any]]]:
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]],
exclude_algorithms: Optional[list[str]],
algorithm_passes: Optional[list[AlgorithmPass]],
) -> Optional[list[dict[str, Any]]]:
"""Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False)."""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = safe_torch_load(
Expand Down Expand Up @@ -990,7 +990,7 @@ def _save_deepspeed_model(model, filename: str):
compatible with DeepSpeed,
Returns:
List[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process.
list[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process.
.. note::
Expand Down

0 comments on commit 2c3ba52

Please sign in to comment.