苏剑林:重新思考学习率与Batch Size

学习率与Batch Size关系再思考

 Datawhale干货 

作者:苏剑林,科学空间

原文链接:https://kexue.fm/archives/11260

在之前的文章《当Batch Size增大时,学习率该如何随之变化[1]?》和《Adam的epsilon如何影响学习率的Scaling Law?[2]》中,我们从理论上讨论了学习率随Batch Size的变化规律,其中比较经典的部分是由OpenAI提出的展开到二阶的分析。然而,当我们要处理非SGD优化器时,这套分析方法的计算过程往往会相当复杂,有种无从下手的感觉。

接下来的文章,笔者将重新整理和思考上述文章中的相关细节,尝试简化其中的一些推导步骤,给出一条更通用、更轻盈的推导路径,并且探讨推广到Muon优化器的可能性。

方法大意

首先回顾一下之前的分析方法。在《当Batch Size增大时,学习率该如何随之变化?》中,我们介绍了多种分析学习率与Batch Size规律的思路,其中OpenAI在《An Empirical Model of Large-Batch Training[3]》提出的二阶近似分析占了主要篇幅,本文也是沿用同样的思路。

接着需要引入一些记号。设损失函数为是参数向量,是它的梯度。注意理想的损失函数是在全体训练样本上算的期望,但实际我们只能采样一个Batch来算,这导致梯度也带有随机性,我们将单个样本的梯度记为,它的均值就是,而协方差矩阵记为;当Batch Size为时,梯度记为,它的均值还是,但协方差矩阵变为

进一步地,设当前学习率为,更新向量为,那么更新后的损失函数将是

右侧我们泰勒展开到了二阶,是Hessian矩阵,是矩阵的迹,第二个等号用到了这个恒等式。为了得到一个确定性的结果,我们对两边求期望:

我们把右端看成是关于的二次函数,并假设二次项系数是正的(更强的假设是矩阵是正定的),那么可以得到最小值点

这便是平均来说让损失函数下降最快的学习率,是学习率的理论最优解。我们要做的事情,就是针对具体的算出,然后从上式析出它与Batch Size(即)的关系。

热身练习

作为第一个例子,我们自然是考虑最简单的SGD,此时有,那么简单可得以及,于是有

其中

对于结果,我们可以有多种解读方式。首先,它是一个单调递增但有上界的函数,上界为,这表明学习率不能无限增加,相比简单的线性律或者平方根律,它更符合我们的直觉认知;当时,我们有

这表明在Batch Size比较小时,SGD的学习率与Batch Size确实呈线性关系,同时也暗示了是一个关键统计量。不过的定义依赖于Hessian矩阵,这在LLM中是几乎不可能精确计算的,所以实践中我们通常假设它是单位阵(的若干倍),得到一个简化的形式

该结果具有噪音强度()除以信号强度()的形式,它其实就是信噪比的倒数,它表明信噪比越小,那么就需要更大的Batch Size才能用上相同的,这也跟我们的直觉认知相符。只依赖于的对角线元素,这表明我们只需要将每个参数独立地估计均值和方差,这在实践上是可行的。

数据效率

除了学习率与Batch Size的直接关系外,笔者认为由此衍生出来的关于训练数据量和训练步数的渐近关系,也是必须要学习的精彩部分。特别地,这个结论似乎比学习率的关系式更为通用,因为后面我们将会看到,SignSGD也能得到同样形式的结论,但它的学习率规律并不是式

原论文对这部分的讨论比较复杂,下面的推导是经过笔者简化的。具体来说,我们将代回到,将得到

其中。怎么理解这个结果呢?首先,它是关于的单调递增函数,当时等于,换言之如果我们能开无穷大的Batch Size,那么每一步的损失下降量是,此时所需的训练步数最少,记为

如果Batch Size是有限值,那么每一步的损失下降量平均来说只有,这意味着平均而言我们要花步,才能达到无穷大Batch Size时1步的下降量,于是为了达到相同的损失,我们要训练步。

由于Batch Size为,所以很容易得出训练消耗的数据总量为。从这个结果可以看出,增大Batch Size后,想要达到相同的效果,我们还需要适当增加数据量;当时,所需要的数据量最少,为。利用这些记号,我们可以写出

这便是训练数据量和训练步数的经典关系式,它有两个参数,我们也可以通过实验搜索多个来拟合上式,从而估计,进而可以估算。更多分析细节请看回之前的文章《当Batch Size增大时,学习率该如何随之变化?》或OpenAI的原论文《An Empirical Model of Large-Batch Training》。

困难分析

前面写了那么多,都还停留在SGD中。从计算角度看,SGD是平凡的,真正复杂的是非线性地依赖于的情形,比如SignSGD对应于,在理论分析中它经常用作Adam的近似,更准确的近似则是考虑了的SoftSignSGD,我们在《Adam的epsilon如何影响学习率的Scaling Law?》尝试过分析它。

在这些非线性场景下,的计算往往是相当困难的,即便我们将的分布假设为简单的正态分布也是如此(注意,在SGD的分析中,我们并不需要对它的分布形式做正态假设)。比如,在之前的文章中,对于的SignSGD,为了计算,我们经历了如下步骤:

1、假设的分量相互独立,问题简化为单个分量(没有加粗)的期望;

2、假设(此时是一个标量)服从正态分布,那么就可以算出,答案要用函数来表示;

3、将函数用形式的函数近似,简化结果。

也就是说,我们要经过一堆弯弯绕绕的步骤,才勉强算出一个可以分析下去的近似结果(这个过程首次出现在Tencent的论文《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling[4]》),而且这已经算是简单的了,因为如果是SoftSignSGD,则更加复杂:

1、假设的分量相互独立,问题简化为单个分量的期望;

2、将函数用分段线性函数近似,这样才能算出下面的积分;

3、假设服从正态分布,结合第2步的近似,可以算出,答案是包含的复杂函数;

4、将复杂函数用形式的函数近似,简化结果。

事情还没完。费那么大劲,加那么多假设,我们才堪堪算出,接着还要算,这往往更加复杂(SignSGD是个例外,因为一定是1,所以反而简单了)。然而,计算的复杂性还是次要的,主要是这些步骤看上去没有任何能推广的规律,似乎只能具体问题具体分析的样子,这就让人觉得非常心累。

本文引用链接

[1] 当Batch Size增大时,学习率该如何随之变化: https://kexue.fm/archives/10542
[2] Adam的epsilon如何影响学习率的Scaling Law?: https://kexue.fm/archives/10563
[3] An Empirical Model of Large-Batch Training: https://papers.cool/arxiv/1812.06162
[4] Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling: https://papers.cool/arxiv/2405.14578

图片

一起“赞”三连

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值