LONG-TAILED RECOGNITION 精读

博客聚焦机器学习中类别不平衡问题,介绍解决思路,如重采样、设计损失函数、知识迁移等。阐述多种采样策略,包括标准基于实例采样、类别平衡采样等,还提及不同损失函数、分类器及作者实验结论,最后介绍了采样策略更新、额外特征等内容。

BackGround

解决类别不平衡问题一般的思路:

  1. re-sample the data 重采样
  2. design specific loss functions that better facilitate learning with imbalanced data 设计针对不平衡数据的损失函数
  3. enhance recognition performance of the tail classes by transferring knowledge from the head classes 知识迁移
    • 解释一下:通过从头部类别中转移知识来提高尾部类别的识别性能。在机器学习和模式识别领域,通常存在一些类别的样本数量较少,被称为尾部类别;而另一些类别的样本数量较多,被称为头部类别。由于尾部类别样本数量的不足,其识别性能往往较低。因此,通过从头部类别中学习到的知识,可以帮助提高尾部类别的识别性能。这种方法通常被称为知识迁移。

下面直奔主题:decouple the learning procedure into representation learning and classification
即,解耦成两个部分:表征学习和分类器。

  • 表征学习部分:the model is exposed to the training instances and trained through different sampling strategies or losses.
  • 分类器部分:upon the learned representations, the model recognizes the long-tailed classes through various classifiers

采样策略

  1. standard instance-based sampling
    标准基于实例的采样(standard instance-based sampling)是一种常用的采样策略,用于解决类别不平衡问题。在类别不平衡的数据集中,某些类别的样本数量远远多于其他类别,这可能导致模型对于少数类别的识别性能较差。

    标准基于实例的采样通过对数据集中的样本进行重采样,以平衡不同类别的样本数量。具体而言,该采样策略会从多数类别中随机选择一定数量的样本,使其数量与少数类别相当。这样可以避免模型过于偏向多数类别,提高对少数类别的识别能力。

    标准基于实例的采样通常有两种方式:欠采样(undersampling)和过采样(oversampling)。

    • 欠采样:从多数类别中删除一些样本,使其数量与少数类别相当。这种方法可能会导致信息丢失,因为删除了一些多数类别的样本。但它可以减少多数类别的影响,使得模型更关注少数类别。
    • 过采样:通过复制或生成新的样本,使得多数类别的样本数量与少数类别相当。这种方法可能会导致过拟合,因为复制或生成的样本可能过于接近多数类别的分布。但它可以增加少数类别的样本数量,提高模型对少数类别的学习能力。
  2. class-balanced sampling
    类别平衡采样的目标是通过调整样本的采样权重,使得每个类别在采样过程中都有相似的机会被选中。这样可以避免模型过于偏向多数类别,提高对少数类别的学习能力。

  3. 以上两种方法混合

Class-balanced Losses

  1. Focal loss
  2. Meta-Weight-Net
  3. re-weighted training
  4. based on Bayesian uncertainty
  5. to balance the classification regions of head and tail classes using an affinity measure to enforce cluster centers of classes to be uniformly spaced and equidistant

Transfer learning from head- to tail classes

transferring features learned from head classes with abundant
training instances to under-represented tail classes

  1. transferring the intra-class variance
  2. transferring semantic deep features
  • However it is usually a non-trivial task to design specific modules (e.g. external memory) for feature transfer
  1. low-shot recognition,the setup for long-tail recognition assumes access to both head and tail classes and a more continuous decrease in in class labels
  2. adopt re-balancing schedules that learn representation and classifier jointly within a two-stage training scheme
  3. OLTR:uses instance-balanced sampling to first learn
    representations that are fine-tuned in a second stage with class-balanced sampling together with a
    memory module
  4. LDAM:a label-distribution-aware margin loss that
    expands the decision boundaries of few-shot classes.

分类器

  1. re-training the parametric linear classifier(预训练的线性分类器) in a class-balancing manner (i.e., re-sampling)
  2. non-parametric nearest class mean classifier(KNN), which classifies the data based on their closest class-specific mean representations from the training set
  3. normalizing the classifier weights(分类器权重进行归一化,加入温度参数), which adjusts the weight magnitude directly to be more balanced, adding a temperature to modulate the normalization procedure

