-
Notifications
You must be signed in to change notification settings - Fork 18
/
options.py
138 lines (115 loc) · 5.31 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
""" Options
This script is largely based on junyanz/pytorch-CycleGAN-and-pix2pix.
Returns:
[argparse]: Class containing argparse
"""
import argparse
import os
import torch
class Options():
"""Options class
Returns:
[argparse]: argparse containing train and test options
"""
def __init__(self):
# Inputs for the main function
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# original
self.parser.add_argument(
'--data_name',
choices=['sine', 'stock', 'energy'],
default='stock',
type=str)
self.parser.add_argument(
'--z_dim',
help='z or data dimension',
default=6,
type=int)
self.parser.add_argument(
'--seq_len',
help='sequence length',
default=24,
type=int)
self.parser.add_argument(
'--module',
choices=['gru', 'lstm', 'lstmLN'],
default='gru',
type=str)
self.parser.add_argument(
'--hidden_dim',
help='hidden state dimensions (should be optimized)',
default=24,
type=int)
self.parser.add_argument(
'--num_layer',
help='number of layers (should be optimized)',
default=3,
type=int)
self.parser.add_argument(
'--iteration',
help='Training iterations (should be optimized)',
default=50000,
type=int)
self.parser.add_argument(
'--batch_size',
help='the number of samples in mini-batch (should be optimized)',
default=128,
type=int)
self.parser.add_argument(
'--metric_iteration',
help='iterations of the metric computation',
default=10,
type=int)
# Add
self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
self.parser.add_argument('--device', type=str, default='gpu', help='Device: gpu | cpu')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
self.parser.add_argument('--model', type=str, default='TimeGAN', help='chooses which model to use. timegan')
self.parser.add_argument('--outf', default='./output', help='folder to output images and model checkpoints')
self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
self.parser.add_argument('--display', action='store_true', help='Use visdom.')
self.parser.add_argument('--manualseed', default=-1, type=int, help='manual seed')
# Train
self.parser.add_argument('--print_freq', type=int, default=1000, help='frequency of showing training results on console')
self.parser.add_argument('--load_weights', action='store_true', help='Load the pretrained weights')
self.parser.add_argument('--resume', default='', help="path to checkpoints (to continue training)")
self.parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
self.parser.add_argument('--w_gamma', type=float, default=1, help='Gamma weight')
self.parser.add_argument('--w_es', type=float, default=0.1, help='Encoder loss weight')
self.parser.add_argument('--w_e0', type=float, default=10, help='Encoder loss weight')
self.parser.add_argument('--w_g', type=float, default=100, help='Generator loss weight.')
self.isTrain = True
self.opt = None
def parse(self):
""" Parse Arguments.
"""
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if self.opt.device == 'gpu':
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
# save to the disk
if self.opt.name == 'experiment_name':
self.opt.name = "%s/%s" % (self.opt.model, self.opt.data_name)
expr_dir = os.path.join(self.opt.outf, self.opt.name)
if not os.path.isdir(expr_dir):
os.makedirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt