《Decoupling Representation and Classifier》笔记

长尾识别的解耦表示与分类
本文探讨了解耦表示学习与分类在长尾识别问题中的应用,研究了不同采样策略对学习高质量表示的影响,以及如何通过调整分类器获得平衡决策边界,提出了包括随机采样、类别平衡采样等策略,以及重新训练分类器、最近类平均分类和归一化分类器等方法。
Paper:《Decoupling Representation and Classifier for Long-tailed Recognition》
Published at ICLR 2020
Keywords:Long-Tailed Image Recognition.

https://zhuanlan.zhihu.com/p/111518894

【概览】

  • 作者将分类网络分解为representation learning 和 classification 两部分,系统的研究了这两部分对于Long-tailed问题的影响。通过实验得到的两点发现是:
    • 数据不均衡问题不会影响高质量Representations的学习。即,random sampling策略往往会学到泛化性更好的representations;
    • 使用最简单的random sampling 来学习representations,然后只调整classifier的学习也可以作为一个strong baseline。
  • 具体地,作者首先使用不同的sampling策略(random sampling; class-balanced sampling; mixture of them )来训练representations,然后,研究了三种不同的方法来获得有balanced决策边界的classifier。分别为:使用class-balanced sampling来re-training线性分类器的参数;对学到的representations使用KNN进行分类;通过normalize classifier的weight来使得weight的尺度变得更加balanced,并添加了一个temperature参数来调节normalization的过程。

【Sampling策略】

  • 作者将三种Sampling策略统一为一个公式,即从类别 [公式] 中挑选一个样本的概率 [公式] 如公式1所示。其中, [公式] 代表的是类别 [公式] 的样本数量, [公式] 是类别数量, [公式] 。若 [公式] ,为random sampling,文中也称之为Instance-balanced sampling当 [公式] 的时候,为Class-balanced Samling;当 [公式] 时为Square-root Sampling
    [公式] (公式1)
  • 最近也有工作是在训练前期使用random-sampling策略,后期使用class-balanced sampling。由此,作者提出了一个Softer版本的结合方式,Progressively-balanced Sampling。如公式2所示。其中, [公式] 为当前训练的Epoch, [公式] 为总的训练Epoch。
    [公式] (公式2)

【Classifier的学习策略】

  • Classifier Re-training (cRT).
    • 固定住representations部分,随机初始化classifier的weight和bias参数,并使用class-balanced sampling在训练少量epoch。
  • Nearest Class Mean classifier (NCM).
    • 首先将training set里的每个类别计算feature representaitions的均值,然后在test set上执行最近邻查找。或者将mean features进行L2-Normalization之后,使用余弦距离或者欧氏距离计算相似度。作者指出,余弦相似度可以通过其本身的normalization特性来缓解weight imbalance的问题。
  • [公式]-normalized classifier ( [公式]-normalized).
    • 令 [公式] ,其中, [公式] 是类别 [公式] 的classifier权重。按照公式3对 [公式] 进行re-scale。其中, [公式] 代表的是L2 Norm,当 [公式] 时,就是标准的L2-Normalization;当 [公式] 时,表示没有进行scaling操作。 [公式] ,其值是通过cross-validation来选择的。
      [公式] (公式3)
  • Learnable weight scaling (LWS).
    • 如果将公式3写为公式4的形式,我们可以将 [公式] 看作是一个可学习的参数,我们通过固定住representations和classifier两部分的weighs来只学习这个scaling factors。
      [公式] (公式4)

【Experiments】

  • 为了更好地测试在训练期间每个类别样本数量对结果的影响,作者将类别集合分成三种情况。Many-shot (more than 100 images), Medium-shot (20100 images) and Few-shot (less than 20 images).
  • 图1展示的是ImageNet-LTwith ResNeXt-50上的指标。

Figure 1: The performance of different classifiers for each split on ImageNet-LT with ResNeXt-50.

  • 其中,Joint代表的是传统的representations和classifier结合训练的策略。由此,作者的几点发现包括:
    • Sampling matters when training jointly. 可见,使用更好地sampling策略可以显著提高Joint的性能。疑惑的是,第一个Many的实验中,作者解释出现这种情况的原因是,模型高度偏向many-shot的类别,但是为什么这样就会导致acc高呢?感觉作者的解释不是很好。
    • 保持 representation frozen训练classifier要比finetune效果好。如图2所示。和一些SOTA的模型比较如图3所示。其他的实验结果及分析具体可参见paper。

Figure 2.

