-
Notifications
You must be signed in to change notification settings - Fork 2
/
decorators.py
66 lines (54 loc) · 2.03 KB
/
decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import inspect
import pickle
import os
def auto_init_args(init):
def new_init(self, *args, **kwargs):
arg_dict = inspect.signature(init).parameters
arg_names = list(arg_dict.keys())[1:] # skip self
proc_names = set()
for name, arg in zip(arg_names, args):
setattr(self, name, arg)
proc_names.add(name)
for name, arg in kwargs.items():
setattr(self, name, arg)
proc_names.add(name)
remain_names = set(arg_names) - proc_names
if len(remain_names):
for name in remain_names:
setattr(self, name, arg_dict[name].default)
init(self, *args, **kwargs)
return new_init
def auto_init_pytorch(init):
def new_init(self, *args, **kwargs):
init(self, *args, **kwargs)
# self.apply(self.init_weights)
self.opt = self.init_optimizer(
self.expe.config.opt,
self.expe.config.lr,
self.expe.config.l2)
if not self.expe.config.resume:
self.to(self.device)
self.expe.log.info(
"transferred model to {}".format(self.device))
self.expe.log.info("#all parameters: {}, #trainable parameters: {}"
.format(self.count_all_parameters(),
self.count_trainable_parameters()))
return new_init
class lazy_execute:
@auto_init_args
def __init__(self, func_name):
pass
def __call__(self, fn):
func_name = self.func_name
def new_fn(self, *args, **kwargs):
file_name = kwargs.pop('file_name')
if os.path.isfile(file_name):
return getattr(self, func_name)(file_name)
else:
data = fn(self, *args, **kwargs)
self.expe.log.info("saving to {}"
.format(file_name))
with open(file_name, "wb+") as fp:
pickle.dump(data, fp, protocol=-1)
return data
return new_fn