-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
166 lines (126 loc) · 6.15 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
class HardTripletLoss(nn.Module):
"""Hard/Hardest Triplet Loss
(pytorch implementation of https://omoindrot.github.io/triplet-loss)
For each anchor, we get the hardest positive and hardest negative to form a triplet.
"""
def __init__(self, margin=0.3, hardest=False, squared=False):
"""
Args:
margin: margin for triplet loss
hardest: If true, loss is considered only hardest triplets.
squared: If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
"""
super(HardTripletLoss, self).__init__()
self.margin = margin
self.hardest = hardest
self.squared = squared
def forward(self, embeddings, labels):
"""
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)
if self.hardest:
# Get the hardest positive pairs
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
valid_positive_dist = pairwise_dist * mask_anchor_positive
hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)
# Get the hardest negative pairs
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
1.0 - mask_anchor_negative)
hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1)
triplet_loss = torch.mean(triplet_loss)
else:
anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
anc_neg_dist = pairwise_dist.unsqueeze(dim=1)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
loss = anc_pos_dist - anc_neg_dist + self.margin
mask = _get_triplet_mask(labels).float()
triplet_loss = loss * mask
# Remove negative losses (i.e. the easy triplets)
triplet_loss = F.relu(triplet_loss)
# Count number of hard triplets (where triplet_loss > 0)
hard_triplets = torch.gt(triplet_loss, 1e-16).float()
num_hard_triplets = torch.sum(hard_triplets)
triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16)
return triplet_loss
def _pairwise_distance(x, squared=False, eps=1e-16):
# Compute the 2D matrix of distances between all the embeddings.
cor_mat = torch.matmul(x, x.t())
norm_mat = cor_mat.diag()
distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
distances = F.relu(distances)
if not squared:
mask = torch.eq(distances, 0.0).float()
distances = distances + mask * eps
distances = torch.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
def _get_anchor_positive_triplet_mask(labels):
# Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1
# Check if labels[i] == labels[j]
labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
mask = indices_not_equal * labels_equal
return mask
def _get_anchor_negative_triplet_mask(labels):
# Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
# Check if labels[i] != labels[k]
labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
mask = labels_equal ^ 1
return mask
def _get_triplet_mask(labels):
"""return a 3d mask where mask[a, p, n] is true if the triplet (a, p, n) is valid.
a triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Check that i, j and k are distinct
indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1
i_not_equal_j = torch.unsqueeze(indices_not_same, 2)
i_not_equal_k = torch.unsqueeze(indices_not_same, 1)
j_not_equal_k = torch.unsqueeze(indices_not_same, 0)
distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k
# Check if labels[i] == labels[j] and labels[i] != labels[k]
label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
i_equal_j = torch.unsqueeze(label_equal, 2)
i_equal_k = torch.unsqueeze(label_equal, 1)
valid_labels = i_equal_j * (i_equal_k ^ 1)
mask = distinct_indices * valid_labels # Combine the two masks
return mask
class FocalLoss(nn.Module):
def __init__(self, gamma=2, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
#print(self.gamma)
self.eps = eps
self.ce = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
return loss.mean()
def ce_loss(input, target, ohem=0.0):
if ohem == 0:
return F.cross_entropy(input, target)
else:
loss = F.cross_entropy(input, target, reduction=None)
value, index= loss.topk(int(13619 * ohem), dim=1, largest=True, sorted=True)
return value.mean()