本文对fast re-ID 中triplet loss 的应用进行尝试,同时加入对于源代码的理解
源代码
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from .utils import euclidean_dist, cosine_dist
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] #return max values for every img in batch; size=(N,)
diff = dist - max_v
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
W = torch.exp(diff) * mask / Z
return W
def hard_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
dist_ap, _ = torch.max(dist_mat * is_pos, dim=1) #torch.max return (values, indices)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)
return dist_ap, dist_an
def weighted_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
is_pos = is_pos
is_neg = is_neg
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos) # weights go higher when dist goes up
weights_an = softmax_weights(-dist_an, is_neg) # weights go higher when dist goes down
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
dist_an = torch.sum(dist_an * weights_an, dim=1)
return dist_ap, dist_an
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
if norm_feat:
dist_mat = cosine_dist(embedding, embedding)
else:
dist_mat = euclidean_dist(embedding, embedding)
# For distributed training, gather all features from different process.
# if comm.get_world_size() > 1:
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
# all_targets = concat_all_gather(targets)
# else:
# all_embedding = embedding
# all_targets = targets
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
y = dist_an.new().resize_as_(dist_an).fill_(1) # y = [1,....] seize = (N,)
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
return loss
测试代码
import torch
from fastreid.modeling.losses import triplet_loss
predict_feat = torch.tensor([[1.0, 1.1, 0.0],[1.0, 1.2, 0.0], [1.2, 1.1, 0.22], [2.0, 0.0,1.0]]).float()
gt_labels = torch.tensor([0,1,1,0])
loss1 = triplet_loss(predict_feat, gt_labels, margin=0.3, norm_feat=False, hard_mining=True)
loss2 = triplet_loss(predict_feat, gt_labels, margin=0, norm_feat=False, hard_mining=False)
输出结果在结果分析中
结果分析
主函数triplet loss
```
Args
---------------
embedding: predict_feat (size=[batch_size, feature_size])
targets: gt_labels (size=[batch_size])
margin: default=0.3
norm_feat: default=False
hard_mining: if Ture use hard_mining else use soft_weight_mining
```
hard_example_mining
- hard positive 与 hard negative 定义:与anchor 相比,所有pos中距离最远的feature 为hard_positive; 与anchor相比,所有neg中距离最近的feature为hard_positive.
- 此函数对于
anchor
即batch
中的每一个feature
只取dist_ap = max(all dist_ap), dist_an = min(all dist_an)
- loss1 输出
is_pos=tensor([[1., 0., 0., 1.], [0., 1., 1., 0.], [0., 1., 1., 0.], [1., 0., 0., 1.]]) is_neg=tensor([[0., 1., 1., 0.], [1., 0., 0., 1.], [1., 0., 0., 1.], [0., 1., 1., 0.]]) dist_mat=tensor([[1.0000e-06, 9.9999e-02, 2.9732e-01, 1.7916e+00], [9.9999e-02, 1.0000e-06, 3.1369e-01, 1.8547e+00], [2.9732e-01, 3.1369e-01, 1.0000e-06, 1.5679e+00], [1.7916e+00, 1.8547e+00, 1.5679e+00, 1.0000e-06]]) dist_ap=tensor([1.7916, 0.3137, 0.3137, 1.7916]) dist_an=tensor([0.1000, 0.1000, 0.2973, 1.5679])
weighted_example_mining
- 在计算
dist_ap
过程中,考虑到所有的pos
例子。但是,利用soft_weight
对所有的dist_ap
做一个权值的计算。最终,dist_ap=sum(soft_weights_ap* all dist_ap)
。距离越远的pos
获得的权值越高。 - 在计算
dist_an
过程中,考虑到所有的neg
例子。但是,利用soft_weight
对所有的dist_an
做一个权值的计算。最终,dist_an=sum(soft_weights_an* all dist_an)
。距离越近的neg
获得的权值越高。 - loss2 输出
weights_ap=tensor([[0.1429, 0.0000, 0.0000, 0.8571], [0.0000, 0.4222, 0.5778, 0.0000], [0.0000, 0.5778, 0.4222, 0.0000], [0.8571, 0.0000, 0.0000, 0.1429]]) weights_an=tensor([[0.0000, 0.5492, 0.4508, 0.0000], [0.8525, 0.0000, 0.0000, 0.1475], [0.7808, 0.0000, 0.0000, 0.2192], [0.0000, 0.4288, 0.5712, 0.0000]]) dist_ap=tensor([1.5357, 0.1812, 0.1812, 1.5357]) dist_an=tensor([0.1890, 0.3587, 0.5758, 1.6909])
最终loss 计算
-
torch.nn.functional.margin_ranking_loss
主函数中为loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
即为
l o s s = m a x ( 0 , − ( d i s t _ a n − d i s t _ a p ) + 0.3 ) loss=max(0, -(dist\_an - dist\_ap)+0.3) loss=max(0,−(dist_an−dist_ap)+0.3) -
torch.nn.functional.soft_margin_loss
主函数中为loss = F.soft_margin_loss(dist_an - dist_ap, y)
即为
l o s s = ∑ i l o g ( 1 + e x p ( − ( d i s t _ a n − d i s t _ a p ) ) b a t c h _ s i z e loss = \sum_i \frac{log(1+exp(-(dist\_an - dist\_ap))}{batch\_size} loss=i∑batch_sizelog(1+exp(−(dist_an−dist_ap)) -
输出
loss1=tensor(0.8364) loss2=tensor(0.8300)