Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[perf][OSS] Clip grad norm : minor obvious speedup #363

Merged
merged 1 commit into from
Feb 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Iterable, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[Iterable[Any]] = None

# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
Expand Down Expand Up @@ -143,6 +144,17 @@ def partition_parameters(self) -> List[List[dict]]:

return self._partition_parameters

@property
def local_params(self) -> Iterable[torch.Tensor]:
if self._local_params is None:
self._local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
return self._local_params

@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
Expand Down Expand Up @@ -255,25 +267,16 @@ def clip_grad_norm(
max_norm = float(max_norm)
norm_type = float(norm_type)

# Filter out the grad-less params, concatenate params from all devices
local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)

# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if filter_params_fn is not None:
local_params = filter_params_fn(local_params)
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params

# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
Expand Down Expand Up @@ -508,6 +511,7 @@ def _clear_cache(self) -> None:
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()
self._local_params = None

@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
Expand Down