Triplet Loss 学习笔记, 以Fast re-ID中代码为例

本文详细解析了fastre-ID中tripletloss的使用,包括源代码解析、hard_example_mining和weighted_example_mining两种采样策略,并通过实例展示了不同策略下loss的计算过程。该博客有助于深入理解深度学习中的人脸识别和行人重识别技术。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文对fast re-ID 中triplet loss 的应用进行尝试,同时加入对于源代码的理解

源代码

# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""

import torch
import torch.nn.functional as F

from .utils import euclidean_dist, cosine_dist


def softmax_weights(dist, mask):
    max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]  #return max values for every img in batch; size=(N,)
    diff = dist - max_v
    Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6  # avoid division by zero
    W = torch.exp(diff) * mask / Z
    return W


def hard_example_mining(dist_mat, is_pos, is_neg):
    """For each anchor, find the hardest positive and negative sample.
    Args:
      dist_mat: pair wise distance between samples, shape [N, M]
      is_pos: positive index with shape [N, M]
      is_neg: negative index with shape [N, M]
    Returns:
      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
      p_inds: pytorch LongTensor, with shape [N];
        indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
      n_inds: pytorch LongTensor, with shape [N];
        indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
    NOTE: Only consider the case in which all labels have same num of samples,
      thus we can cope with all anchors in parallel.
    """

    assert len(dist_mat.size()) == 2

    # `dist_ap` means distance(anchor, positive)
    # both `dist_ap` and `relative_p_inds` with shape [N]
    dist_ap, _ = torch.max(dist_mat * is_pos, dim=1) #torch.max return (values, indices)
    # `dist_an` means distance(anchor, negative)
    # both `dist_an` and `relative_n_inds` with shape [N]
    dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)

    return dist_ap, dist_an


def weighted_example_mining(dist_mat, is_pos, is_neg):
    """For each anchor, find the weighted positive and negative sample.
    Args:
      dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
      is_pos:
      is_neg:
    Returns:
      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    """
    assert len(dist_mat.size()) == 2

    is_pos = is_pos
    is_neg = is_neg
    dist_ap = dist_mat * is_pos
    dist_an = dist_mat * is_neg

    weights_ap = softmax_weights(dist_ap, is_pos) # weights go higher when dist goes up
    weights_an = softmax_weights(-dist_an, is_neg) # weights go higher when dist goes down 

    dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
    dist_an = torch.sum(dist_an * weights_an, dim=1)

    return dist_ap, dist_an


def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
    r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
    Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
    Loss for Person Re-Identification'."""

    if norm_feat:
        dist_mat = cosine_dist(embedding, embedding)
    else:
        dist_mat = euclidean_dist(embedding, embedding)

    # For distributed training, gather all features from different process.
    # if comm.get_world_size() > 1:
    #     all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
    #     all_targets = concat_all_gather(targets)
    # else:
    #     all_embedding = embedding
    #     all_targets = targets

    N = dist_mat.size(0)
    is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
    is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()

    if hard_mining:
        dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
    else:
        dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)

    y = dist_an.new().resize_as_(dist_an).fill_(1)  # y = [1,....] seize = (N,)

    if margin > 0:
        loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
    else:
        loss = F.soft_margin_loss(dist_an - dist_ap, y)
        # fmt: off
        if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
        # fmt: on

    return loss

测试代码

import torch
from fastreid.modeling.losses import triplet_loss

predict_feat = torch.tensor([[1.0, 1.1, 0.0],[1.0, 1.2, 0.0], [1.2, 1.1, 0.22], [2.0, 0.0,1.0]]).float()
gt_labels = torch.tensor([0,1,1,0])

loss1 = triplet_loss(predict_feat, gt_labels, margin=0.3, norm_feat=False, hard_mining=True)
loss2 = triplet_loss(predict_feat, gt_labels, margin=0, norm_feat=False, hard_mining=False)

输出结果在结果分析中

结果分析

主函数triplet loss

