达摩院开源预训练数据微调框架UOT NeurIPS论文深入解读

文章研究了预训练模型在下游任务中微调的问题,提出在数据有限时利用预训练数据进行优化。通过理论分析和UOT(不平衡最优传输)数据选择方法,改善了模型的泛化能力,尤其在小数据集和长尾分布情况下表现突出。实验结果显示,UOT联合微调策略在多个数据集上取得最佳效果。
部署运行你感兴趣的模型镜像

​​​​​​​  团队模型、论文、博文、直播合集,点击此处浏览

一、论文

论文链接: Improved Fine-Tuning by Better Leveraging Pre-Training Data

代码链接:https://github.com/ziquanliu/NeurIPS2022_UOT_fine_tuning

二、背景

        在下游任务中对预训练模型进行微调被广泛用于许多深度学习的应用中。最近的深度学习的研究主要聚焦在预训练数据充分的情况下如何训练更强大的预训练模型,然而预训练模型在下游的使用往往面临着数据不充分的困难。另一方面,最近的研究经验表明,一旦在某些视觉任务中增加了训练样本的数量,从头开始训练的最终性能并不比这种预训练策略差。在这项工作中,我们使用学习理论中流行的超额风险界限,从泛化分析的角度重新审视了这一现象。当下游数据有限时,我们提出利用预训练数据进行微调。使用预训练数据来进行微调的泛化结果表明,当微调中包含适当的预训练数据时,可以改善目标任务的过度风险界限。基于我们的理论分析,我们提出了一种新的选择策略,从预训练数据中选择一个子集,以帮助提高对目标任务的泛化能力。我们的方法如下图所示。

三、理论分析

        首先定义下游的目标函数和上游预训练函数,然后我们对这两个函数的性质做出如下四个假设。其中第一个假设是对两个训练分布的相似性给一个上界,第二和第三是描述优化目标函数的unbiased, bounded variance和smoothness性质,第四个是常用的PL condition.

         给定以上四个假设,我们可以得到下面的引理。引理1说明了当下游数据的数据量到达一定程度的时候,预训练模型对于随机初始化来说其实没有什么优势。这个引理也解释了之前人们在实验中发现的现象[1]。

         当下游任务的数据量有限时,我们提出在fine-tuning时重复利用预训练数据。具体来说我们在fine-tuning时把两个训练数据的损失函数做一个加权,然后用两部分数据训练这个网络。在实际的操作中,我们对于下游数据和预训练数据使用不同的线性分类器,以便我们可以使用任意数量的下游类别,而且下游和预训练的标签互不影响。

         对于这种joint fine-tuning的效果,我们推导了如下定理。该定理说明,如果我们能够选取适当的预训练数据,使得两个训练数据的距离缩小,那么joint fine-tuning的效果会变得更好。

四、数据选择

        我们提出一种基于unbalanced optimal transport的通用的数据选择方法,简记为UOT数据选择。我们的思想是找到一个运输方式把预训练数据运输到下游数据,而这种运输方式体现了预训练数据的某个类和下游数据的距离。我们发现这种UOT数据选择的方式在选择相似数据上非常的有效,如下图所示。因为CUB数据只包含鸟类图片,所以我们知道这时候需要从预训练集上选择的数据是鸟类。UOT成功地选择了几乎所有ImageNet中的鸟类

五、结果

1. 效果展示

        我们分别用有监督和无监督的预训练模型在8个下游数据集上做了我们的实验。当预训练数据集无标签时,我们用K-means聚类给数据打上伪标签。实验结果表明,我们的UOT joint-fine-tuning在多数数据集上取得了最优的效果。

2. 小数据集与长尾类分布

        我们进一步模拟了下游的小数据集和长尾类分布的情况。左图中我们展示了对CUB和Caltech数据集按类进行数据量下采样的实验结果。当数据量越小时,我们的UOT方法取得了越大的提升。右表显示了我们对于3个数据集进行长尾累分布采样后,再进行UOT和标准微调的结果。我们的方法在这种情况下取得了显著更好的结果。

六、参考文献

[1] He, Kaiming, Ross Girshick, and Piotr Dollár. "Rethinking imagenet pre-training." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.

七、应用

        接下来给大家介绍下我们研发的各个域上的开源免费模型,欢迎大家体验、下载(大部分手机端即可体验):

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

ModelScope 魔搭社区

您可能感兴趣的与本文相关的镜像

Llama Factory

Llama Factory

模型微调
LLama-Factory

LLaMA Factory 是一个简单易用且高效的大型语言模型(Large Language Model)训练与微调平台。通过 LLaMA Factory,可以在无需编写任何代码的前提下,在本地完成上百种预训练模型的微调

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI记忆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值