Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix 2.4.1 test #3612

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 49 additions & 45 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,51 +946,7 @@ def unshard_with_sync(self):
if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.1'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard

@no_type_check
def unshard_with_sync(self):
"""Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`.
This prevents deadlocks when some ranks OOM after the alloc call and others do not.
This is a patched method from pytorch, meant to be called when automicrobatching
turns on hooks in its search process for the optimal non-OOMing microbatch size.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()

# Check if any other rank hit an OOM
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

# 2.4.0 only patch
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from typing import Mapping, Collection
Expand Down Expand Up @@ -1046,3 +1002,51 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:

for key, value in state_dict.items():
_traverse_obj((str(key),), value)

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.2'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard

@no_type_check
def unshard_with_sync(self):
"""Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`.
This prevents deadlocks when some ranks OOM after the alloc call and others do not.
This is a patched method from pytorch, meant to be called when automicrobatching
turns on hooks in its search process for the optimal non-OOMing microbatch size.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()

# Check if any other rank hit an OOM
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
Loading