diff --git a/requirements.txt b/requirements.txt index c5e6a6119321..5e813fb0420c 100755 --- a/requirements.txt +++ b/requirements.txt @@ -16,8 +16,9 @@ tqdm>=4.41.0 # logging ------------------------------------- # wandb -# coco ---------------------------------------- -# pycocotools>=2.0 +# plotting ------------------------------------ +seaborn +pandas # export -------------------------------------- # coremltools==4.0 @@ -26,4 +27,4 @@ tqdm>=4.41.0 # extras -------------------------------------- # thop # FLOPS computation -# seaborn # plotting +# pycocotools>=2.0 # COCO mAP diff --git a/test.py b/test.py index e50fb117ef80..0c06f37e5e4b 100644 --- a/test.py +++ b/test.py @@ -14,7 +14,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \ non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path from utils.loss import compute_loss -from utils.metrics import ap_per_class +from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target from utils.torch_utils import select_device, time_synchronized @@ -89,6 +89,7 @@ def test(data, dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0] seen = 0 + confusion_matrix = ConfusionMatrix(nc=nc) names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} coco91class = coco80_to_coco91_class() s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') @@ -176,6 +177,8 @@ def test(data, # target boxes tbox = xywh2xyxy(labels[:, 1:5]) scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels + if plots: + confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1)) # Per target class for cls in torch.unique(tcls_tensor): @@ -218,10 +221,12 @@ def test(data, else: nt = torch.zeros(1) - # W&B logging - if plots and wandb and wandb.run: - wandb.log({"Images": wandb_images}) - wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]}) + # Plots + if plots: + confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) + if wandb and wandb.run: + wandb.log({"Images": wandb_images}) + wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]}) # Print results pf = '%20s' + '%12.3g' * 6 # print format diff --git a/train.py b/train.py index a2244fcbf395..ca8148a82176 100644 --- a/train.py +++ b/train.py @@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if plots: plot_results(save_dir=save_dir) # save as results.png if wandb: - wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in - ['results.png', 'precision_recall_curve.png']]}) + files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png'] + wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files + if (save_dir / f).exists()]}) logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) else: dist.destroy_process_group() diff --git a/utils/metrics.py b/utils/metrics.py index 62add1da1f8d..26872f2704c5 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -4,6 +4,9 @@ import matplotlib.pyplot as plt import numpy as np +import torch + +from . import general def fitness(x): @@ -102,6 +105,84 @@ def compute_ap(recall, precision): return ap, mpre, mrec +class ConfusionMatrix: + # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix + def __init__(self, nc, conf=0.25, iou_thres=0.45): + self.matrix = np.zeros((nc + 1, nc + 1)) + self.nc = nc # number of classes + self.conf = conf + self.iou_thres = iou_thres + + def process_batch(self, detections, labels): + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + detections (Array[N, 6]), x1, y1, x2, y2, conf, class + labels (Array[M, 5]), class, x1, y1, x2, y2 + Returns: + None, updates confusion matrix accordingly + """ + detections = detections[detections[:, 4] > self.conf] + gt_classes = labels[:, 0].int() + detection_classes = detections[:, 5].int() + iou = general.box_iou(labels[:, 1:], detections[:, :4]) + + x = torch.where(iou > self.iou_thres) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + else: + matches = np.zeros((0, 3)) + + n = matches.shape[0] > 0 + m0, m1, _ = matches.transpose().astype(np.int16) + for i, gc in enumerate(gt_classes): + j = m0 == i + if n and sum(j) == 1: + self.matrix[gc, detection_classes[m1[j]]] += 1 # correct + else: + self.matrix[gc, self.nc] += 1 # background FP + + if n: + for i, dc in enumerate(detection_classes): + if not any(m1 == i): + self.matrix[self.nc, dc] += 1 # background FN + + def matrix(self): + return self.matrix + + def plot(self, save_dir='', names=()): + try: + import seaborn as sn + + array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig = plt.figure(figsize=(12, 9)) + sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size + labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels + sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, + xticklabels=names + ['background FN'] if labels else "auto", + yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1)) + fig.axes[0].set_xlabel('True') + fig.axes[0].set_ylabel('Predicted') + fig.tight_layout() + fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) + except Exception as e: + pass + + def print(self): + for i in range(self.nc + 1): + print(' '.join(map(str, self.matrix[i]))) + + +# Plots ---------------------------------------------------------------------------------------------------------------- + def plot_pr_curve(px, py, ap, save_dir='.', names=()): fig, ax = plt.subplots(1, 1, figsize=(9, 6)) py = np.stack(py, axis=1)