作者的实验结论

  1. instance-balanced sampling learns the best and most generalizable representations.
  2. 重新调整由联合学习分类器指定的决策边界是有优势的,这可以通过表示学习期间重新训练分类器使用类别平衡采样,或者通过简单而有效的分类器权重归一化来实现。我们的实验结果表明,这两种方法都可以实现决策边界的调整,并且它们都只有一个超参数来控制“温度”,而且不需要额外的训练。
  3. By applying the decoupled learning scheme to standard networks (e.g., ResNeXt), we
    achieve significantly higher accuracy than well established state-of-the-art methods (different sampling strategies, new loss designs and other complex modules) on multiple longtailed recognition benchmark datasets, including ImageNet-LT, Places-LT, and iNaturalist

Start

采样策略上的更新:

  • 均方根采样
    在这里插入图片描述

  • 调节采样
    在这里插入图片描述

  • Loss re-weighting strategies


  • Classifier Re-training (cRT)

  • Nearest Class Mean classifier (NCM)

  • τ -normalized classifier (τ -normalized)
    在这里插入图片描述

  • Learnable weight scaling (LWS)
    在这里插入图片描述


Extra Feature

  1. ResNet/ + modulatedatt -> x , featuremap
  • ResNet 残差网络就不提了
  • ResNext 则使用了类似的残差块,但在残差块内部引入了分组卷积(grouped convolution)的概念,通过增加卷积层的宽度来增加模型的表示能力。
  • ModulatedAttLayer

这个layer是一个模块化的注意力层,用于提取输入特征图的重要信息。它通过计算输入特征图的不同部分之间的关联性,以及对特征图进行空间注意力调整,来增强特征图中的有用信息。

  • 具体来说,该层包括三个卷积操作(g、theta、phi)和一个卷积操作(conv_mask)。它通过计算输入特征图的三个不同表示(g、theta、phi),并使用这些表示计算特征图之间的关联性(map_t_p),然后将关联性与输入特征图中的空间注意力调整(mask)相乘,得到最终的特征图输出(final)。
  • 该层还包括一个全连接层(fc_spatial),用于计算特征图的空间注意力权重(spatial_att)。最后,该层返回特征图输出和中间结果(x、spatial_att、mask)作为输出。
  • 在这个模块化的注意力层中,卷积层的kernel_size被设置为1是为了保持特征图的空间维度不变。这是因为在这个层中,主要关注的是特征图之间的关联性和空间注意力调整,而不是特征图的空间维度变化。

  • 因此,通过将kernel_size设置为1,可以确保卷积操作只对特征图的通道维度进行操作,而不改变其空间维度。这样可以保持特征图的空间结构不变,使得后续的操作能够更好地利用特征图中的信息。

import torch
from torch import nn
from torch.nn import functional as F
import pdb

