今天我们来聊一聊label smooth这个tricks,标签平滑已经成为众所周知的机器学习或者说深度学习的正则化技巧。标签平滑——label smooth regularization作为一种简单的正则化技巧,它能提高分类任务中模型的泛化性能和准确率,缓解数据分布不平衡的问题,同时在模型蒸馏中可见它的身影。在最近的2个月3个NLP算法比赛的实战中,label smooth也作为一种炼丹术被我用来提高比赛的成绩。那为啥label smooth有效呢?怎么来解释这个现象呢?那么我们就在这篇博客中一起学习一下label smooth的数学原理以及宏观角度的解释,最后看看最近的论文中有没有更好的label smooth方法。
一、label smooth及其原理和解释
label smooth是相对于hard label和soft label 而言的,一般的分类任务中我们对label是采用hard label的方式进行one hot编码,而对hard label得到的one hot编码添加一点点噪声。举例如下图来自——如何理解soft target这一做法?:
hard label和soft label的优缺点在图中也给出来了,相对来说soft label拥有携带更多的信息,更好的描述数据的类别情况,而hard label丢失了类内和类间的关联,从这个角度来看soft label确实能在一定程度上提高模型的泛化能力,也就是相同数据能提点。
分类问题中, 假设样本 的标签为
,
为样本对应的预测概率(即softmax的结果)。交叉熵损失如下:
神经网络的输出称为logits,简记为,经过softmax之后转化为和为1的概率形式,记为
,真值target记为, 为分类类别的数量。由softmax公式可得:
当模型的loss为0的时候,当样本为真样本的时候
,可以得出:
最终结果是: 什么意思呢?
神经网络在交叉熵损失函数的时候,当模型loss很低的时候(为0的时候),必然是真样本的logits为常数,假样本的logits为负无穷,一般而言,模型的输出由于采用了激活函数以及有界限定之类的logits不可能为无穷大,采用hard label就不会得到最优的结果——也可以直接说对真样本是其softmax值为1,假样本softmax值为0过于绝对!这就是hard label 不好的原因!
label smooth采用soft label的时候情况就不一样了
abel smooth 学习的编码形式如下图,其中是预定义好的一个超参数,一般取值0.1, 是该分类问题的类别个数:
经过上述类似的推导——详细推导过程参考文章——简单的label smoothing为什么能够涨点呢,导数等于0的情况下,logit的取值
可见——使用label-smooth时,假样本的logit不会要求是负无穷。且假样本和真样本的logit值有一定大小误差的情况下,loss就会很小很小,这个对模型效果提升肯定是有益的。
二、label smooth的实现
label smooth可以直接使用soft label 然后采用KLDIvLoss计算loss。
import torch
def label_smooth(label, n_class=3,alpha=0.1):
"""
标签平滑
:param label: 真实lable
:param n_class: 类别数目
:param alpha: 平滑系数
:return:
"""
k = alpha / (n_class - 1)
# temp [batch_size,n_class]
temp = torch.full((label.shape[0], n_class), k)
# scatter_.(int dim, Tensor index, Tensor src),这个函数比较难理解——用src张量根据dim和index来修改temp中的元素
temp = temp.scatter_(1, label.unsqueeze(1), (1-alpha))
return temp
也可以把soft label以及计算loss的过程统一封装起来,实现一个新的loss function实现如下:
"""
标签平滑
可以把真实标签平滑集成在loss函数里面,然后计算loss
也可以直接在loss函数外面执行标签平滑,然后计算散度loss
"""
import torch.nn as nn
import torch
class LabelSmoothingLoss(nn.Module):
"""
标签平滑Loss
"""
def __init__(self, classes, smoothing=0.0, dim=-1):
"""
:param classes: 类别数目
:param smoothing: 平滑系数
:param dim: loss计算平均值的维度
"""
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
self.loss = nn.KLDivLoss()
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
#torch.mean(torch.sum(-true_dist * pred, dim=self.dim))就是按照公式来计算损失
loss = torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
#采用KLDivLoss来计算
loss = self.loss(pred,true_dist)
return loss
forward中的计算也有两种方式 一种是直接采用KLDIvLoss来计算,一种是采用公式一步一步的计算
三、最新的label smooth——在线学习label smooth
《Delving Deep into Label Smoothing》这篇论文就提供了一种在线标签平滑策略方法,使用一种在线学习的方式来生成soft label,相比传统的soft label方法,论文提出的方法声称效提高分类性能和模型的鲁棒性,优于LS、Bootsoft等方法。
原理步骤图
算法流程
新的损失函数
公式四就是最终的loss函数
具体实现,放上一份别人实现的代码:
import torch
import torch.nn as nn
from torch import Tensor
class OnlineLabelSmoothing(nn.Module):
"""
Implements Online Label Smoothing from paper
https://arxiv.org/pdf/2011.12562.pdf
使用方法
from ols import OnlineLabelSmoothing
criterion = OnlineLabelSmoothing(alpha=..., n_classes=...)
for epoch in range(...): # loop over the dataset multiple times
for i, data in enumerate(...):
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch} finished!')
# Update the soft labels for next epoch
criterion.next_epoch()
criterion.eval()
dev()/test()
"""
def __init__(self, alpha: float, n_classes: int, smoothing: float = 0.1):
"""
:param alpha: Term for balancing soft_loss and hard_loss
:param n_classes: Number of classes of the classification problem
:param smoothing: Smoothing factor to be used during first epoch in soft_loss
"""
super(OnlineLabelSmoothing, self).__init__()
assert 0 <= alpha <= 1, 'Alpha must be in range [0, 1]'
self.a = alpha
self.n_classes = n_classes
# Initialize soft labels with normal LS for first epoch
self.register_buffer('supervise', torch.zeros(n_classes, n_classes))
self.supervise.fill_(smoothing / (n_classes - 1))
self.supervise.fill_diagonal_(1 - smoothing)
# Update matrix is used to supervise next epoch
self.register_buffer('update', torch.zeros_like(self.supervise))
# For normalizing we need a count for each class
self.register_buffer('idx_count', torch.zeros(n_classes))
self.hard_loss = nn.CrossEntropyLoss()
def forward(self, y_h: Tensor, y: Tensor):
# Calculate the final loss
soft_loss = self.soft_loss(y_h, y)
hard_loss = self.hard_loss(y_h, y)
return self.a * hard_loss + (1 - self.a) * soft_loss
def soft_loss(self, y_h: Tensor, y: Tensor):
"""
Calculates the soft loss and calls step
to update `update`.
:param y_h: Predicted logits.
:param y: Ground truth labels.
:return: Calculates the soft loss based on current supervise matrix.
"""
y_h = y_h.log_softmax(dim=-1)
if self.training:
with torch.no_grad():
self.step(y_h.exp(), y)
true_dist = torch.index_select(self.supervise, 1, y).swapaxes(-1, -2)
return torch.mean(torch.sum(-true_dist * y_h, dim=-1))
def step(self, y_h: Tensor, y: Tensor) -> None:
"""
Updates `update` with the probabilities
of the correct predictions and updates `idx_count` counter for
later normalization.
Steps:
1. Calculate correct classified examples.
2. Filter `y_h` based on the correct classified.
3. Add `y_h_f` rows to the `j` (based on y_h_idx) column of `memory`.
4. Keep count of # samples added for each `y_h_idx` column.
5. Average memory by dividing column-wise by result of step (4).
Note on (5): This is done outside this function since we only need to
normalize at the end of the epoch.
"""
# 1. Calculate predicted classes
y_h_idx = y_h.argmax(dim=-1)
# 2. Filter only correct
mask = torch.eq(y_h_idx, y)
y_h_c = y_h[mask]
y_h_idx_c = y_h_idx[mask]
# 3. Add y_h probabilities rows as columns to `memory`
self.update.index_add_(1, y_h_idx_c, y_h_c.swapaxes(-1, -2))
# 4. Update `idx_count`
self.idx_count.index_add_(0, y_h_idx_c, torch.ones_like(y_h_idx_c, dtype=torch.float32))
def next_epoch(self) -> None:
"""
This function should be called at the end of the epoch.
It basically sets the `supervise` matrix to be the `update`
and re-initializes to zero this last matrix and `idx_count`.
"""
# 5. Divide memory by `idx_count` to obtain average (column-wise)
self.idx_count[torch.eq(self.idx_count, 0)] = 1 # Avoid 0 denominator
# Normalize by taking the average
self.update /= self.idx_count
self.idx_count.zero_()
self.supervise = self.update
self.update = self.update.clone().zero_()
实际效果待验证!
参考文章