-
Notifications
You must be signed in to change notification settings - Fork 0
/
maml.py
748 lines (618 loc) · 25.8 KB
/
maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
import os
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import gc
from transformers import AutoModel, AutoModelForSequenceClassification, AutoConfig
#from transformers import AdapterConfig, AdapterType
from few_shot_learning_system import *
from meta_bert import MetaBERT, distil_state_dict_to_bert
from inner_loop_optimizers import LSLRGradientDescentLearningRule
from ranger import Ranger
def print_torch_stats():
reserved = torch.cuda.memory_reserved()
allocated = torch.cuda.memory_allocated()
print(f"Allocated : {allocated}, Reserved : {reserved}")
class MAMLFewShotClassifier(FewShotClassifier):
def __init__(self, device, args):
"""
Initializes a MAML few shot learning system
:param device: The device to use to use the model on.
:param args: A namedtuple of arguments specifying various hyperparameters.
"""
super(MAMLFewShotClassifier, self).__init__(device, args)
print("Init")
print_torch_stats()
config = AutoConfig.from_pretrained(args.pretrained_weights)
config.num_labels = args.num_classes_per_set
model_initialization = AutoModelForSequenceClassification.from_pretrained(
args.pretrained_weights, config=config
)
print("Base model created")
print_torch_stats()
slow_model = MetaBERT
# Init fast model
state_dict = model_initialization.state_dict()
config = model_initialization.config
del model_initialization
print("Base model deleted")
print_torch_stats()
# Slow model
self.classifier = slow_model.init_from_pretrained(
state_dict,
config,
num_labels=args.num_classes_per_set,
is_distil=self.is_distil,
is_xlm=self.is_xlm,
per_step_layer_norm_weights=args.per_step_layer_norm_weights,
num_inner_loop_steps=args.number_of_training_steps_per_iter,
device=device,
)
print("Classifier model created")
print_torch_stats()
self.classifier.to("cpu")
self.classifier.train()
print("Classifier moved to CPU")
print_torch_stats()
self.inner_loop_optimizer = LSLRGradientDescentLearningRule(
device=torch.device("cpu"),
init_learning_rate=self.task_learning_rate,
total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter,
use_learnable_learning_rates=self.args.learnable_per_layer_per_step_inner_loop_learning_rate,
init_class_head_lr_multiplier=self.args.init_class_head_lr_multiplier,
)
self.inner_loop_optimizer.initialise(
names_weights_dict=self.get_inner_loop_parameter_dict(
params=self.classifier.named_parameters()
)
)
print("Inner Loop parameters")
for key, value in self.inner_loop_optimizer.named_parameters():
print(key, value.shape)
print("Outer Loop parameters")
for name, param in self.named_parameters():
if param.requires_grad:
print(name, param.shape, param.device, param.requires_grad)
self.optimizer = Ranger(
[
{"params": self.classifier.parameters(), "lr": args.meta_learning_rate},
{
"params": self.inner_loop_optimizer.parameters(),
"lr": args.meta_inner_optimizer_learning_rate,
},
],
lr=args.meta_learning_rate,
)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=self.optimizer,
T_max=self.args.total_epochs * self.args.total_iter_per_epoch,
eta_min=self.args.min_learning_rate,
)
self.inner_loop_optimizer.to(self.device)
print("Optimizer moved to GPU")
print_torch_stats()
self.clip_value = 1.0
# gradient clipping
for p in self.classifier.parameters():
if p.requires_grad:
p.register_hook(
lambda grad: torch.clamp(grad, -self.clip_value, self.clip_value)
)
self.num_freeze_epochs = args.num_freeze_epochs
if self.num_freeze_epochs > 0:
self.classifier.freeze()
def get_inner_loop_parameter_dict(self, params, adapter_only=False):
"""
Returns a dictionary with the parameters to use for inner loop updates.
:param params: A dictionary of the network's parameters.
:return: A dictionary of the parameters to use for the inner loop optimization process.
"""
param_dict = dict()
for name, param in params:
if param.requires_grad:
key = (
name.replace("module.", "", 1)
if name.startswith("module.")
else name
)
if self.args.enable_inner_loop_optimizable_ln_params:
param_dict[key] = param.to(device=self.device)
else:
if "LayerNorm" not in key:
if adapter_only:
if "adapter" in key or "classifier" in key:
param_dict[key] = param.to(device=self.device)
else:
print(key)
else:
param_dict[key] = param.to(device=self.device)
return param_dict
def apply_inner_loop_update(
self,
loss,
names_weights_copy,
use_second_order,
current_step_idx,
allow_unused=True,
):
"""
Applies an inner loop update given current step's loss, the weights to update, a flag indicating whether to use
second order derivatives and the current step's index.
:param loss: Current step's loss with respect to the support set.
:param names_weights_copy: A dictionary with names to parameters to update.
:param use_second_order: A boolean flag of whether to use second order derivatives.
:param current_step_idx: Current step's index.
:return: A dictionary with the updated weights (name, param)
"""
all_names = list(names_weights_copy.keys())
names_weights_copy = {
k: v for k, v in names_weights_copy.items() if v is not None
}
grads = torch.autograd.grad(
loss,
names_weights_copy.values(),
create_graph=use_second_order,
allow_unused=allow_unused,
)
names_grads_wrt_params = dict(zip(names_weights_copy.keys(), grads))
names_weights_copy = self.inner_loop_optimizer.update_params(
names_weights_dict=names_weights_copy,
names_grads_wrt_params_dict=names_grads_wrt_params,
num_step=current_step_idx,
)
del names_grads_wrt_params
for name in all_names:
if name not in names_weights_copy.keys():
names_weights_copy[name] = None
return names_weights_copy
def net_forward(
self,
x,
teacher_unary,
fast_model,
training,
num_step,
return_nr_correct=False,
mask=None,
task_name="",
):
student_logits = self.classifier(
input_ids=x, attention_mask=mask, num_step=num_step, params=fast_model
)[0]
set_kl_loss = False
if task_name in self.gold_label_tasks and self.meta_loss.lower() == "kl":
set_kl_loss = True
self.meta_loss = "ce"
#print(teacher_unary, student_logits.shape)
loss = self.inner_loss(
student_logits, teacher_unary, return_nr_correct=return_nr_correct
)
if set_kl_loss:
self.meta_loss = "kl"
return loss
def forward(
self,
data_batch,
epoch,
use_second_order,
num_steps,
training_phase,
):
"""
Runs a forward outer loop pass on the batch of tasks using the MAML/++ framework.
:param data_batch: A data batch containing the support and target sets.
:param epoch: Current epoch's index
:param use_second_order: A boolean saying whether to use second order derivatives.
:param use_multi_step_loss_optimization: Whether to optimize on the outer loop using just the last step's
target loss (True) or whether to use multi step loss which improves the stability of the system (False)
:param num_steps: Number of inner loop steps.
:param training_phase: Whether this is a training phase (True) or an evaluation phase (False)
:return: A dictionary with the collected losses of the current outer forward propagation.
"""
(
x_support_set,
len_support_set,
x_target_set,
len_target_set,
y_support_set,
y_target_set,
teacher_names,
) = data_batch
meta_batch_size = self.args.batch_size
self.classifier.zero_grad()
if self.num_freeze_epochs <= epoch:
self.classifier.unfreeze()
losses = {"loss": 0}
task_accuracies = []
task_lang_logs = []
for (
task_id,
(
x_support_set_task,
len_support_set_task,
y_support_set_task,
x_target_set_task,
len_target_set_task,
y_target_set_task,
teacher_name,
),
) in enumerate(
zip(
x_support_set,
len_support_set,
y_support_set,
x_target_set,
len_target_set,
y_target_set,
teacher_names,
)
):
task_lang_log = [teacher_name, epoch]
task_losses = []
# freeze and unfreeze if necessary to get correct params
if epoch <= self.num_freeze_epochs:
self.classifier.unfreeze()
fast_weights = self.classifier.get_inner_loop_params()
if epoch < self.num_freeze_epochs:
self.classifier.freeze()
total_task_loss = 0
x_support_set_task = x_support_set_task.squeeze()
len_support_set_task = len_support_set_task.squeeze()
y_support_set_task = y_support_set_task.squeeze()
x_target_set_task = x_target_set_task.squeeze()
len_target_set_task = len_target_set_task.squeeze()
y_target_set_task = y_target_set_task.squeeze()
for num_step in range(num_steps):
torch.cuda.empty_cache()
support_loss, is_correct = self.net_forward(
x=x_support_set_task,
mask=len_support_set_task,
num_step=num_step,
teacher_unary=y_support_set_task,
fast_model=fast_weights,
training=True,
return_nr_correct=True,
task_name=teacher_name,
)
fast_weights = self.apply_inner_loop_update(
loss=support_loss,
names_weights_copy=fast_weights,
use_second_order=use_second_order,
current_step_idx=num_step,
)
if num_step == (self.args.number_of_training_steps_per_iter - 1):
# store support set statistics
task_lang_log.append(support_loss.detach().item())
task_lang_log.append(np.mean(is_correct))
target_loss, is_correct = self.net_forward(
x=x_target_set_task,
mask=len_target_set_task,
teacher_unary=y_target_set_task,
num_step=num_step,
fast_model=fast_weights,
training=True,
return_nr_correct=True,
task_name=teacher_name,
)
task_losses.append(target_loss)
accuracy = np.mean(is_correct)
task_accuracies.append(accuracy)
# store query set statistics
task_lang_log.append(target_loss.detach().item())
task_lang_log.append(accuracy)
# Achieve gradient accumulation by already backpropping current loss
torch.cuda.empty_cache()
task_losses = torch.sum(torch.stack(task_losses)) / meta_batch_size
task_losses.backward()
total_task_loss += task_losses.detach().cpu().item()
losses["loss"] += total_task_loss
task_lang_logs.append(task_lang_log)
torch.cuda.synchronize()
losses["accuracy"] = np.mean(task_accuracies)
if training_phase:
return losses, task_lang_logs
else:
return losses
def finetune_epoch(
self,
names_weights_copy,
model_config,
train_dataloader,
dev_dataloader,
best_loss,
eval_every,
model_save_dir,
task_name,
epoch,
train_on_cpu=False,
writer=None,
):
"""
Finetunes the meta-learned classifier on a dataset
:param train_dataloader: Dataloader with train examples
:param dev_dataloader: Dataloader with validation examples
:param best_loss: best achieved loss on dev set up till now
:param eval_every: eval on dev set after eval_every updates
:param model_save_dir: directory to save the model to
:param task_name: name of the task finetuning is performed on
:param epoch: current epoch number
:return: best_loss
"""
if train_on_cpu:
self.device = torch.device("cpu")
self.inner_loop_optimizer.requires_grad_(False)
self.inner_loop_optimizer.eval()
self.inner_loop_optimizer.to(self.device)
self.classifier.to(self.device)
if names_weights_copy is None:
if epoch <= self.num_freeze_epochs:
self.classifier.unfreeze()
# # Get fast weights
names_weights_copy = self.classifier.get_inner_loop_params()
if epoch < self.num_freeze_epochs:
self.classifier.freeze()
eval_every = (
eval_every if eval_every < len(train_dataloader) else len(train_dataloader)
)
if writer is not None: # create histogram of weights
for param_name, param in names_weights_copy.items():
writer.add_histogram(task_name + "/" + param_name, param, 0)
writer.flush()
with tqdm(
initial=0, total=eval_every * self.args.number_of_training_steps_per_iter
) as pbar_train:
for batch_idx, batch in enumerate(train_dataloader):
torch.cuda.empty_cache()
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
for train_step in range(self.args.number_of_training_steps_per_iter):
support_loss = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
num_step=train_step,
fast_model=names_weights_copy,
training=True,
)
names_weights_copy = self.apply_inner_loop_update(
loss=support_loss,
names_weights_copy=names_weights_copy,
use_second_order=False,
current_step_idx=train_step,
)
self.inner_loop_optimizer.zero_grad()
pbar_train.update(1)
pbar_train.set_description(
"finetuning phase {} -> loss: {}".format(
batch_idx * self.args.number_of_training_steps_per_iter
+ train_step
+ 1,
support_loss.item(),
)
)
if writer is not None: # create histogram of weights
for param_name, param in names_weights_copy.items():
writer.add_histogram(
task_name + "/" + param_name, param, train_step + 1
)
writer.flush()
if (batch_idx + 1) % eval_every == 0:
print("Evaluating model...")
losses = []
is_correct_preds = []
if train_on_cpu:
self.device = torch.device("cuda")
self.classifier.to(self.device)
with torch.no_grad():
for batch in tqdm(
dev_dataloader,
desc="Evaluating",
leave=False,
total=len(dev_dataloader),
):
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
loss, is_correct = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
fast_model=names_weights_copy,
training=False,
return_nr_correct=True,
num_step=train_step,
)
losses.append(loss.item())
is_correct_preds.extend(is_correct.tolist())
avg_loss = np.mean(losses)
accuracy = np.mean(is_correct_preds)
print("Accuracy", accuracy)
if avg_loss < best_loss:
best_loss = avg_loss
print(
"New best finetuned model with loss {:.05f}".format(
best_loss
)
)
torch.save(
names_weights_copy,
os.path.join(
model_save_dir,
"model_finetuned_{}".format(
task_name.replace("train/", "", 1)
.replace("val/", "", 1)
.replace("test/", "", 1)
),
),
)
return names_weights_copy, best_loss, avg_loss, accuracy
def single_task_finetune(
self,
train_dataloader,
dev_dataloader,
task_name,
num_epochs,
train_on_cpu=False
):
"""
Finetunes the meta-learned classifier on a dataset
"""
if train_on_cpu:
self.device = torch.device("cpu")
self.inner_loop_optimizer.requires_grad_(False)
self.inner_loop_optimizer.eval()
self.inner_loop_optimizer.to(self.device)
self.classifier.to(self.device)
self.classifier.unfreeze()
names_weights_copy = self.classifier.get_inner_loop_params()
print(f"Dataloader sizes : Train - {len(train_dataloader)}, Val - {len(dev_dataloader)}")
train_step = 1
losses = []
with tqdm(
initial=0, total=num_epochs * len(train_dataloader)
) as pbar_train:
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(train_dataloader):
torch.cuda.empty_cache()
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
support_loss = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
num_step=train_step,
fast_model=names_weights_copy,
training=True,
)
names_weights_copy = self.apply_inner_loop_update(
loss=support_loss,
names_weights_copy=names_weights_copy,
use_second_order=False,
current_step_idx=train_step,
)
self.inner_loop_optimizer.zero_grad()
pbar_train.update(1)
losses.append(support_loss.item())
pbar_train.set_description(
"finetuning phase {} -> loss: {}".format(
len(train_dataloader) * epoch
+ batch_idx
+ 1,
np.mean(losses),
)
)
print("Evaluating model...")
losses = []
is_correct_preds = []
if train_on_cpu:
self.device = torch.device("cuda")
self.classifier.to(self.device)
with torch.no_grad():
for batch in tqdm(
dev_dataloader,
desc="Evaluating",
leave=False,
total=len(dev_dataloader),
):
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
loss, is_correct = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
fast_model=names_weights_copy,
training=False,
return_nr_correct=True,
num_step=train_step,
)
losses.append(loss.item())
is_correct_preds.extend(is_correct.tolist())
avg_loss = np.mean(losses)
accuracy = np.mean(is_correct_preds)
print("Val loss : ", avg_loss, "Val accuracy : ", accuracy)
return avg_loss, accuracy
def single_task_finetune_hf(
self,
train_dataloader,
dev_dataloader,
task_name,
num_epochs,
train_on_cpu=False
):
"""
Finetunes the meta-learned classifier on a dataset
"""
if train_on_cpu:
self.device = torch.device("cpu")
print("Training using HF methods")
from torch.utils.data import DataLoader
from transformers import AdamW
self.classifier.to(self.device)
self.classifier.unfreeze()
self.classifier.train()
optim = AdamW(self.classifier.parameters(), lr=2e-5)
print(f"Dataloader sizes : Train - {len(train_dataloader)}, Val - {len(dev_dataloader)}")
train_step = 1
with tqdm(
initial=0, total=num_epochs * len(train_dataloader)
) as pbar_train:
for epoch in range(num_epochs):
self.classifier.train()
for batch_idx, batch in enumerate(train_dataloader):
#torch.cuda.empty_cache()
optim.zero_grad()
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
support_loss = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
num_step=train_step,
fast_model=None,
training=True,
)
support_loss.backward()
optim.step()
pbar_train.update(1)
pbar_train.set_description(
"finetuning phase {} -> loss: {}".format(
len(train_dataloader) * epoch
+ batch_idx
+ 1,
support_loss,
)
)
self.classifier.eval()
print("Evaluating model...")
losses = []
is_correct_preds = []
if train_on_cpu:
self.device = torch.device("cuda")
self.classifier.to(self.device)
with torch.no_grad():
for batch in tqdm(
dev_dataloader,
desc="Evaluating",
leave=False,
total=len(dev_dataloader),
):
batch = tuple(t.to(self.device) for t in batch)
x, mask, y_true = batch
loss, is_correct = self.net_forward(
x,
mask=mask,
teacher_unary=y_true,
fast_model=None,
training=False,
return_nr_correct=True,
num_step=train_step,
)
losses.append(loss.item())
is_correct_preds.extend(is_correct.tolist())
avg_loss = np.mean(losses)
accuracy = np.mean(is_correct_preds)
print("Val loss : ", avg_loss, "Val accuracy : ", accuracy)
return avg_loss, accuracy