Figure 3.

【参考资料】

[1]. https://arxiv.org/abs/1910.0921

### 解耦表示与分类器的实现细节 在论文 *Decoupling Representation and Classifier for Long-Tailed Recognition* 中,作者提出了一种解耦特征表示和分类器的训练策略,以应对长尾数据分布带来的类别不平衡问题。论文的实现细节主要围绕两个阶段展开:**联合训练阶段** 和 **解耦训练阶段**。 #### 联合训练阶段(Joint Training) 在联合训练阶段,网络的特征提取器(backbone)和分类器classifier)同时进行训练。此阶段的目标是学习一个能够区分各类别的通用特征表示,同时训练一个初步的分类器。训练过程中使用了不同的采样策略,包括: - **实例平衡采样(Instance-balanced sampling)**:对所有样本均匀采样。 - **类别平衡采样(Class-balanced sampling)**:对每个类别采样相同数量的样本。 - **平方根采样(Square root sampling)**:每个类别的采样概率与其样本数的平方根成正比。 - **渐进平衡采样(Progressive-balanced sampling)**:先使用实例平衡采样,再切换到类别平衡采样。 损失函数使用的是标准的交叉熵损失(Cross-Entropy Loss),即: ```python criterion = nn.CrossEntropyLoss() ``` 在 PyTorch 中,联合训练的代码结构大致如下: ```python for epoch in range(joint_epochs): for images, labels in train_loader: features = backbone(images) outputs = classifier(features) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ``` #### 解耦训练阶段(Decoupled Training) 在解耦训练阶段,特征提取器被冻结,仅对分类器进行微调。该阶段的目的是通过调整分类器来适应长尾数据分布,而不影响已经学习到的特征表示。 ##### 1. **类别均值分类器(NCM:Nearest Class Mean)** 在该策略中,首先计算每个类别的特征均值,然后在测试阶段使用最近邻方法(如欧氏距离或余弦相似度)进行分类。 ```python # 计算每个类别的特征均值 class_means = compute_class_means(backbone, train_loader) # 测试阶段使用最近邻分类 def predict(feature): distances = [torch.norm(feature - mean) for mean in class_means] return torch.argmin(distances) ``` ##### 2. **τ-归一化分类器(τ-normalized classifier)** τ-归一化分类器通过归一化分类器权重来调整决策边界,缓解类别不平衡问题。归一化后的权重计算如下: $$ \tilde{w_i} = \frac{w_i}{||w_i||^{\tau}} $$ 其中 $\tau$ 是一个可调参数,通常设置为 1,即 L2 归一化。 在 PyTorch 中,τ-归一化可以通过以下方式实现: ```python def apply_tau_normalization(classifier, tau=1.0): with torch.no_grad(): weights = classifier.weight norm = torch.norm(weights, dim=1, keepdim=True) normalized_weights = weights / (norm ** tau) classifier.weight.copy_(normalized_weights) ``` 随后,使用类别均等采样重新训练分类器,通常不使用偏置项: ```python for epoch in range(decoupled_epochs): for images, labels in class_balanced_loader: features = backbone(images).detach() # 冻结特征提取器 outputs = classifier(features) loss = criterion(outputs, labels) optimizer_decoupled.zero_grad() loss.backward() optimizer_decoupled.step() ``` ##### 3. **LWS(Learnable Weight Scaling)** LWS 是 τ-归一化的扩展,其中缩放因子 $f_i$ 是可学习参数。该方法通过在解耦阶段学习每个类别的缩放因子来优化分类器的决策边界。 ```python # 定义可学习的缩放因子 scaling_factors = nn.Parameter(torch.ones(num_classes)) # 在损失反向传播时,仅更新 scaling_factors optimizer_lws = torch.optim.SGD([scaling_factors], lr=lw_lr) for epoch in range(lws_epochs): for images, labels in class_balanced_loader: features = backbone(images).detach() logits = classifier(features) scaled_logits = logits * scaling_factors.unsqueeze(0) loss = criterion(scaled_logits, labels) optimizer_lws.zero_grad() loss.backward() optimizer_lws.step() ``` #### 代码结构总结 整个训练流程可以分为以下几个步骤: 1. **联合训练特征提取器和分类器**,使用不同的采样策略。 2. **冻结特征提取器**,进入解耦阶段。 3. **使用类别均等采样重新训练分类器**,采用 τ-归一化或 LWS 策略。 4. **在测试集上评估模型性能**,使用 NCM、τ-归一化或 LWS 分类器。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值