diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 1ce2414a1e17..90afaf03550a 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -38,7 +38,7 @@ def _apply_forward_and_backward_to_tensors_only(module, forward_function, backwa class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module=None, *args, **kwargs): + def __init__(self, parent_module, *args, **kwargs): """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. Args: @@ -49,6 +49,10 @@ def __init__(self, parent_module=None, *args, **kwargs): self._parent_module = parent_module self._in_forward = False + def __reduce__(self): + r0, _, *r2 = super().__reduce__() + return (r0, (self._parent_module, )) + r2 + def __getitem__(self, key): param = super().__getitem__(key) @@ -56,6 +60,7 @@ def __getitem__(self, key): if param is None: return param + # TODO: only weaken this check during compilation if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if self._parent_module._parameters._in_forward: register_external_parameter(FWD_MODULE_STACK[-1], param)