Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] OSS: Sync all attributes #67

Merged
merged 3 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any =
# Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device

# Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v

def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks.

Expand Down Expand Up @@ -94,8 +100,8 @@ def partition_parameters(self) -> List[List[dict]]:
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()

# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
Expand All @@ -116,8 +122,8 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:

This needs to be called on all replicas """

# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
# Sync lr and other attributes in case its been updated
self._sync_param_groups()

if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
Expand Down Expand Up @@ -176,20 +182,20 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
{"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]}
)

# Update the param_groups attribute for this instance
# TODO(ben)

def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
if not self.in_super_constructor:
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])

def _sync_lr(self) -> None:
"""Sync learning rate (needed to support LRScheduler)."""
def _sync_param_groups(self) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
local_group["lr"] = global_group["lr"]
for k in local_group.keys():
if k != "params":
# Params have been sharded and should not be synced here
local_group[k] = global_group[k]

def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""
Expand Down
10 changes: 8 additions & 2 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ def test_state_dict():
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()

# Check that the pulled state is what we expect
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0][0]["lr"] == 0.1
assert state_dict["param_groups"][0][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0][0]["nesterov"]
assert state_dict["param_groups"][0][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0][0]["dampening"] == 0.0

# Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["param_groups"][0][0]["lr"] == o.param_groups[0]["lr"]
for k in state_dict["param_groups"][0][0].keys():
if k != "params":
assert state_dict["param_groups"][0][0][k] == o.param_groups[0][k]

# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
Expand Down