```
Args
---------------
embedding:  predict_feat  (size=[batch_size, feature_size])
targets: gt_labels (size=[batch_size])
margin: default=0.3
norm_feat: default=False
hard_mining: if Ture use hard_mining else use soft_weight_mining
```

hard_example_mining

  1. hard positive 与 hard negative 定义:与anchor 相比,所有pos中距离最远的feature 为hard_positive; 与anchor相比,所有neg中距离最近的feature为hard_positive.
  2. 此函数对于anchorbatch 中的每一个feature只取 dist_ap = max(all dist_ap), dist_an = min(all dist_an)
  3. loss1 输出
    is_pos=tensor([[1., 0., 0., 1.],
       			   [0., 1., 1., 0.],
        		   [0., 1., 1., 0.],
       		   	   [1., 0., 0., 1.]])
    is_neg=tensor([[0., 1., 1., 0.],
        		   [1., 0., 0., 1.],
     			   [1., 0., 0., 1.],
    			   [0., 1., 1., 0.]])
    dist_mat=tensor([[1.0000e-06, 9.9999e-02, 2.9732e-01, 1.7916e+00],
       				 [9.9999e-02, 1.0000e-06, 3.1369e-01, 1.8547e+00],
       				 [2.9732e-01, 3.1369e-01, 1.0000e-06, 1.5679e+00],
      			 	 [1.7916e+00, 1.8547e+00, 1.5679e+00, 1.0000e-06]])
    dist_ap=tensor([1.7916, 0.3137, 0.3137, 1.7916])
    dist_an=tensor([0.1000, 0.1000, 0.2973, 1.5679])
    

weighted_example_mining

  1. 在计算dist_ap 过程中,考虑到所有的pos 例子。但是,利用soft_weight 对所有的dist_ap 做一个权值的计算。最终,dist_ap=sum(soft_weights_ap* all dist_ap)。距离越远的pos 获得的权值越高。
  2. 在计算dist_an 过程中,考虑到所有的neg例子。但是,利用soft_weight 对所有的dist_an 做一个权值的计算。最终,dist_an=sum(soft_weights_an* all dist_an)。距离越近的neg获得的权值越高。
  3. loss2 输出
    weights_ap=tensor([[0.1429, 0.0000, 0.0000, 0.8571],
      				   [0.0000, 0.4222, 0.5778, 0.0000],
      				   [0.0000, 0.5778, 0.4222, 0.0000],
      				   [0.8571, 0.0000, 0.0000, 0.1429]])
    weights_an=tensor([[0.0000, 0.5492, 0.4508, 0.0000],
       				   [0.8525, 0.0000, 0.0000, 0.1475],
     				   [0.7808, 0.0000, 0.0000, 0.2192],
      				   [0.0000, 0.4288, 0.5712, 0.0000]])
    dist_ap=tensor([1.5357, 0.1812, 0.1812, 1.5357])
    dist_an=tensor([0.1890, 0.3587, 0.5758, 1.6909])
    

最终loss 计算

  1. torch.nn.functional.margin_ranking_loss
    MarginRankingLoss definition from pytorch docs
    主函数中为loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) 即为
    l o s s = m a x ( 0 , − ( d i s t _ a n − d i s t _ a p ) + 0.3 ) loss=max(0, -(dist\_an - dist\_ap)+0.3) loss=max(0,(dist_andist_ap)+0.3)

  2. torch.nn.functional.soft_margin_loss
    SoftMarginLoss definition from pytorch docs
    主函数中为loss = F.soft_margin_loss(dist_an - dist_ap, y) 即为
    l o s s = ∑ i l o g ( 1 + e x p ( − ( d i s t _ a n − d i s t _ a p ) ) b a t c h _ s i z e loss = \sum_i \frac{log(1+exp(-(dist\_an - dist\_ap))}{batch\_size} loss=ibatch_sizelog(1+exp((dist_andist_ap))

  3. 输出

    loss1=tensor(0.8364)
    loss2=tensor(0.8300)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值