✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。
我是Srlua小谢,在这里我会分享我的知识和经验。🎥
希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮
记得先点赞👍后阅读哦~ 👏👏
📘📚 所属专栏:传知代码论文复现
欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙
目录
本文所有资源均可在该地址处获取。
概述
本文复现论文 Revisiting Consistency Regularization for Deep Partial Label Learning[1] 提出的偏标记学习方法。程序基于Pytorch,会保存完整的训练日志,并生成损失变化图和准确度变化图。
偏标记学习(Partial Label Learning)是一个经典的弱监督问题。在偏标记学习中,每个样例的监督信息为一个包含多个标签的候选标签集合。目前的偏标记方法大多基于自监督或者对比学习范式,或多或少地会遇到低性能或低效率的问题。该论文基于一致性正则化的思想,改进基于自监督的偏标记学习方法。具体地,该论文所提出的方法设计了两个训练目标。其中第一个训练目标为最小化非候选标签的预测输出,第二个目标最大化不同视图的预测输出之间的一致性。
总的来说,该论文所提出的方法着眼于将模型对同一图像不同增强视图的预测输出对齐,以提升模型输出的可靠性和对标签的消歧能力,这一方法同样能给其他弱监督学习任务带来提升。
算法原理
首先,论文所提出方法的第一项损失(监督损失)如下:
Losssupervised(x)=−∑i=1cI(i∉candidates)⋅log[1−fi(x)]Losssupervised(x)=−i=1∑cI(i∈candidates)⋅log[1−fi(x)]
其中,当事件 AA 为真时I(A)=1I(A)=1 否则 I(A)=0I(A)=0,f(⋅)f(⋅) 表示模型的输出概率。
然后,论文所提出方法的第二项损失(一致性损失)如下:
Lossconsistency(x)=KL-Divergence[f(x),label-distribution(x)]Lossconsistency(x)=KL-Divergence[f(x),label-distribution(x)]
其在训练过程中通过所有增强视图预测结果的几何平均来更新标签分布:
label-distributioni(x)=I(i∈candidates)⋅f‾i(augment(x))∑jI(j∈candidates)⋅f‾j(augment(x))label-distributioni(x)=∑jI(j∈candidates)⋅fj(augment(x))I(i∈candidates)⋅fi(augment(x))
由于数据增强的不稳定性,该论文通过叠加 KK 个不同的增强视图的一致性损失来提升方法性能。
最后,考虑到训练初期模型的预测准确率较低,一致性损失的权重被设置为从零开始随着训练轮数的增加逐渐提高:
λ(t)=min(λmax⋅tT,λmax)λ(t)=min(λmax⋅Tt,λmax)
综上所述,模型的总损失函数如下:
Losssummary=Losssupervised[weak-augment(x)]+λ(t)⋅∑k=1KLossconsistency[strong-augmentk(x)]Losssummary=Losssupervised[weak-augment(x)]+λ(t)⋅k=1∑KLossconsistency[strong-augmentk(x)]
核心逻辑
具体的核心逻辑如下所示:
def dpll_sup_loss(probs, partial_labels):
loss = -torch.sum(torch.log(1 + 1e-6 - probs) * (1 - partial_labels), dim=-1)
loss_avg = torch.mean(loss)
return loss_avg
def dpll_cont_loss(logits, targets):
logits_log = torch.log_softmax(logits, dim=-1)
loss = F.kl_div(logits_log, targets, reduction='batchmean')
return loss
def train():
# main loops
for epoch_id in range(total_epochs):
# train
model.train()
for batch in train_dataloader:
optimizer.zero_grad()
ids = batch['ids']
data1 = batch['data1'].to(device)
data2 = batch['data2'].to(device)
data3 = batch['data3'].to(device)
partial_labels = batch['partial_labels'].to(device)
targets = train_targets[ids].to(device)
logits1 = model(data1)
logits2 = model(data2)
logits3 = model(data3)
probs1 = F.softmax(logits1, dim=-1)
# update targets
with torch.no_grad():
probs2 = F.softmax(logits2.detach(), dim=-1)
probs3 = F.softmax(logits3.detach(), dim=-1)
new_targets = torch.pow(probs1.detach() * probs2 * probs3, 1 / 3)
new_targets = F.normalize(new_targets * partial_labels, p=1, dim=-1)
train_targets[ids] = new_targets.cpu()
# dynamic weight
balancing_weight = max_weight * (epoch_id + 1) / max_weight_epoch
balancing_weight = min(max_weight, balancing_weight)
# supervised loss
loss_sup = dpll_sup_loss(probs1, partial_labels)
# consistency regularization loss
loss_cont1 = dpll_cont_loss(logits1, targets)
loss_cont2 = dpll_cont_loss(logits2, targets)
loss_cont3 = dpll_cont_loss(logits3, targets)
# all loss
loss = loss_sup + balancing_weight * (loss_cont1 + loss_cont2 + loss_cont3)
loss.backward()
optimizer.step()
if epoch_id in lr_decay_epochs:
lr_scheduler.step()
以上代码仅作展示,更详细的代码文件请参见附件。
效果演示
本文基于网络 Wide-ResNet[2] 和数据集 CIFAR-10[3] 进行实验,偏标记的随机翻转概率为0.1。当然,本文所提供的程序不仅仅提供了上述的实验设置,同时也可以直接基于CIFAR-100(100类图像分类数据集),SVHN(数字号牌识别数据集),Fashion-MNIST(时装识别数据集),Kuzushiji-MNIST(日本古草体识别数据集)进行实验。仅仅需要替换运行命令的对应部分即可(使用说明见下文)
- 损失曲线:
- 准确率曲线:
使用方式
- 解压附件压缩包并进入工作目录,并通过执行如下命令进行环境配置:
pip install -r requirements.txt
- 如果希望运行训练程序,请执行如下命令:
python main.py -m [模型名称] -d [数据集名称] -p [翻转概率]
- 通过如下命令执行本文展示的实验:
python main.py -m "Wide-ResNet-34-10" -d "CIFAR-10" -p 0.1
- 实验结果保存在
./data/logs/
目录下,并以执行时间为名称。
参考文献
[1] Wu D D, Wang D B, Zhang M L. Revisiting consistency regularization for deep partial label learning[C]//International conference on machine learning. PMLR, 2022: 24212-24225.
[2] Zagoruyko S. Wide residual networks[J]. arXiv preprint arXiv:1605.07146, 2016.
[3] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.
希望对你有帮助!加油!
若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!