文章链接:paper
目录
Starting Point
当给定扰动半径内存在单个极小值或多个极小值时,零阶平坦度不足以区分具有低泛化误差的极小值与具有高泛化误差的极小值。
提出了一阶平坦性,这是一种更强的平坦性度量,关注扰动半径内的最大梯度范数,该扰动半径既包含局部极小值处的Hessian最大特征值,也包含SAM的正则化函数。
提出了一种新的优化器GAM,相较于SAM (Sharpness Aware Minimization) 能够寻找更平坦的最小值
In Depth: first order flatness
zeroth-order flatness
L
s
a
m
(
θ
)
=
L
^
(
θ
)
+
max
θ
′
∈
B
(
θ
,
ρ
)
(
L
^
(
θ
′
)
−
L
^
(
θ
)
)
L^{\mathrm{sam}}(\boldsymbol{\theta})=\hat{L}(\boldsymbol{\theta})+\max _{\boldsymbol{\theta}^{\prime} \in B(\boldsymbol{\theta}, \rho)}\left(\hat{L}\left(\boldsymbol{\theta}^{\prime}\right)-\hat{L}(\boldsymbol{\theta})\right)
Lsam(θ)=L^(θ)+maxθ′∈B(θ,ρ)(L^(θ′)−L^(θ))
ρ-zeroth-order flatness
R
ρ
(
0
)
(
θ
)
≜
max
θ
′
∈
B
(
θ
,
ρ
)
(
L
^
(
θ
′
)
−
L
^
(
θ
)
)
,
∀
θ
∈
Θ
R_\rho^{(0)}(\boldsymbol{\theta}) \triangleq \max _{\boldsymbol{\theta}^{\prime} \in B(\boldsymbol{\theta}, \rho)}\left(\hat{L}\left(\boldsymbol{\theta}^{\prime}\right)-\hat{L}(\boldsymbol{\theta})\right), \quad \forall \boldsymbol{\theta} \in \Theta
Rρ(0)(θ)≜maxθ′∈B(θ,ρ)(L^(θ′)−L^(θ)),∀θ∈Θ
ρ-first-order flatness
R
ρ
(
1
)
(
θ
)
≜
ρ
⋅
max
θ
′
∈
B
(
θ
,
ρ
)
∥
∇
L
^
(
θ
′
)
∥
,
∀
θ
∈
Θ
R_\rho^{(1)}(\boldsymbol{\theta}) \triangleq \rho \cdot \max _{\boldsymbol{\theta}^{\prime} \in B(\boldsymbol{\theta}, \rho)}\left\|\nabla \hat{L}\left(\boldsymbol{\theta}^{\prime}\right)\right\|, \quad \forall \boldsymbol{\theta} \in \Theta
Rρ(1)(θ)≜ρ⋅maxθ′∈B(θ,ρ)
∇L^(θ′)
,∀θ∈Θ,这里
ρ
\rho
ρ是控制邻域大小的扰动半径,其衡量了领域内的最大梯度范数。直观地说,ρ-first-order flatness要求损失函数在θ的邻域中不应发生剧烈变化,故由梯度范数衡量。
Relationship with Eigenvalue
作者随后证明了ρ-first-order flatness与Hessian阵的最大特征值
λ
max
(
∇
2
L
^
(
θ
∗
)
)
)
\left.\lambda_{\max }\left(\nabla^2 \hat{L}\left(\boldsymbol{\theta}^*\right)\right)\right)
λmax(∇2L^(θ∗)))间的关系,因为其被证明是泛化能力的有效度量之一,故有以下引理:
Gradient Norm Aware Minimization
随后,将上述ρ-first-order flatness定义引入优化器设计。将总损失函数设计为:
L
overall
(
θ
)
=
L
oracle
(
θ
)
+
α
R
ρ
(
1
)
(
θ
)
L^{\text {overall }}(\boldsymbol{\theta})=L^{\text {oracle }}(\boldsymbol{\theta})+\alpha R_\rho^{(1)}(\boldsymbol{\theta})
Loverall (θ)=Loracle (θ)+αRρ(1)(θ),其中
α
\alpha
α控制正则强度,则总损失函数梯度为:
∇
L
overall
(
θ
)
=
∇
L
oracle
(
θ
)
+
α
∇
R
ρ
(
1
)
(
θ
)
\nabla L^{\text {overall }}(\boldsymbol{\theta})=\nabla L^{\text {oracle }}(\boldsymbol{\theta})+\alpha \nabla R_\rho^{(1)}(\boldsymbol{\theta})
∇Loverall (θ)=∇Loracle (θ)+α∇Rρ(1)(θ),对这两部分分别求导,
∇
L
oracle
(
θ
)
\nabla L^{\text {oracle }}(\boldsymbol{\theta})
∇Loracle (θ)直接求导即可,对于第二部分
∇
R
ρ
(
1
)
(
θ
)
\nabla R_\rho^{(1)}(\boldsymbol{\theta})
∇Rρ(1)(θ),利用对
θ
\boldsymbol{\theta}
θ邻域泰勒展开并忽略高阶项,有
∇
R
ρ
(
1
)
(
θ
)
≈
ρ
⋅
∇
∥
∇
L
^
(
θ
a
d
v
)
∥
,
θ
a
d
v
=
θ
+
ρ
⋅
f
∥
f
∥
,
f
=
∇
∥
∇
L
^
(
θ
)
∥
=
∇
2
L
^
(
θ
)
⋅
∇
L
^
(
θ
)
∥
∇
L
^
(
θ
)
∥
.
\begin{aligned} & \nabla R_\rho^{(1)}(\boldsymbol{\theta}) \approx \rho \cdot \nabla\left\|\nabla \hat{L}\left(\boldsymbol{\theta}^{\mathrm{adv}}\right)\right\|, \quad \boldsymbol{\theta}^{\mathrm{adv}}=\boldsymbol{\theta}+\rho \cdot \frac{\boldsymbol{f}}{\|\boldsymbol{f}\|}, \\ & \boldsymbol{f}=\nabla\|\nabla \hat{L}(\boldsymbol{\theta})\|=\frac{\nabla^2 \hat{L}(\boldsymbol{\theta}) \cdot \nabla \hat{L}(\boldsymbol{\theta})}{\|\nabla \hat{L}(\boldsymbol{\theta})\|} . \end{aligned}
∇Rρ(1)(θ)≈ρ⋅∇
∇L^(θadv)
,θadv=θ+ρ⋅∥f∥f,f=∇∥∇L^(θ)∥=∥∇L^(θ)∥∇2L^(θ)⋅∇L^(θ).
算法流程如下:
Summary
相较于SAM,所提出的GAM转换了优化搜索的目标,即梯度二范数,并证明所提出优化器一定程度上包含了SAM与eig value;同时对于损失函数的改进也体现了正则化的思想。