Skip to content

Commit

Permalink
different learning rate for different parts
Browse files Browse the repository at this point in the history
Summary:
Adds the ability to have different learning rates for different parts of the model. The trainable parts of the implicitron have a new member

       param_groups: dictionary where keys are names of individual parameters,
            or module’s members and values are the parameter group where the
            parameter/member will be sorted to. "self" key is used to denote the
            parameter group at the module level. Possible keys, including the "self" key
            do not have to be defined. By default all parameters are put into "default"
            parameter group and have the learning rate defined in the optimizer,
            it can be overriden at the:
                - module level with “self” key, all the parameters and child
                    module s parameters will be put to that parameter group
                - member level, which is the same as if the `param_groups` in that
                    member has key=“self” and value equal to that parameter group.
                    This is useful if members do not have `param_groups`, for
                    example torch.nn.Linear.
                - parameter level, parameter with the same name as the key
                    will be put to that parameter group.

And in the optimizer factory, parameters and their learning rates are recursively gathered.

Reviewed By: shapovalov

Differential Revision: D40145802

fbshipit-source-id: 631c02b8d79ee1c0eb4c31e6e42dbd3d2882078a
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 18, 2022
1 parent a819ecb commit fe5bdb2
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 5 deletions.
96 changes: 93 additions & 3 deletions projects/implicitron_trainer/impl/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple
from collections import defaultdict
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple

import torch.optim

