forked from lorenlugosch/end-to-end-SLU
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
151 lines (145 loc) · 5.95 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from tqdm import tqdm # for displaying progress bar
import os
from data import SLUDataset, ASRDataset
from models import PretrainedModel, Model
import pandas as pd
class Trainer:
def __init__(self, model, config):
self.model = model
self.config = config
if isinstance(self.model, PretrainedModel):
self.lr = config.pretraining_lr
self.checkpoint_path = os.path.join(self.config.folder, "pretraining")
else:
self.lr = config.training_lr
self.checkpoint_path = os.path.join(self.config.folder, "training")
self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
self.epoch = 0
self.df = None
def load_checkpoint(self):
if os.path.isfile(os.path.join(self.checkpoint_path, "model_state.pth")):
try:
if self.model.is_cuda:
self.model.load_state_dict(torch.load(os.path.join(self.checkpoint_path, "model_state.pth")))
else:
self.model.load_state_dict(torch.load(os.path.join(self.checkpoint_path, "model_state.pth"), map_location="cpu"))
except:
print("Could not load previous model; starting from scratch")
else:
print("No previous model; starting from scratch")
def save_checkpoint(self):
try:
torch.save(self.model.state_dict(), os.path.join(self.checkpoint_path, "model_state.pth"))
except:
print("Could not save model")
def log(self, results):
if self.df is None:
self.df = pd.DataFrame(columns=[field for field in results])
self.df.loc[len(self.df)] = results
self.df.to_csv(os.path.join(self.checkpoint_path, "log.csv"))
def train(self, dataset, print_interval=100):
# TODO: refactor to remove if-statement?
if isinstance(dataset, ASRDataset):
train_phone_acc = 0
train_phone_loss = 0
train_word_acc = 0
train_word_loss = 0
num_examples = 0
self.model.train()
for idx, batch in enumerate(tqdm(dataset.loader)):
x,y_phoneme,y_word = batch
batch_size = len(x)
num_examples += batch_size
phoneme_loss, word_loss, phoneme_acc, word_acc = self.model(x,y_phoneme,y_word)
if self.config.pretraining_type == 1: loss = phoneme_loss
if self.config.pretraining_type == 2: loss = phoneme_loss + word_loss
if self.config.pretraining_type == 3: loss = word_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_phone_loss += phoneme_loss.cpu().data.numpy().item() * batch_size
train_word_loss += word_loss.cpu().data.numpy().item() * batch_size
train_phone_acc += phoneme_acc.cpu().data.numpy().item() * batch_size
train_word_acc += word_acc.cpu().data.numpy().item() * batch_size
if idx % print_interval == 0:
print("phoneme loss: " + str(phoneme_loss.cpu().data.numpy().item()))
print("word loss: " + str(word_loss.cpu().data.numpy().item()))
print("phoneme acc: " + str(phoneme_acc.cpu().data.numpy().item()))
print("word acc: " + str(word_acc.cpu().data.numpy().item()))
train_phone_loss /= num_examples
train_phone_acc /= num_examples
train_word_loss /= num_examples
train_word_acc /= num_examples
results = {"phone_loss" : train_phone_loss, "phone_acc" : train_phone_acc, "word_loss" : train_word_loss, "word_acc" : train_word_acc, "set": "train"}
self.log(results)
self.epoch += 1
return train_phone_acc, train_phone_loss, train_word_acc, train_word_loss
else: # SLUDataset
train_intent_acc = 0
train_intent_loss = 0
num_examples = 0
self.model.train()
self.model.print_frozen()
for idx, batch in enumerate(tqdm(dataset.loader)):
x,y_intent = batch
batch_size = len(x)
num_examples += batch_size
intent_loss, intent_acc = self.model(x,y_intent)
loss = intent_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_intent_loss += intent_loss.cpu().data.numpy().item() * batch_size
train_intent_acc += intent_acc.cpu().data.numpy().item() * batch_size
if idx % print_interval == 0:
print("intent loss: " + str(intent_loss.cpu().data.numpy().item()))
print("intent acc: " + str(intent_acc.cpu().data.numpy().item()))
train_intent_loss /= num_examples
train_intent_acc /= num_examples
self.model.unfreeze_one_layer()
results = {"intent_loss" : train_intent_loss, "intent_acc" : train_intent_acc, "set": "train"}
self.log(results)
self.epoch += 1
return train_intent_acc, train_intent_loss
def test(self, dataset):
if isinstance(dataset, ASRDataset):
test_phone_acc = 0
test_phone_loss = 0
test_word_acc = 0
test_word_loss = 0
num_examples = 0
self.model.eval()
for idx, batch in enumerate(dataset.loader):
x,y_phoneme,y_word = batch
batch_size = len(x)
num_examples += batch_size
phoneme_loss, word_loss, phoneme_acc, word_acc = self.model(x,y_phoneme,y_word)
test_phone_loss += phoneme_loss.cpu().data.numpy().item() * batch_size
test_word_loss += word_loss.cpu().data.numpy().item() * batch_size
test_phone_acc += phoneme_acc.cpu().data.numpy().item() * batch_size
test_word_acc += word_acc.cpu().data.numpy().item() * batch_size
test_phone_loss /= num_examples
test_phone_acc /= num_examples
test_word_loss /= num_examples
test_word_acc /= num_examples
results = {"phone_loss" : test_phone_loss, "phone_acc" : test_phone_acc, "word_loss" : test_word_loss, "word_acc" : test_word_acc,"set": "valid"}
self.log(results)
return test_phone_acc, test_phone_loss, test_word_acc, test_word_loss
else:
test_intent_acc = 0
test_intent_loss = 0
num_examples = 0
self.model.eval()
for idx, batch in enumerate(dataset.loader):
x,y_intent = batch
batch_size = len(x)
num_examples += batch_size
intent_loss, intent_acc = self.model(x,y_intent)
test_intent_loss += intent_loss.cpu().data.numpy().item() * batch_size
test_intent_acc += intent_acc.cpu().data.numpy().item() * batch_size
test_intent_loss /= num_examples
test_intent_acc /= num_examples
results = {"intent_loss" : test_intent_loss, "intent_acc" : test_intent_acc, "set": "valid"}
self.log(results)
return test_intent_acc, test_intent_loss