class ModulatedAttLayer(nn.Module):

    def __init__(self, in_channels, reduction = 2, mode='embedded_gaussian'):
        super(ModulatedAttLayer, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.inter_channels = in_channels // reduction
        self.mode = mode
        assert mode in ['embedded_gaussian']

        self.g = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.theta = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.phi = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.conv_mask = nn.Conv2d(self.inter_channels, self.in_channels, kernel_size = 1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc_spatial = nn.Linear(7 * 7 * self.in_channels, 7 * 7)

        self.init_weights()

    def init_weights(self):
        msra_list = [self.g, self.theta, self.phi]
        for m in msra_list:
            nn.init.kaiming_normal_(m.weight.data)
            m.bias.data.zero_()
        self.conv_mask.weight.data.zero_()

    def embedded_gaussian(self, x):
        # embedded_gaussian cal self-attention, which may not strong enough
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        map_t_p = torch.matmul(theta_x, phi_x)
        mask_t_p = F.softmax(map_t_p, dim=-1)

        map_ = torch.matmul(mask_t_p, g_x)
        map_ = map_.permute(0, 2, 1).contiguous()
        map_ = map_.view(batch_size, self.inter_channels, x.size(2), x.size(3))
        mask = self.conv_mask(map_)
        
        x_flatten = x.view(-1, 7 * 7 * self.in_channels)

        spatial_att = self.fc_spatial(x_flatten)
        spatial_att = spatial_att.softmax(dim=1)
        
        spatial_att = spatial_att.view(-1, 7, 7).unsqueeze(1)
        spatial_att = spatial_att.expand(-1, self.in_channels, -1, -1)

        final = spatial_att * mask + x

        return final, [x, spatial_att, mask]

    def forward(self, x):
        if self.mode == 'embedded_gaussian':
            output, feature_maps = self.embedded_gaussian(x)
        else:
            raise NotImplemented("The code has not been implemented.")
        return output, feature_maps

Classifier

KNNClassifier

"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""


import torch
import torch.nn as nn
import numpy as np
import pickle
from os import path

class KNNClassifier(nn.Module):
    def __init__(self, feat_dim=512, num_classes=1000, feat_type='cl2n', dist_type='l2'):
        super(KNNClassifier, self).__init__()
        assert feat_type in ['un', 'l2n', 'cl2n'], "feat_type is wrong!!!"
        assert dist_type in ['l2', 'cos'], "dist_type is wrong!!!"
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.centroids = torch.randn(num_classes, feat_dim)
        self.feat_mean = torch.randn(feat_dim)
        self.feat_type = feat_type
        self.dist_type = dist_type
        self.initialized = False
    
    def update(self, cfeats):
        mean = cfeats['mean']
        centroids = cfeats['{}cs'.format(self.feat_type)]

        mean = torch.from_numpy(mean)
        centroids = torch.from_numpy(centroids)
        self.feat_mean.copy_(mean)
        self.centroids.copy_(centroids)
        if torch.cuda.is_available():
            self.feat_mean = self.feat_mean.cuda()
            self.centroids = self.centroids.cuda()
        self.initialized = True

    def forward(self, inputs, *args):
        centroids = self.centroids
        feat_mean = self.feat_mean

        # Feature transforms
        if self.feat_type == 'cl2n':
            inputs = inputs - feat_mean
            #centroids = centroids - self.feat_mean

        if self.feat_type in ['l2n', 'cl2n']:
            norm_x = torch.norm(inputs, 2, 1, keepdim=True)
            inputs = inputs / norm_x

            #norm_c = torch.norm(centroids, 2, 1, keepdim=True)
            #centroids = centroids / norm_c
        
        # Logit calculation
        if self.dist_type == 'l2':
            logit = self.l2_similarity(inputs, centroids)
        elif self.dist_type == 'cos':
            logit = self.cos_similarity(inputs, centroids)
        
        return logit, None

    def l2_similarity(self, A, B):
        # input A: [bs, fd] (batch_size x feat_dim)
        # input B: [nC, fd] (num_classes x feat_dim)
        feat_dim = A.size(1)

        AB = torch.mm(A, B.t())
        AA = (A**2).sum(dim=1, keepdim=True)
        BB = (B**2).sum(dim=1, keepdim=True)
        dist = AA + BB.t() - 2*AB

        return -dist
    
    def cos_similarity(self, A, B):
        feat_dim = A.size(1)
        AB = torch.mm(A, B.t())
        AB = AB / feat_dim
        return AB


def create_model(feat_dim, num_classes=1000, feat_type='cl2n', dist_type='l2',
                 log_dir=None, test=False, *args):
    print('Loading KNN Classifier')
    print(feat_dim, num_classes, feat_type, dist_type, log_dir, test)
    clf = KNNClassifier(feat_dim, num_classes, feat_type, dist_type)

    if log_dir is not None:
        fname = path.join(log_dir, 'cfeats.pkl')
        if path.exists(fname):
            print('===> Loading features from %s' % fname)
            with open(fname, 'rb') as f:
                data = pickle.load(f)
            clf.update(data)
    else:
        print('Random initialized classifier weights.')
    
    return clf


if __name__ == "__main__":
    cens = np.eye(4)
    mean = np.ones(4)
    xs = np.array([
        [0.9, 0.1, 0.0, 0.0],
        [0.2, 0.1, 0.1, 0.6],
        [0.3, 0.3, 0.4, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.25, 0.25, 0.25, 0.25]
    ])
    xs = torch.Tensor(xs)

    classifier = KNNClassifier(feat_dim=4, num_classes=4, 
                               feat_type='un')
    classifier.update(mean, cens)
    import pdb; pdb.set_trace()
    logits, _ = classifier(xs)

Dot_Classifier

"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""


import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

from utils import *
from os import path

class DotProduct_Classifier(nn.Module):
    
    def __init__(self, num_classes=1000, feat_dim=2048, *args):
        super(DotProduct_Classifier, self).__init__()
        # print('<DotProductClassifier> contains bias: {}'.format(bias))
        self.fc = nn.Linear(feat_dim, num_classes)
        self.scales = Parameter(torch.ones(num_classes))
        for param_name, param in self.fc.named_parameters():
            param.requires_grad = False
        
    def forward(self, x, *args):
        x = self.fc(x)
        x *= self.scales
        return x, None
    
def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args):
    print('Loading Dot Product Classifier.')
    clf = DotProduct_Classifier(num_classes, feat_dim)

    if not test:
        if stage1_weights:
            assert(dataset)
            print('Loading %s Stage 1 Classifier Weights.' % dataset)
            if log_dir is not None:
                subdir = log_dir.strip('/').split('/')[-1]
                subdir = subdir.replace('stage2', 'stage1')
                weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir)
                # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1')
            else:
                weight_dir = './logs/%s/stage1' % dataset
            print('==> Loading classifier weights from %s' % weight_dir)
            clf.fc = init_weights(model=clf.fc,
                                  weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'),
                                  classifier=True)
        else:
            print('Random initialized classifier weights.')

    return clf