Expand Down Expand Up @@ -64,6 +66,12 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
foreach: Whether to use new "foreach" implementation of optimizer where
available (e.g. requires PyTorch 1.12.0 for Adam)
group_learning_rates: Parameters or modules can be assigned to parameter
groups. This dictionary has names of those parameter groups as keys
and learning rates as values. All parameter group names have to be
defined in this dictionary. Parameters which do not have predefined
parameter group are put into "default" parameter group which has
`lr` as its learning rate.
"""

betas: Tuple[float, ...] = (0.9, 0.999)
Expand All @@ -78,6 +86,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
linear_exponential_lr_milestone: int = 200
linear_exponential_start_gamma: float = 0.1
foreach: Optional[bool] = True
group_learning_rates: Dict[str, float] = field(default_factory=lambda: {})

def __post_init__(self):
run_auto_creation(self)
Expand Down Expand Up @@ -115,8 +124,10 @@ def __call__(
# pyre-ignore[29]
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": self.lr}]
p_groups = [
{"params": params, "lr": self._get_group_learning_rate(group)}
for group, params in self._get_param_groups(model).items()
]

# Intialize the optimizer
optimizer_kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -233,3 +244,82 @@ def _get_optimizer_state(
else:
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
return optimizer_state

def _get_param_groups(
self, module: torch.nn.Module
) -> Dict[str, List[torch.nn.Parameter]]:
"""
Recursively visits all the modules inside the `module` and sorts all the
parameters in parameter groups.
Uses `param_groups` dictionary member, where keys are names of individual
parameters or module members and values are the names of the parameter groups
for those parameters or members. "self" key is used to denote the parameter groups
at the module level. Possible keys, including the "self" key do not have to
be defined. By default all parameters have the learning rate defined in the
optimizer. This can be overridden by setting the parameter group in `param_groups`
member of a specific module, it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will inherit it
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, only parameter with the same name as the key
will have it.
Args:
module: module from which to extract the parameters and their parameter
groups
Returns:
dictionary with parameter groups as keys and lists of parameters as values
"""

param_groups = defaultdict(list)

def traverse(module, default_group):
# If key self is defined in param_groups then chenge the default param
# group for all parameters and children in the module.
if hasattr(module, "param_groups") and "self" in module.param_groups:
default_group = module.param_groups["self"]

# Collect all the parameters that are directly inside the `module`,
# they will be in the default param group if they don't have
# defined group.
for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
if hasattr(module, "param_groups") and name in module.param_groups:
param_groups[module.param_groups[name]].append(param)
else:
param_groups[default_group].append(param)

# If children have defined default param group then use it else pass
# own default.
for child_name, child in module.named_children():
if (
hasattr(module, "param_groups")
and child_name in module.param_groups
):
traverse(child, module.param_groups[child_name])
else:
traverse(child, default_group)

traverse(module, "default")
return param_groups

def _get_group_learning_rate(self, group_name: str) -> float:
"""
Wraps the `group_learning_rates` dictionary providing errors and returns
`self.lr` for "default" group_name.
Args:
group_name: a string representing the name of the group
Returns:
learning rate for a specific group
"""
if group_name == "default":
return self.lr
lr = self.group_learning_rates.get(group_name, None)
if lr is None:
raise ValueError(f"no learning rate given for group {group_name}")
return lr
1 change: 1 addition & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
linear_exponential_lr_milestone: 200
linear_exponential_start_gamma: 0.1
foreach: true
group_learning_rates: {}
training_loop_ImplicitronTrainingLoop_args:
evaluator_class_type: ImplicitronEvaluator
evaluator_ImplicitronEvaluator_args:
Expand Down
162 changes: 162 additions & 0 deletions projects/implicitron_trainer/tests/test_optimizer_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest

import torch
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args

from ..impl.optimizer_factory import ImplicitronOptimizerFactory

internal = os.environ.get("FB_TEST", False)


class TestOptimizerFactory(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
expand_args_fields(ImplicitronOptimizerFactory)

def _get_param_groups(self, model):
default_cfg = get_default_args(ImplicitronOptimizerFactory)
factory = ImplicitronOptimizerFactory(default_cfg)
return factory._get_param_groups(model)

def _assert_allin(self, a, param_groups, key):
with self.subTest(f"Testing key {key}"):
b = param_groups[key]
for el in a:
if el not in b:
raise ValueError(
f"Element {el}\n\n from:\n\n {a}\n\n not in:\n\n {b}\n\n."
+ f" Full param groups = \n\n{param_groups}"
)
for el in b:
if el not in a:
raise ValueError(
f"Element {el}\n\n from:\n\n {b}\n\n not in:\n\n {a}\n\n."
+ f" Full param groups = \n\n{param_groups}"
)

def test_default_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node(params=[pb])
root = Node(children=[na, nb], params=[pc])
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default")

def test_member_overrides_default_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node(params=[pb])
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pc], param_groups, "default")
self._assert_allin([pb], param_groups, "pb")

def test_self_overrides_member_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node(params=[pb], param_groups={"self": "pb_self"})
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pc], param_groups, "default")
self._assert_allin([pb], param_groups, "pb_self")
assert len(param_groups["pb_member"]) == 0, param_groups

def test_param_overrides_self_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node(
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
)
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pc], param_groups, "default")
self._assert_allin([pb], param_groups, "pb_self")
assert len(param_groups["pb_member"]) == 0, param_groups

def test_no_param_groups_defined(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node(params=[pb])
root = Node(children=[na, nb], params=[pc])
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default")

def test_tree_param_groups_defined(self):
"""
Test generic tree assignment.
A0
|---------------------------
| | |
Bb M J-
|----- |-------
| | | |
C Ddg K Ll
|--------------
| | | |
E4 Ff G H-
All nodes have one parameter. Character next to the capital
letter means they have added something to their `parameter_groups`:
- small letter same as capital means self is set to that letter
- small letter different then capital means that member is set
(the one that is named like that)
- number means parameter's parameter_group is set like that
- "-" means it does not have `parameter_groups` member
"""
p = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(12)]
L = Node(params=[p[11]], param_groups={"self": "l"})
K = Node(params=[p[10]], param_groups={})
J = Node(params=[p[9]], param_groups=None, children=[K, L])
M = Node(params=[p[8]], param_groups={})

E = Node(params=[p[4]], param_groups={"p0": "4"})
F = Node(params=[p[5]], param_groups={"self": "f"})
G = Node(params=[p[6]], param_groups={})
H = Node(params=[p[7]], param_groups=None)

D = Node(
params=[p[3]], param_groups={"self": "d", "m2": "g"}, children=[E, F, G, H]
)
C = Node(params=[p[2]], param_groups={})

B = Node(params=[p[1]], param_groups={"self": "b"}, children=[C, D])

A = Node(params=[p[0]], param_groups={"p0": "0"}, children=[B, M, J])

param_groups = self._get_param_groups(A)

# if parts of the group belong to two different categories assert is repeated
# parameter level
self._assert_allin([p[0]], param_groups, "0")
self._assert_allin([p[4]], param_groups, "4")
# self level
self._assert_allin([p[5]], param_groups, "f")
self._assert_allin([p[11]], param_groups, "l")
self._assert_allin([p[2], p[1]], param_groups, "b")
self._assert_allin([p[7], p[3]], param_groups, "d")
# member level
self._assert_allin([p[6]], param_groups, "g")
# inherit level
self._assert_allin([p[7], p[3]], param_groups, "d")
self._assert_allin([p[2], p[1]], param_groups, "b")
# default level
self._assert_allin([p[8], p[9], p[10]], param_groups, "default")


class Node(torch.nn.Module):
def __init__(self, children=(), params=(), param_groups=None):
super().__init__()
for i, child in enumerate(children):
self.add_module("m" + str(i), child)
for i, param in enumerate(params):
setattr(self, "p" + str(i), param)
if param_groups is not None:
self.param_groups = param_groups

def __str__(self):
return (
"modules:\n" + str(self._modules) + "\nparameters\n" + str(self._parameters)
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
"""

import logging
from dataclasses import field

from enum import Enum
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch

Expand All @@ -42,8 +43,27 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
space and transforms it into the required quantity (for example density and color).
Members:
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
"""

param_groups: Dict[str, str] = field(default_factory=lambda: {})

def __post_init__(self):
super().__init__()

Expand Down
16 changes: 16 additions & 0 deletions pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
with mean=init_mean and std=init_std. Default 0.
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
will be saved as parameters and therefore be trainable. Default True.
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
"""

voxel_grid_class_type: str = "FullResolutionVoxelGrid"
Expand All @@ -820,6 +835,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
init_mean: float = 0

hold_voxel_grid_as_parameters: bool = True
param_groups: Dict[str, str] = field(default_factory=lambda: {})

def __post_init__(self):
super().__init__()
Expand Down
1 change: 0 additions & 1 deletion tests/implicitron/test_voxel_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
FullResolutionVoxelGridValues,
VMFactorizedVoxelGrid,
VoxelGridModule,
)
Expand Down

0 comments on commit fe5bdb2

Please sign in to comment.