forked from shentianxiao/text-autoencoders
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
156 lines (146 loc) · 6.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import time
import os
import random
import collections
import numpy as np
import torch
from model import DAE, VAE, AAE
from vocab import Vocab
from meter import AverageMeter
from utils import set_seed, logging, load_sent
from batchify import get_batches
parser = argparse.ArgumentParser()
# Path arguments
parser.add_argument('--train', metavar='FILE', required=True,
help='path to training file')
parser.add_argument('--valid', metavar='FILE', required=True,
help='path to validation file')
parser.add_argument('--save-dir', default='checkpoints', metavar='DIR',
help='directory to save checkpoints and outputs')
parser.add_argument('--load-model', default='', metavar='FILE',
help='path to load checkpoint if specified')
# Architecture arguments
parser.add_argument('--vocab-size', type=int, default=10000, metavar='N',
help='keep N most frequent words in vocabulary')
parser.add_argument('--dim_z', type=int, default=128, metavar='D',
help='dimension of latent variable z')
parser.add_argument('--dim_emb', type=int, default=512, metavar='D',
help='dimension of word embedding')
parser.add_argument('--dim_h', type=int, default=1024, metavar='D',
help='dimension of hidden state per layer')
parser.add_argument('--nlayers', type=int, default=1, metavar='N',
help='number of layers')
parser.add_argument('--dim_d', type=int, default=512, metavar='D',
help='dimension of hidden state in AAE discriminator')
# Model arguments
parser.add_argument('--model_type', default='dae', metavar='M',
choices=['dae', 'vae', 'aae'],
help='which model to learn')
parser.add_argument('--lambda_kl', type=float, default=0, metavar='R',
help='weight for kl term in VAE')
parser.add_argument('--lambda_adv', type=float, default=0, metavar='R',
help='weight for adversarial loss in AAE')
parser.add_argument('--lambda_p', type=float, default=0, metavar='R',
help='weight for L1 penalty on posterior log-variance')
parser.add_argument('--noise', default='0,0,0,0', metavar='P,P,P,K',
help='word drop prob, blank prob, substitute prob'
'max word shuffle distance')
# Training arguments
parser.add_argument('--dropout', type=float, default=0.5, metavar='DROP',
help='dropout probability (0 = no dropout)')
parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
help='learning rate')
#parser.add_argument('--clip', type=float, default=0.25, metavar='NORM',
# help='gradient clipping')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
help='number of training epochs')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
help='batch size')
# Others
parser.add_argument('--seed', type=int, default=1111, metavar='N',
help='random seed')
parser.add_argument('--no-cuda', action='store_true',
help='disable CUDA')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='report interval')
def evaluate(model, batches):
model.eval()
meters = collections.defaultdict(lambda: AverageMeter())
with torch.no_grad():
for inputs, targets in batches:
losses = model.autoenc(inputs, targets)
for k, v in losses.items():
meters[k].update(v.item(), inputs.size(1))
loss = model.loss({k: meter.avg for k, meter in meters.items()})
meters['loss'].update(loss)
return meters
def main(args):
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
log_file = os.path.join(args.save_dir, 'log.txt')
logging(str(args), log_file)
# Prepare data
train_sents = load_sent(args.train)
logging('# train sents {}, tokens {}'.format(
len(train_sents), sum(len(s) for s in train_sents)), log_file)
valid_sents = load_sent(args.valid)
logging('# valid sents {}, tokens {}'.format(
len(valid_sents), sum(len(s) for s in valid_sents)), log_file)
vocab_file = os.path.join(args.save_dir, 'vocab.txt')
if not os.path.isfile(vocab_file):
Vocab.build(train_sents, vocab_file, args.vocab_size)
vocab = Vocab(vocab_file)
logging('# vocab size {}'.format(vocab.size), log_file)
set_seed(args.seed)
cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')
model = {'dae': DAE, 'vae': VAE, 'aae': AAE}[args.model_type](
vocab, args).to(device)
if args.load_model:
ckpt = torch.load(args.load_model)
model.load_state_dict(ckpt['model'])
model.flatten()
logging('# model parameters: {}'.format(
sum(x.data.nelement() for x in model.parameters())), log_file)
train_batches, _ = get_batches(train_sents, vocab, args.batch_size, device)
valid_batches, _ = get_batches(valid_sents, vocab, args.batch_size, device)
best_val_loss = None
for epoch in range(args.epochs):
start_time = time.time()
logging('-' * 80, log_file)
model.train()
meters = collections.defaultdict(lambda: AverageMeter())
indices = list(range(len(train_batches)))
random.shuffle(indices)
for i, idx in enumerate(indices):
inputs, targets = train_batches[idx]
losses = model.autoenc(inputs, targets, is_train=True)
losses['loss'] = model.loss(losses)
model.step(losses)
for k, v in losses.items():
meters[k].update(v.item())
if (i + 1) % args.log_interval == 0:
log_output = '| epoch {:3d} | {:5d}/{:5d} batches |'.format(
epoch + 1, i + 1, len(indices))
for k, meter in meters.items():
log_output += ' {} {:.2f},'.format(k, meter.avg)
meter.clear()
logging(log_output, log_file)
valid_meters = evaluate(model, valid_batches)
logging('-' * 80, log_file)
log_output = '| end of epoch {:3d} | time {:5.0f}s | valid'.format(
epoch + 1, time.time() - start_time)
for k, meter in valid_meters.items():
log_output += ' {} {:.2f},'.format(k, meter.avg)
if not best_val_loss or valid_meters['loss'].avg < best_val_loss:
log_output += ' | saving model'
ckpt = {'args': args, 'model': model.state_dict()}
torch.save(ckpt, os.path.join(args.save_dir, 'model.pt'))
best_val_loss = valid_meters['loss'].avg
logging(log_output, log_file)
logging('Done training', log_file)
if __name__ == '__main__':
args = parser.parse_args()
args.noise = [float(x) for x in args.noise.split(',')]
main(args)