Loss

DiscCentroidsLoss

class DiscCentroidsLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(DiscCentroidsLoss, self).__init__()
        self.num_classes = num_classes
        self.centroids = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.disccentroidslossfunc = DiscCentroidsLossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, feat, label):
        batch_size = feat.size(0)
        
        # calculate attracting loss

        feat = feat.view(batch_size, -1)
        # To check the dim of centroids and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's \
                            dim: {1}".format(self.feat_dim,feat.size(1)))
        batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
        loss_attract = self.disccentroidslossfunc(feat, label, self.centroids, batch_size_tensor).squeeze()
        
        # calculate repelling loss

        distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centroids, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, feat, self.centroids.t())

        classes = torch.arange(self.num_classes).long().cuda()
        labels_expand = label.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels_expand.eq(classes.expand(batch_size, self.num_classes))

        distmat_neg = distmat
        distmat_neg[mask] = 0.0
        # margin = 50.0
        margin = 10.0
        loss_repel = torch.clamp(margin - distmat_neg.sum() / (batch_size * self.num_classes), 0.0, 1e6)

        # loss = loss_attract + 0.05 * loss_repel
        loss = loss_attract + 0.01 * loss_repel

        return loss
 
 
class DiscCentroidsLossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centroids, batch_size):
        ctx.save_for_backward(feature, label, centroids, batch_size)
        centroids_batch = centroids.index_select(0, label.long())
        return (feature - centroids_batch).pow(2).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centroids, batch_size = ctx.saved_tensors
        centroids_batch = centroids.index_select(0, label.long())
        diff = centroids_batch - feature
        # init every iteration
        counts = centroids.new_ones(centroids.size(0))
        ones = centroids.new_ones(label.size(0))
        grad_centroids = centroids.new_zeros(centroids.size())

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centroids.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centroids = grad_centroids/counts.view(-1, 1)
        return - grad_output * diff / batch_size, None, grad_centroids / batch_size, None


    
def create_loss (feat_dim=512, num_classes=1000):
    print('Loading Discriminative Centroids Loss.')
    return DiscCentroidsLoss(num_classes, feat_dim)
import torch.nn as nn

def create_loss ():
    print('Loading Softmax Loss.')
    return nn.CrossEntropyLoss()

采样

  • 循环采样

循环遍历不同类别的样本,并按照指定数量从每个类别中取样本。它使用了两个迭代器,一个用于循环遍历不同类别的样本,另一个用于循环遍历每个类别中的样本。在每次迭代中,它从当前类别的样本迭代器中取出指定数量的样本,并返回给调用者。当遍历完所有类别的样本后,它会重新随机打乱类别顺序,以便下一轮迭代时可以重新遍历样本。

  • 样本重要性优先采样
