-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
71 lines (61 loc) · 2.27 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Evaluates model on benchmark of tasks"""
import os
from pprint import pprint
import pytorch_lightning as pl
import hydra
from omegaconf import DictConfig, OmegaConf
import transformers
from claficle.data.benchmark import BenchmarkDataModule
from claficle.models.base import BaseModel
from claficle.utils.general import run_script_preamble
@hydra.main(version_base=None, config_path="../conf", config_name="eval")
def main(cfg: DictConfig):
# sets seed, parses model name
model: BaseModel
model, cfg = run_script_preamble(cfg)
lang = cfg.lang
# separate benchmarks by language
benchmark = BenchmarkDataModule(config=cfg.data, lang=lang)
benchmark.prepare_data()
benchmark.setup()
pprint(benchmark.get_metadata())
# so that the model knows names and metrics of dataloaders before testing
model.set_benchmark_metadata(benchmark.get_metadata())
# gewechselt models come with trained tokenizers
print("Loading tokenizer...")
if cfg.model.tokenizer_name is not None:
tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(
cfg.model.checkpoint_dir, "tokenizers", cfg.model.tokenizer_name
)
)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2-large")
benchmark.set_tokenizer(tokenizer)
# set up pl trainer (tester)
log_save_dir = os.path.join(
cfg.trainer.log_dir, cfg.model.name, f"seed_{cfg.seed}", lang
)
os.makedirs(log_save_dir, exist_ok=True)
script_host = "slurm" if "SLURM_JOB_ID" in os.environ else "local"
logger = pl.loggers.WandbLogger(
save_dir=log_save_dir,
project="claficle",
entity="giulio-uva",
job_type="eval",
mode="disabled" if cfg.disable_wandb else "online",
group=script_host,
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
log_model=False, # don't log or upload artifacts
)
trainer = pl.Trainer(
logger=logger,
enable_progress_bar=cfg.trainer.progress_bar,
accelerator=cfg.trainer.accelerator,
devices=cfg.trainer.devices,
)
print(f"Evaluating in {lang}...")
trainer.test(model, datamodule=benchmark)
print("Done.")
if __name__ == "__main__":
main()