Generalized Infinitesimal Gradient Ascent(GIGA)

Generalized Infinitesimal Gradient Ascent(GIGA)

GIGA是一种梯度上升算法,可以用于强化学习和深度学习中,智能体根据收益调整自己的策略。
在这里插入图片描述

### Generalized SAM概述 Generalized Sharpness-Aware Minimization (SAM) 是一种优化算法,旨在通过最小化损失函数的锐度来提高模型泛化能力[^1]。传统上,机器学习中的优化目标主要是使训练集上的经验风险最小化。然而,在某些情况下,这种做法可能导致过拟合现象。 为了克服这一挑战,Generalized SAM引入了一个新的视角——不仅关注参数空间中局部极小值的位置,还考虑其形状特性。具体来说,该方法试图找到那些具有较低锐度(即平坦)解附近的最优解路径。研究表明,这样的解决方案往往能够提供更好的测试性能和更强的鲁棒性。 #### 数学表达形式 给定一个神经网络 \( f_\theta(x) \),其中\( \theta \)表示可调参数向量,则原始SAM可以被定义为: \[ \min_{\theta} L(\theta)+\lambda R(\theta+\epsilon),|\epsilon|_p=\delta \] 这里, - \(L\) 表示标准监督任务下的损失; - \(R\) 则衡量扰动后的最大可能增长幅度; - 参数 \(\lambda,\delta>0\) 控制着正则项强度以及允许的最大偏离程度; 对于广义版本而言,上述公式进一步扩展到多维度约束条件,并支持更灵活的距离测度方式,从而适应不同类型的任务需求[^2]。 ```python import torch from torch import nn, optim class GSAMOptimizer(optim.Optimizer): def __init__(self, params, base_optimizer, rho=0.05, adaptive=False): defaults = dict(rho=rho, adaptive=adaptive) super(GSAMOptimizer, self).__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups) @torch.no_grad() def step(self, closure=None): if closure is None: raise ValueError('closure should be provided for sharpness-aware update') # Compute perturbed parameters w' grad_norm = 0. for group in self.param_groups: for p in group['params']: if p.grad is not None: param_state = self.state[p] state.setdefault('grad', []).append(p.grad.clone()) grad_norm += torch.norm(p.grad)**2 grad_norm = grad_norm.sqrt().item() for group in self.param_groups: scale = group["rho"] / max(grad_norm, 1e-8) for p in group['params']: if p.grad is not None: e_w = (scale * p.grad).data.sign() p.add_(e_w) loss_at_perturbation = closure().float() # Restore original weights & compute gradient at w' for group in self.param_groups: for i, p in enumerate(group['params']): if 'original_params' not in self.state[p]: self.state[p]['original_params'] = p.data.clone() p.data.copy_(self.state[p]['original_params']) self.zero_grad() loss_at_original = closure().float() final_loss = loss_at_original + ((loss_at_perturbation - loss_at_original)).mean() final_loss.backward(create_graph=True) self.base_optimizer.step(closure=closure) # Example usage with SGD as a base optimizer model = ... # Define your model here base_optim = torch.optim.SGD(model.parameters(), lr=0.01) gsam = GSAMOptimizer(model.parameters(), base_optim, rho=0.05) for input_, target in data_loader: def closure(): output = model(input_) loss = criterion(output, target) gsam.zero_grad() loss.backward() return loss gsam.step(closure) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值