class PriorityTree(object):
    def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0,
                 alpha=1.0):
        """
        fixed_weights: weights that wont be updated by self.update()
        """
        assert fixed_weights is None or len(fixed_weights) == capacity
        assert len(init_weights) == capacity
        self.alpha = alpha
        self._capacity = capacity
        self._tree_size = 2 * capacity - 1
        self.fixed_scale = fixed_scale
        self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \
                             else fixed_weights
        self.tree = np.zeros(self._tree_size)
        self._initialized = False
        self.initialize(init_weights)

    def initialize(self, init_weights):
        """Initialize the tree."""

        # Rescale the fixed_weights if it is not zero
        self.fixed_scale_init = self.fixed_scale
        if self.fixed_weights.sum() > 0 and init_weights.sum() > 0:
            self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum()
            self.fixed_weights *= self.fixed_scale * init_weights.sum() \
                                / self.fixed_weights.sum()
        print('FixedWeights: {}'.format(self.fixed_weights.sum()))

        self.update_whole(init_weights + self.fixed_weights)
        self._initialized = True
    
    def reset_adaptive_weights(self, adaptive_weights):
        self.update_whole(self.fixed_weights + adaptive_weights)
    
    def reset_fixed_weights(self, fixed_weights, rescale=False):
        """ Reset the manually designed weights and 
            update the whole tree accordingly.

            @rescale: rescale the fixed_weights such that 
            fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum()
        """

        adaptive_weights = self.get_adaptive_weights()
        fixed_sum = fixed_weights.sum()
        if rescale and fixed_sum > 0:
            # Rescale fixedweight based on adaptive weights
            scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum
        else:
            # Rescale fixedweight based on previous fixedweight
            scale = self.fixed_weights.sum() / fixed_sum
        self.fixed_weights = fixed_weights * scale
        self.update_whole(self.fixed_weights + adaptive_weights)
    
    def update_whole(self, total_weights):
        """ Update the whole tree based on per-example sampling weights """
        if self.alpha != 1:
            total_weights = np.power(total_weights, self.alpha)
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        self.tree[lefti:righti+1] = total_weights

        # Iteratively find a parent layer
        while lefti != 0 and righti != 0:
            lefti = (lefti - 1) // 2 if lefti != 0 else 0
            righti = (righti - 1) // 2 if righti != 0 else 0
            
            # Assign paraent weights from right to left
            for i in range(righti, lefti-1, -1):
                self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2]
    
    def get_adaptive_weights(self):
        """ Get the instance-aware weights, that are not mannually designed"""
        if self.alpha == 1:
            return self.get_total_weights() - self.fixed_weights
        else:
            return self.get_raw_total_weights() - self.fixed_weights
    
    def get_total_weights(self):
        """ Get the per-example sampling weights
            return shape: [capacity]
        """
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        return self.tree[lefti:righti+1]

    def get_raw_total_weights(self):
        """ Get the per-example sampling weights
            return shape: [capacity]
        """
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        return np.power(self.tree[lefti:righti+1], 1/self.alpha)

    @property
    def size(self):
        return self._tree_size

    @property
    def capacity(self):
        return self._capacity

    def __len__(self):
        return self.capacity

    def pointer_to_treeidx(self, pointer):
        assert pointer < self.capacity
        return int(pointer + self.capacity - 1)

    def update(self, pointer, priority):
        assert pointer < self.capacity
        tree_idx = self.pointer_to_treeidx(pointer)
        priority += self.fixed_weights[pointer]
        if self.alpha != 1:
            priority = np.power(priority, self.alpha)
        delta = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += delta
    
    def update_delta(self, pointer, delta):
        assert pointer < self.capacity
        tree_idx = self.pointer_to_treeidx(pointer)
        ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx]
        # delta *= ratio
        if self.alpha != 1:
            # Update delta
            if self.tree[tree_idx] < 0 or \
                np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0:
                import pdb; pdb.set_trace()
            delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta,
                             self.alpha) \
                  - self.tree[tree_idx]
        self.tree[tree_idx] += delta
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += delta

    def get_leaf(self, value):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        assert self.total > 0, 'No priority weights setted!!'
        parent = 0
        while True:
            left_child = 2 * parent + 1
            right_child = 2 * parent + 2
            if left_child >= len(self.tree):
                tgt_leaf = parent
                break
            if value < self.tree[left_child]:
                parent = left_child
            else:
                value -= self.tree[left_child]
                parent = right_child
        data_idx = tgt_leaf - self.capacity + 1
        return data_idx, self.tree[tgt_leaf]        # data idx, priority

    @property
    def total(self):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        return self.tree[0]

    @property
    def max(self):
        return np.max(self.tree[-self.capacity:])

    @property
    def min(self):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        return np.min(self.tree[-self.capacity:])
    
    def get_weights(self):
        wdict = {'fixed_weights': self.fixed_weights, 
                 'total_weights': self.get_total_weights()}
        if self.alpha != 1:
            wdict.update({'raw_total_weights': self.get_raw_total_weights(),
                          'alpha': self.alpha})

        return wdict
