forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 1
/
score_tta.py
36 lines (28 loc) · 1.15 KB
/
score_tta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.model import BaseTTAModel
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[DataSample]],
) -> List[DataSample]:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[DataSample]]): List of predictions
of all enhanced data.
Returns:
List[DataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample: DataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample