SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION 论文精读

SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION-论文精读

锐度感知最小化 (SAM)paper-link 是一种简单而有效的技术,用于提高深度学习模型在看不见的数据样本上的泛化能力。

摘要

在当今过度参数化的模型中,训练损失的值很少能保证模型的泛化能力。事实上,仅优化训练损失值(如通常所做的那样)很容易导致次优的模型质量。
受先前将损失景观的几何形状与泛化能力相关联的工作的启发,我们引入了一种新颖而有效的方法,该方法通过同时最小化损失值和损失锐度来提高模型泛化能力。特别是,我们的方法,锐度感知最小化 (SAM),寻求位于具有均匀低损失的邻域中的参数;这种公式导致了一个可以在其上有效地进行梯度下降的最小-最大优化问题。我们展示了实验结果,表明 SAM 在各种基准数据集(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型上提高了模型的泛化能力,为其中一些模型带来了新的最先进性能。此外,我们发现 SAM 本身就具有与专门针对噪声标签学习而设计的最先进程序相当的对标签噪声的鲁棒性。
代码: https://github.com/goo

<think>好的,我需要回答用户关于Adaptive Sharpness-Aware Minimization (ASAM) 的详细信息及具体算法结构的问题。首先,我需要回顾用户提供的引用内容,特别是引用[1]和[2]。根据引用,ASAM是ICML '21的一篇论文,针对SAM(Sharpness-Aware Minimization)进行了改进,使其具有尺度不变性。 首先,我应该解释ASAM要解决的核心问题。SAM的初衷是寻找平坦的最小值,以提高模型的泛化能力,但SAM在处理不同参数尺度时可能不够有效,因为其损失函数对参数缩放敏感。ASAM通过引入自适应缩放,使得sharpness的度量不受参数缩放的影响,从而提升了SAM的效果。这部分需要详细说明SAM的局限性以及ASAM如何改进。 接下来,算法结构部分需要具体。我需要回忆ASAM的具体步骤,可能包括如何计算自适应权重,如何调整梯度更新方向。例如,ASAM可能对每个参数进行归一化,使得梯度更新时考虑参数的相对大小。这里可能需要数学公式,根据用户的要求,独立公式用$$...$$,行内用$...$。例如,SAM的目标函数是: $$\min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) + \lambda \max_{\|\epsilon\|_2 \leq \rho} \mathcal{L}(\mathbf{w} + \epsilon)$$ 而ASAM可能修改了扰动区域,使其适应参数尺度,比如使用参数相关的扰动大小: $$\max_{\|\epsilon_i\|_2 \leq \rho \|\mathbf{w}_i\|_2} \mathcal{L}(\mathbf{w} + \epsilon)$$ 这样,每个参数的扰动范围与其自身的范数成比例,从而实现尺度不变性。 然后,需要描述ASAM的具体实现步骤,可能包括梯度计算、扰动方向的确定以及参数更新方式。例如,ASAM在每次迭代中计算梯度,然后根据参数的大小调整扰动方向,进而更新参数。这部分需要结合算法流程,可能类似于: 1. 计算当前参数$\mathbf{w}$的损失梯度。 2. 对每个参数$\mathbf{w}_i$,计算自适应扰动范围$\rho \|\mathbf{w}_i\|_2$。 3. 确定最坏情况扰动$\epsilon$,使得损失最大化。 4. 沿扰动后的梯度方向更新参数。 此外,需要提到ASAM的理论贡献,如证明其尺度不变性,以及在实验中的表现,比如在图像分类任务上的效果提升。 用户还要求生成相关问题,我需要考虑ASAM与其他方法的比较,实现细节,应用场景等。例如,可能的问题包括ASAM与SAM的区别,实际应用中的调参技巧,或者在不同网络结构中的效果差异。 现在需要确保所有数学表达式格式正确,引用标识正确添加。同时,使用中文回答,结构清晰,分点说明。检查是否有遗漏的关键点,如ASAM的动机、算法步骤、数学公式、实验结果等。最后,确保相关问题与ASAM相关,能够引导进一步的学习。</think>### Adaptive Sharpness-Aware Minimization (ASAM) 详解 #### **1. 解决的问题** ASAM 是 SAM(Sharpness-Aware Minimization)的改进版本,旨在解决 SAM 对参数尺度敏感的问题[^1]。SAM 的目标是寻找损失函数中“平坦”的极小值(即对参数扰动鲁棒的区域),以提升模型泛化能力。然而,SAM 的扰动范围是固定且与参数尺度无关的,这可能导致对不同量级的参数施加相同的扰动强度,影响优化效果。ASAM 通过引入**自适应的扰动范围**,使扰动与参数自身尺度解耦,从而实现**尺度不变性**(Scale-Invariant Learning)[^2]。 --- #### **2. 核心思想** ASAM 将 SAM 的固定扰动范围替换为与参数尺度相关的动态扰动。具体来说: - **SAM 的扰动范围**:所有参数的扰动大小统一为 $\epsilon$,即 $\|\epsilon\|_2 \leq \rho$。 - **ASAM 的扰动范围**:每个参数 $\mathbf{w}_i$ 的扰动范围与其自身尺度成比例,即 $\|\epsilon_i\|_2 \leq \rho \|\mathbf{w}_i\|_2$。 数学上,ASAM 的优化目标为: $$ \min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) + \lambda \max_{\|\epsilon_i\|_2 \leq \rho \|\mathbf{w}_i\|_2} \mathcal{L}(\mathbf{w} + \epsilon) $$ 其中 $\rho$ 控制扰动强度,$\lambda$ 是正则化系数。 --- #### **3. 算法步骤** ASAM 的迭代过程分为以下步骤: 1. **计算当前损失梯度**:计算 $\nabla_{\mathbf{w}} \mathcal{L}(\mathbf{w})$。 2. **确定自适应扰动方向**: - 对每个参数 $\mathbf{w}_i$,计算归一化梯度方向 $\mathbf{g}_i = \frac{\nabla_{\mathbf{w}_i} \mathcal{L}(\mathbf{w})}{\|\mathbf{w}_i\|_2}$。 - 扰动 $\epsilon_i$ 的最坏方向为 $\epsilon_i = \rho \|\mathbf{w}_i\|_2 \cdot \mathbf{g}_i$。 3. **更新参数**:沿扰动后的梯度方向更新: $$ \mathbf{w} \leftarrow \mathbf{w} - \eta \left( \nabla_{\mathbf{w}} \mathcal{L}(\mathbf{w}) + \lambda \nabla_{\mathbf{w}} \mathcal{L}(\mathbf{w} + \epsilon) \right) $$ 其中 $\eta$ 是学习率。 --- #### **4. 关键改进** - **尺度不变性**:扰动范围与参数自身大小成比例,避免对小参数过度扰动或对大参数扰动不足。 - **兼容性**:可无缝集成到现有优化器(如 SGD、Adam)中。 - **实验效果**:在图像分类(如 CIFAR-10/100)和自然语言处理任务中,ASAM 相比 SAM 显著提升了泛化性能。 --- #### **5. 代码示例(伪代码)** ```python def ASAM_update(w, loss_fn, rho, lambda_, lr): # 计算原始梯度 loss = loss_fn(w) grad = compute_gradient(loss, w) # 计算自适应扰动 perturb = [] for wi, gi in zip(w, grad): norm_wi = torch.norm(wi) perturb_i = rho * norm_wi * gi / (torch.norm(gi) + 1e-12) perturb.append(perturb_i) # 计算扰动后的梯度 w_perturbed = w + torch.cat(perturb) loss_perturbed = loss_fn(w_perturbed) grad_perturbed = compute_gradient(loss_perturbed, w_perturbed) # 更新参数 w = w - lr * (grad + lambda_ * grad_perturbed) return w ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值