### 解耦表示与分类器的实现细节 在论文 *Decoupling Representation and Classifier for Long-Tailed Recognition* 中,作者提出了一种解耦特征表示和分类器的训练策略,以应对长尾数据分布带来的类别不平衡问题。论文的实现细节主要围绕两个阶段展开:**联合训练阶段** 和 **解耦训练阶段**。 #### 联合训练阶段(Joint Training) 在联合训练阶段,网络的特征提取器(backbone)和分类器(classifier)同时进行训练。此阶段的目标是学习一个能够区分各类别的通用特征表示,同时训练一个初步的分类器。训练过程中使用了不同的采样策略,包括: - **实例平衡采样(Instance-balanced sampling)**:对所有样本均匀采样。 - **类别平衡采样(Class-balanced sampling)**:对每个类别采样相同数量的样本。 - **平方根采样(Square root sampling)**:每个类别的采样概率与其样本数的平方根成正比。 - **渐进平衡采样(Progressive-balanced sampling)**:先使用实例平衡采样,再切换到类别平衡采样。 损失函数使用的是标准的交叉熵损失(Cross-Entropy Loss),即: ```python criterion = nn.CrossEntropyLoss() ``` 在 PyTorch 中,联合训练的代码结构大致如下: ```python for epoch in range(joint_epochs): for images, labels in train_loader: features = backbone(images) outputs = classifier(features) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ``` #### 解耦训练阶段(Decoupled Training) 在解耦训练阶段,特征提取器被冻结,仅对分类器进行微调。该阶段的目的是通过调整分类器来适应长尾数据分布,而不影响已经学习到的特征表示。 ##### 1. **类别均值分类器(NCM:Nearest Class Mean)** 在该策略中,首先计算每个类别的特征均值,然后在测试阶段使用最近邻方法(如欧氏距离或余弦相似度)进行分类。 ```python # 计算每个类别的特征均值 class_means = compute_class_means(backbone, train_loader) # 测试阶段使用最近邻分类 def predict(feature): distances = [torch.norm(feature - mean) for mean in class_means] return torch.argmin(distances) ``` ##### 2. **τ-归一化分类器(τ-normalized classifier)** τ-归一化分类器通过归一化分类器权重来调整决策边界,缓解类别不平衡问题。归一化后的权重计算如下: $$ \tilde{w_i} = \frac{w_i}{||w_i||^{\tau}} $$ 其中 $\tau$ 是一个可调参数,通常设置为 1,即 L2 归一化。 在 PyTorch 中,τ-归一化可以通过以下方式实现: ```python def apply_tau_normalization(classifier, tau=1.0): with torch.no_grad(): weights = classifier.weight norm = torch.norm(weights, dim=1, keepdim=True) normalized_weights = weights / (norm ** tau) classifier.weight.copy_(normalized_weights) ``` 随后,使用类别均等采样重新训练分类器,通常不使用偏置项: ```python for epoch in range(decoupled_epochs): for images, labels in class_balanced_loader: features = backbone(images).detach() # 冻结特征提取器 outputs = classifier(features) loss = criterion(outputs, labels) optimizer_decoupled.zero_grad() loss.backward() optimizer_decoupled.step() ``` ##### 3. **LWS(Learnable Weight Scaling)** LWS 是 τ-归一化的扩展,其中缩放因子 $f_i$ 是可学习参数。该方法通过在解耦阶段学习每个类别的缩放因子来优化分类器的决策边界。 ```python # 定义可学习的缩放因子 scaling_factors = nn.Parameter(torch.ones(num_classes)) # 在损失反向传播时,仅更新 scaling_factors optimizer_lws = torch.optim.SGD([scaling_factors], lr=lw_lr) for epoch in range(lws_epochs): for images, labels in class_balanced_loader: features = backbone(images).detach() logits = classifier(features) scaled_logits = logits * scaling_factors.unsqueeze(0) loss = criterion(scaled_logits, labels) optimizer_lws.zero_grad() loss.backward() optimizer_lws.step() ``` #### 代码结构总结 整个训练流程可以分为以下几个步骤: 1. **联合训练特征提取器和分类器**,使用不同的采样策略。 2. **冻结特征提取器**,进入解耦阶段。 3. **使用类别均等采样重新训练分类器**,采用 τ-归一化或 LWS 策略。 4. **在测试集上评估模型性能**,使用 NCM、τ-归一化或 LWS 分类器。 --- ###
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值