-
Notifications
You must be signed in to change notification settings - Fork 0
/
inner_loop_optimizers.py
149 lines (133 loc) · 6.5 KB
/
inner_loop_optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import logging
import os
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class GradientDescentLearningRule(nn.Module):
"""Simple (stochastic) gradient descent learning rule.
For a scalar error function `E(p[0], p_[1] ... )` of some set of
potentially multidimensional parameters this attempts to find a local
minimum of the loss function by applying updates to each parameter of the
form
p[i] := p[i] - learning_rate * dE/dp[i]
With `learning_rate` a positive scaling parameter.
The error function used in successive applications of these updates may be
a stochastic estimator of the true error function (e.g. when the error with
respect to only a subset of data-points is calculated) in which case this
will correspond to a stochastic gradient descent learning rule.
"""
def __init__(self, device, learning_rate=1e-3):
"""Creates a new learning rule object.
Args:
learning_rate: A postive scalar to scale gradient updates to the
parameters by. This needs to be carefully set - if too large
the learning dynamic will be unstable and may diverge, while
if set too small learning will proceed very slowly.
"""
super(GradientDescentLearningRule, self).__init__()
assert learning_rate > 0.0, "learning_rate should be positive."
self.learning_rate = torch.ones(1) * learning_rate
self.learning_rate.to(device)
def update_params(
self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9
):
"""Applies a single gradient descent update to all parameters.
All parameter updates are performed using in-place operations and so
nothing is returned.
Args:
grads_wrt_params: A list of gradients of the scalar loss function
with respect to each of the parameters passed to `initialise`
previously, with this list expected to be in the same order.
"""
updated_names_weights_dict = dict()
for key in names_weights_dict.keys():
updated_names_weights_dict[key] = (
names_weights_dict[key]
- self.learning_rate * names_grads_wrt_params_dict[key]
)
return updated_names_weights_dict
class LSLRGradientDescentLearningRule(nn.Module):
"""Simple (stochastic) gradient descent learning rule.
For a scalar error function `E(p[0], p_[1] ... )` of some set of
potentially multidimensional parameters this attempts to find a local
minimum of the loss function by applying updates to each parameter of the
form
p[i] := p[i] - learning_rate * dE/dp[i]
With `learning_rate` a positive scaling parameter.
The error function used in successive applications of these updates may be
a stochastic estimator of the true error function (e.g. when the error with
respect to only a subset of data-points is calculated) in which case this
will correspond to a stochastic gradient descent learning rule.
"""
def __init__(
self,
device,
total_num_inner_loop_steps,
use_learnable_learning_rates,
init_learning_rate=1e-3,
init_class_head_lr_multiplier=1,
):
"""Creates a new learning rule object.
Args:
init_learning_rate: A postive scalar to scale gradient updates to the
parameters by. This needs to be carefully set - if too large
the learning dynamic will be unstable and may diverge, while
if set too small learning will proceed very slowly.
"""
super(LSLRGradientDescentLearningRule, self).__init__()
print(init_learning_rate)
assert init_learning_rate > 0.0, "learning_rate should be positive."
self.init_learning_rate = torch.ones(1) * init_learning_rate
self.init_class_head_lr_multiplier = init_class_head_lr_multiplier
self.total_num_inner_loop_steps = total_num_inner_loop_steps
self.use_learnable_learning_rates = use_learnable_learning_rates
def initialise(self, names_weights_dict):
self.names_learning_rates_dict = nn.ParameterDict()
for idx, (key, param) in enumerate(names_weights_dict.items()):
multiplier = 1
if "classifier" in key:
multiplier = self.init_class_head_lr_multiplier
self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter( #
data=torch.ones(self.total_num_inner_loop_steps)
* self.init_learning_rate
* multiplier,
requires_grad=self.use_learnable_learning_rates,
)
def reset(self):
# for key, param in self.names_learning_rates_dict.items():
# param.fill_(self.init_learning_rate)
pass
def update_params_(self, names_weights_dict, names_grads_wrt_params_dict, num_step):
for key in names_grads_wrt_params_dict.keys():
if key in ["bert.pooler.dense.weight", "bert.pooler.dense.bias"]:
continue
names_weights_dict[key] = (
names_weights_dict[key]
- self.names_learning_rates_dict[key.replace(".", "-")][num_step]
* names_grads_wrt_params_dict[key]
)
def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step):
"""Applies a single gradient descent update to all parameters.
All parameter updates are performed using in-place operations and so
nothing is returned.
Args:
grads_wrt_params: A list of gradients of the scalar loss function
with respect to each of the parameters passed to `initialise`
previously, with this list expected to be in the same order.
"""
updated_names_weights_dict = dict()
# device = names_weights_dict[next(iter(names_weights_dict.keys()))].device
for key in names_grads_wrt_params_dict.keys():
if "LayerNorm" in key:
# don't update, just pass on
updated_names_weights_dict[key] = names_weights_dict[key]
else:
updated_names_weights_dict[key] = (
names_weights_dict[key]
- self.names_learning_rates_dict[key.replace(".", "-")][num_step]
* names_grads_wrt_params_dict[key]
)
return updated_names_weights_dict