-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
95 lines (74 loc) · 2.85 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
from operator import index
import pandas as pd
import os
import numpy as np
import torch
import timm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.common import tta_predict
from sklearn.metrics import roc_auc_score
from dataloader.dataset import MyDataset
from dataloader.augment import valid_transform
from multiprocessing import cpu_count
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data-csv', type=str)
parser.add_argument('--weight', type=str)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--save-dir', type=str, default='/content/exp')
parser.add_argument('--fold', type=int, default=0)
parser.add_argument('--img_size', type=int, default=448)
parser.add_argument('--tta', action='store_true')
return parser.parse_args()
def parse_model_info(weight_path):
fname = weight_path.split(os.sep)[-1]
model_name, *rest = fname.split('_fold')
fold = int(rest[0].split('_')[0])
return model_name, fold
def main(args):
model_name, fold = parse_model_info(args.weight)
print(f'Evaluate model {model_name} - fold {fold}')
df = pd.read_csv(args.data_csv)
os.makedirs(args.save_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model(model_name, pretrained=False, num_classes=4).to(device)
model.load_state_dict(torch.load(args.weight, map_location='cpu'))
val_data = df[df.fold == fold].reset_index(drop=True)
val_data= MyDataset(val_data, transform=valid_transform(),
img_size=args.img_size, return_path=True)
valid_loader = DataLoader(val_data,
shuffle=False,
num_workers=cpu_count(),
batch_size=args.batch_size)
bar = tqdm(valid_loader)
preds = []
y_true = []
paths = []
for path, image, labels in bar:
image = image.to(device)
with torch.no_grad():
if args.tta:
output = tta_predict(model, image)
else:
output = model(image)
preds.append(F.softmax(output, dim=1).cpu().detach().numpy())
y_true.append(labels.detach().cpu().numpy())
paths.append(path)
preds = np.concatenate(preds)
y_true = np.concatenate(y_true)
paths = np.concatenate(paths)
auc = roc_auc_score(y_true, preds, multi_class='ovr')
print(f'AUC {auc}')
preds = [','.join([str(p) for p in pred]) for pred in preds]
df = pd.DataFrame({
'path': paths,
'label': y_true,
'pred': preds
})
df.to_csv(os.path.join(args.save_dir, 'prediction.csv'), index=False)
print(f'Result saved to {args.save_dir}')
if __name__ == '__main__':
main(parse_args())