Datawhale干货
作者:苏剑林,科学空间
原文链接:https://kexue.fm/archives/11260
在之前的文章《当Batch Size增大时,学习率该如何随之变化[1]?》和《Adam的epsilon如何影响学习率的Scaling Law?[2]》中,我们从理论上讨论了学习率随Batch Size的变化规律,其中比较经典的部分是由OpenAI提出的展开到二阶的分析。然而,当我们要处理非SGD优化器时,这套分析方法的计算过程往往会相当复杂,有种无从下手的感觉。
接下来的文章,笔者将重新整理和思考上述文章中的相关细节,尝试简化其中的一些推导步骤,给出一条更通用、更轻盈的推导路径,并且探讨推广到Muon优化器的可能性。
方法大意
首先回顾一下之前的分析方法。在《当Batch Size增大时,学习率该如何随之变化?》中,我们介绍了多种分析学习率与Batch Size规律的思路,其中OpenAI在《An Empirical Model of Large-Batch Training[3]》提出的二阶近似分析占了主要篇幅,本文也是沿用同样的思路。
接着需要引入一些记号。设损失函数为,是参数向量,是它的梯度。注意理想的损失函数是在全体训练样本上算的期望,但实际我们只能采样一个Batch来算,这导致梯度也带有随机性,我们将单个样本的梯度记为,它的均值就是,而协方差矩阵记为;当Batch Size为时,梯度记为,它的均值还是,但协方差矩阵变为。
进一步地,设当前学习率为,更新向量为,那么更新后的损失函数将是
右侧我们泰勒展开到了二阶,是Hessian矩阵,是矩阵的迹,第二个等号用到了这个恒等式。为了得到一个确定性的结果,我们对两边求期望:
我们把右端看成是关于的二次函数,并假设二次项系数是正的(更强的假设是矩阵是正定的),那么可以得到最小值点
这便是平均来说让损失函数下降最快的学习率,是学习率的理论最优解。我们要做的事情,就是针对具体的算出和,然后从上式析出它与Batch Size(即)的关系。
热身练习
作为第一个例子,我们自然是考虑最简单的SGD,此时有,那么简单可得以及,于是有
其中
对于结果,我们可以有多种解读方式。首先,它是一个单调递增但有上界的函数,上界为,这表明学习率不能无限增加,相比简单的线性律或者平方根律,它更符合我们的直觉认知;当时,我们有
这表明在Batch Size比较小时,SGD的学习率与Batch Size确实呈线性关系,同时也暗示了是一个关键统计量。不过的定义依赖于Hessian矩阵,这在LLM中是几乎不可能精确计算的,所以实践中我们通常假设它是单位阵(的若干倍),得到一个简化的形式
该结果具有噪音强度()除以信号强度()的形式,它其实就是信噪比的倒数,它表明信噪比越小,那么就需要更大的Batch Size才能用上相同的,这也跟我们的直觉认知相符。只依赖于的对角线元素,这表明我们只需要将每个参数独立地估计均值和方差,这在实践上是可行的。
数据效率
除了学习率与Batch Size的直接关系外,笔者认为由此衍生出来的关于训练数据量和训练步数的渐近关系,也是必须要学习的精彩部分。特别地,这个结论似乎比学习率的关系式更为通用,因为后面我们将会看到,SignSGD也能得到同样形式的结论,但它的学习率规律并不是式。
原论文对这部分的讨论比较复杂,下面的推导是经过笔者简化的。具体来说,我们将代回到,将得到
其中。怎么理解这个结果呢?首先,它是关于的单调递增函数,当时等于,换言之如果我们能开无穷大的Batch Size,那么每一步的损失下降量是,此时所需的训练步数最少,记为。
如果Batch Size是有限值,那么每一步的损失下降量平均来说只有,这意味着平均而言我们要花
由于Batch Size为
这便是训练数据量和训练步数的经典关系式,它有两个参数
困难分析
前面写了那么多,都还停留在SGD中。从计算角度看,SGD是平凡的,真正复杂的是
在这些非线性场景下,
1、假设
2、假设
3、将
也就是说,我们要经过一堆弯弯绕绕的步骤,才勉强算出一个可以分析下去的近似结果(这个过程首次出现在Tencent的论文《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling[4]》),而且这已经算是简单的了,因为如果是SoftSignSGD,则更加复杂:
1、假设
2、将
3、假设
4、将复杂函数用
事情还没完。费那么大劲,加那么多假设,我们才堪堪算出
本文引用链接
[1] 当Batch Size增大时,学习率该如何随之变化: https://kexue.fm/archives/10542[2] Adam的epsilon如何影响学习率的Scaling Law?: https://kexue.fm/archives/10563[3] An Empirical Model of Large-Batch Training: https://papers.cool/arxiv/1812.06162[4] Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling: https://papers.cool/arxiv/2405.14578

一起“点赞”三连↓
学习率与Batch Size关系再思考
2万+

被折叠的 条评论
为什么被折叠?



