Pytorch自定义优化器最佳实践

在 PyTorch 中,自定义优化器需要遵循特定的规范以兼容 PyTorch 的训练流程。下面从核心方法、closure的作用,到 Ranger 优化器的实现,逐步展开说明。

一、PyTorch 自定义优化器的必要方法

自定义优化器必须继承 torch.optim.Optimizer,并实现以下核心方法:

  1. init(self, params, defaults)
  • 作用:初始化优化器,定义超参数(如学习率、动量等),并为参数组(param_groups)设置状态。
  • 参数:
    • params:模型参数迭代器(通常是model.parameters())。
    • defaults:超参数默认字典(如{‘lr’: 0.001, ‘momentum’: 0.9}),用于统一管理参数组的超参数。
  • 关键操作:需调用父类__init__方法,并初始化参数组的状态(如动量缓存、二阶矩估计等)。
  1. step(self, closure=None)
  • 作用:执行单次参数更新(核心逻辑),根据梯度和超参数计算新的参数值。
  • 参数:closure是一个可选的函数(见下文详解),主要用于需要多次计算梯度的优化器(如 LBFGS)。
  • 关键操作:
    • 遍历param_groups中的参数,获取梯度(param.grad)。
    • 根据优化器的更新规则(如 SGD 的动量更新、Adam 的自适应学习率)计算参数更新量。
    • 用param.data更新参数(需在torch.no_grad()上下文下避免构建计算图)。

二、closure 是什么?

closure 是一个函数对象,其作用是重新计算模型的损失和梯度,通常用于需要多次反向传播的优化器(如 LBFGS)。

  • 为什么需要closure?
    某些优化器(如 LBFGS)在单次参数更新中需要多次计算梯度(例如线搜索过程),closure函数会封装前向传播、损失计算和反向传播的逻辑,供优化器在需要时反复调用。
  • closure的典型实现:
    def closure():
      optimizer.zero_grad()  # 清空梯度
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()  # 重新计算梯度
      return loss
    
  • 使用场景:
    大多数优化器(如 SGD、Adam)不需要closure,但如果自定义优化器涉及线搜索或多次梯度计算,则必须在step方法中处理closure。

三、Ranger 优化器的实现(最佳实践)

Ranger(RAdam + Lookahead)是结合了 RAdam(Rectified Adam,自适应学习率)和 Lookahead(慢更新策略)的优化器,兼具收敛速度和稳定性。下面是符合 PyTorch 规范的实现:

  1. Ranger 的核心原理
    RAdam:在 Adam 基础上修正了自适应学习率的方差估计,解决了训练初期样本量不足导致的学习率不稳定问题。
    Lookahead:维护两组参数(“快参数” 和 “慢参数”),快参数按 RAdam 更新,每k步将慢参数更新为快参数的移动平均,并将快参数重置为慢参数,增强收敛稳定性。
    import torch
    from torch.optim.optimizer import Optimizer
    import math
    
    class Ranger(Optimizer): 
        """
            继承PyTorch的Optimizer基类
            所有 PyTorch 优化器必须继承torch.optim.Optimizer,该基类提供了参数管理、状态保存等基础功能,self.state和self.param_groups正是从这里继承的。
        """
        def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                    weight_decay=0, alpha=0.5, k=6, N_sma_threshhold=5):
            """
            Ranger优化器初始化
            :param params: 模型参数
            :param lr: 学习率(默认1e-3)
            :param betas: RAdam的动量参数(默认(0.9, 0.999))
            :param eps: 数值稳定性参数(默认1e-8)
            :param weight_decay: 权重衰减(默认0)
            :param alpha: Lookahead的移动平均系数(默认0.5)
            :param k: Lookahead的慢更新间隔(默认6步)
            :param N_sma_threshhold: RAdam的方差修正阈值(默认5)
            """
            # 检查超参数合法性
            if not 0.0 <= lr:
                raise ValueError(f"Invalid learning rate: {lr}")
            if not 0.0 <= eps:
                raise ValueError(f"Invalid epsilon value: {eps}")
            if not 0.0 <= betas[0] < 1.0:
                raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
            if not 0.0 <= betas[1] < 1.0:
                raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
            if not 0.0 <= alpha <= 1.0:
                raise ValueError(f"Invalid alpha parameter: {alpha}")
            if not 1 <= k:
                raise ValueError(f"Invalid k parameter: {k}")
            
            # 定义超参数默认字典
            defaults = dict(
                lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                alpha=alpha, k=k, N_sma_threshhold=N_sma_threshhold,
                # RAdam状态变量
                step=0, exp_avg=0, exp_avg_sq=0,
                # Lookahead状态变量
                slow_step=0, slow_weights=None
            )
            # 调用父类构造函数(关键!)
            #super().__init__(params, defaults)会初始化self.param_groups(参数组),并为后续状态管理提供基础。
            super().__init__(params, defaults)
    
        def __setstate__(self, state):
            """反序列化时恢复状态(兼容模型保存/加载)"""
            super().__setstate__(state)
    
        def step(self, closure=None):
            """执行单次参数更新(RAdam + Lookahead)"""
            loss = None
            if closure is not None:
                loss = closure()  # 若提供closure,先计算损失和梯度
    
            for group in self.param_groups:
                # 遍历参数组,获取超参数和状态
                lr = group['lr']
                beta1, beta2 = group['betas']
                eps = group['eps']
                weight_decay = group['weight_decay']
                alpha = group['alpha']
                k = group['k']
                N_sma_threshhold = group['N_sma_threshhold']
                # 遍历组内每个参数
                for p in group['params']:
                    if p.grad is None:
                        continue  # 无梯度的参数跳过
                    grad = p.grad.data # 获取梯度
                    state = self.state[p] # 获取当前参数的状态字典
                    
                    # 初始化参数状态(首次更新时)
                    if len(state) == 0:
                        state['step'] = 0 # 记录更新步数
                        # RAdam的一阶矩和二阶矩缓存
                        state['exp_avg'] = torch.zeros_like(p.data)
                        state['exp_avg_sq'] = torch.zeros_like(p.data)
                        # Lookahead的慢参数(初始与原参数相同)
                        state['slow_weights'] = torch.clone(p.data).detach()
                        state['slow_step'] = 0
    
                    # 取出状态变量
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    slow_weights = state['slow_weights']
                    state['step'] += 1
                    state['slow_step'] += 1
    
                    # 1. 权重衰减(可选)
                    if weight_decay != 0:
                        grad = grad.add(p.data, alpha=weight_decay)
    
                    # 2. RAdam更新(快参数更新)
                    # 一阶矩(动量)更新:exp_avg = beta1*exp_avg + (1-beta1)*grad
                    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                    # 二阶矩(方差)更新:exp_avg_sq = beta2*exp_avg_sq + (1-beta2)*grad^2
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
    
                    # 计算RAdam的方差修正项
                    beta2_t = beta2 ** state['step']
                    N_sma = 2 / (1 - beta2) - 1  # 理论上的平滑系数
                    N_sma = N_sma * (1 - beta2_t)  # 实际平滑系数
    
                    # # 根据N_sma动态调整更新方式修正学习率(根据N_sma是否超过阈值)
                    if N_sma >= N_sma_threshhold:
                        # 方差修正(避免训练初期学习率波动)
                        rect = math.sqrt((1 - beta2_t) / (exp_avg_sq + eps))
                        # 偏差修正(一阶矩)
                        step_size = lr * math.sqrt(1 - beta2_t) / (1 - beta1 ** state['step'])
                        # 应用更新:p = p - step_size * exp_avg * rect
                        p.data.addcdiv_(exp_avg, rect, value=-step_size)
                    else:
                        # # 初期用简单动量更新, N_sma较小时,使用普通动量更新
                        step_size = lr / (1 - beta1 ** state['step'])
                        p.data.add_(exp_avg, alpha=-step_size)
    
                    # 3. Lookahead更新(慢参数更新)
                    if state['slow_step'] % k == 0:
                        # 慢参数 = 慢参数 * alpha + 快参数 * (1 - alpha)
                        slow_weights.mul_(alpha).add_(p.data, alpha=1 - alpha)
                        # 将快参数重置为慢参数(同步)
                        p.data.copy_(slow_weights)
            return loss
    

四、代码解析与最佳实践

4.1 代码逻辑解析

  1. 初始化(init
  • 对超参数进行合法性校验,避免无效值(如负学习率)。
  • 通过defaults字典统一管理超参数,支持参数组(不同层使用不同超参数)。
  • 初始化 RAdam 需要的状态(exp_avg:一阶矩缓存,exp_avg_sq:二阶矩缓存)和 Lookahead 需要的状态(slow_weights:慢参数,slow_step:慢更新计数器)。
  1. 参数更新(step)
  • 权重衰减:先对梯度应用权重衰减(若启用),再进行参数更新。
  • RAdam 更新:
    • 维护一阶矩(动量)和二阶矩(方差)的指数移动平均。
    • 根据N_sma(平滑系数)动态调整学习率:当N_sma较小时(训练初期),使用简单动量更新;当N_sma足够大时,启用方差修正,实现自适应学习率。
  • Lookahead 更新:
    • 每k步将慢参数更新为快参数的移动平均(slow_weights = slow_weights * alpha + p.data * (1 - alpha))。
    • 同步快参数与慢参数,确保后续更新基于慢参数的 “稳定值”。
  1. 兼容性处理
  • 实现__setstate__方法,确保优化器状态能被正确序列化(支持torch.save保存模型时同时保存优化器状态)。
  • 对无梯度的参数(p.grad is None)跳过更新,避免错误。
  1. 核心逻辑拆解:
  • 参数组遍历:支持不同层使用不同超参数(如对卷积层和全连接层设置不同lr)。
  • 状态初始化:首次更新时为每个参数创建专属状态(如动量缓存、慢参数)。
  • RAdam 更新:通过一阶矩(动量)和二阶矩(方差)动态调整学习率,解决训练初期不稳定性。
  • Lookahead 更新:每k步同步快参数和慢参数,用慢参数的平滑特性提升收敛稳定性。

4.2 self.state和self.param_groups的详解

这两个属性均继承自Optimizer基类,是 PyTorch 优化器的核心数据结构。这两个继承来的属性是 PyTorch 优化器的基础,self.param_groups管理 “参数与超参数”,self.state管理 “参数更新状态”,缺一不可。

  1. self.param_groups:参数组管理
  • 定义:是一个列表,每个元素是一个字典,代表一组参数及其对应的超参数。
  • 来源:由Optimizer基类在__init__中初始化,基于用户传入的params和defaults构建。
  • 结构示例:
    [
    {
        'params': [param1, param2, ...],  # 模型参数(如卷积层权重)
        'lr': 0.001,                     # 该组的学习率
        'betas': (0.9, 0.999),           # 该组的动量参数
        # ... 其他超参数(来自defaults)
    },
    # 可能有多个参数组(如不同层单独设置超参数)
    ]
    
  • 作用:
    • 支持对不同参数(如特征提取层和分类头)设置不同超参数(如冻结特征提取层时将其lr设为 0)。
    • 遍历self.param_groups即可访问所有需要更新的参数。

拓展
在 PyTorch 中,self.param_groups 支持对不同参数组设置不同超参数(如学习率、权重衰减等),其核心原理是通过将模型参数划分为多个组,并为每个组单独配置超参数。这种灵活性在实际训练中非常有用(例如:对预训练模型的冻结层设置较小的学习率,对新层设置较大的学习率)。
self.param_groups 是一个列表,其中每个元素是一个字典,代表一组参数及其对应的超参数。每个字典包含两个核心键:
‘params’:该组包含的模型参数(如 nn.Linear 的权重或偏置)。
其他键:该组的超参数(如 ‘lr’、‘weight_decay’ 等,来自优化器初始化时的 defaults)。

参数分组的实现逻辑:

  • 初始化优化器时,用户可传入参数组列表(而非单一的 model.parameters())。
  • 优化器会为每个参数组创建一个字典,默认超参数来自 defaults,但用户可在分组时覆盖特定超参数。
  • 训练时,step 方法遍历 self.param_groups,对每个组内的参数应用其专属超参数进行更新。

示例场景
假设我们有一个简单的神经网络,包含:
一个预训练的特征提取层(feature_extractor),希望用较小的学习率微调。
一个新初始化的分类头(classifier),希望用较大的学习率训练。

import torch
import torch.nn as nn
from torch.optim import SGD  # 以SGD为例,Ranger等自定义优化器同样适用
# 1. 定义模型(含特征提取层和分类头)
class MyModel(nn.Module):
   def __init__(self):
       super().__init__()
       # 预训练特征提取层(假设已加载预训练权重)
       self.feature_extractor = nn.Sequential(
           nn.Linear(10, 20),
           nn.ReLU()
       )
       # 新初始化的分类头
       self.classifier = nn.Linear(20, 2)

   def forward(self, x):
       x = self.feature_extractor(x)
       return self.classifier(x)

# 2. 初始化模型
model = MyModel()

# 3. 定义参数组(核心步骤)
# 组1:特征提取层参数(小学习率,无权重衰减)
# 组2:分类头参数(大学习率,有权重衰减)
param_groups = [
   {
       'params': model.feature_extractor.parameters(),
       'lr': 1e-4,           # 特征提取层学习率:1e-4(较小)
       'weight_decay': 0.0   # 不使用权重衰减
   },
   {
       'params': model.classifier.parameters(),
       'lr': 1e-3,           # 分类头学习率:1e-3(较大)
       'weight_decay': 1e-5  # 使用权重衰减
   }
]

# 4. 初始化优化器(传入参数组)
# 注意:这里的默认lr和weight_decay会被param_groups中的设置覆盖
optimizer = SGD(
   param_groups,
   lr=0.1,          # 默认学习率(会被分组中的lr覆盖)
   momentum=0.9,    # 所有组共享的超参数(未在分组中指定)
   weight_decay=0.1 # 默认权重衰减(会被分组中的weight_decay覆盖)
)

# 5. 查看self.param_groups的结构(验证分组效果)
print("参数组结构:")
for i, group in enumerate(optimizer.param_groups):
   print(f"\n组{i}超参数:")
   print(f"学习率: {group['lr']}")
   print(f"权重衰减: {group['weight_decay']}")
   print(f"包含参数数量: {len(group['params'])}")

关键解析:

  • 分组覆盖机制:参数组中显式指定的超参数(如 lr=1e-4)会覆盖优化器初始化时的默认值(如 lr=0.1)。
  • 共享超参数:未在分组中指定的超参数(如 momentum=0.9)会沿用优化器的默认值,实现跨组共享。
  • 参数隔离:feature_extractor 和 classifier 的参数被分为两组,后续 optimizer.step() 会分别用其组内的超参数更新。

常见应用场景

  • 迁移学习:对预训练层用小学习率微调,对新层用大学习率训练。
  • 分层学习率:深层网络中,浅层参数用小学习率,深层参数用大学习率。
  • 冻结部分参数:将不需要更新的参数组的 lr 设为 0(等价于冻结)。
param_groups = [
   {'params': frozen_params, 'lr': 0.0},  # 冻结参数
   {'params': trainable_params, 'lr': 1e-3}  # 可训练参数
]
  1. self.state:参数状态管理
  • 定义:是一个字典,键为参数对象(param),值为该参数的状态字典(存储更新所需的中间变量)。
  • 来源:由Optimizer基类初始化,是一个空字典,用户需在step中为每个参数动态添加状态。
  • 结构示例:
{
    param1: {  # param1是模型中的一个参数(如nn.Linear的weight)
        'step': 100,                # 该参数已更新100步
        'exp_avg': tensor(...),     # RAdam的一阶矩缓存
        'exp_avg_sq': tensor(...),  # RAdam的二阶矩缓存
        'slow_weights': tensor(...),# Lookahead的慢参数
        'slow_step': 100            # 慢参数更新步数
    },
    # 其他参数的状态...
}
  • 作用:
    • 存储每个参数的更新历史(如动量缓存、步数计数),避免使用全局变量导致的冲突。
    • 确保参数状态与参数一一对应,支持多参数组和复杂更新逻辑(如 Ranger 的快慢参数)。

五、使用示例

import torch.nn as nn

# 定义模型
model = nn.Linear(10, 2)
# 初始化Ranger优化器
optimizer = Ranger(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    weight_decay=1e-5,
    k=6,  # 每6步更新一次慢参数
    alpha=0.5  # 慢参数移动平均系数
)

# 训练示例
inputs = torch.randn(32, 10)
labels = torch.randint(0, 2, (32,))
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    optimizer.zero_grad()  # 清空梯度
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()  # 计算梯度
    optimizer.step()  # 执行Ranger更新
    print(f"Epoch {epoch}, Loss: {loss.item()}")

六、总结

自定义优化器的核心是实现__init__(初始化状态和超参数)和step(参数更新逻辑)。closure用于需要多次梯度计算的场景,大多数优化器(包括 Ranger)无需依赖它。Ranger 的实现结合了 RAdam 的自适应学习率和 Lookahead 的慢更新策略,代码中需注意状态维护、超参数校验和兼容性处理,以确保在 PyTorch 生态中正常工作。

七、拓展

在混合精度训练的时候,需要注意优化器尽量使用FP32,如果不在代码里强制指定数据类型,Pytorch会默认优化器的中间状态为指定的精度FP16。
For Example:
如果是混合精度训练时,使用Adam自带优化器的话,可以强制使用FP32

optimizer = Adam([p.float() for p in model.parameters()], lr=1e-3)

自定义Ranger优化器

import math
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer

Params = Union[Iterable[torch.Tensor], Iterable[dict]]

LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]


class Ranger(Optimizer):
    """
    Implements Ranger optimization algorithm (Lookahead with RAdam).

    Implementation is modified version from ``pytorch-ranger`` package which build upon
    its `original implementation <https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer>`_.
    Ranger seems to be benefiting most models.

    Args:
        params: iterable of parameters to optimize or dicts defining
            parameter groups
        lr: learning rate (default: 1e-3)
        alpha: linear interpolation factor. 1.0 recovers the inner optimizer.
            (default: 0.5)
        k: number of lookahead steps (default: 6)
        N_sma_threshhold: Maximum length of the simple moving average (SMA)
        betas: coefficients used for computing
            running averages of gradient and its square (default: (0.95, 0))
        eps: term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay: weight decay (L2 penalty) (default: 0)

    Example:
        >>> from pytorch_forecasting.optim import Ranger
        >>> optimizer =  Ranger(model.parameters(), lr=0.1)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        >>> optimizer.step()
        >>> scheduler.step()
    """

    def __init__(
        self,
        params: Params,
        lr: float = 1e-3,
        alpha: float = 0.5,
        k: int = 6,
        N_sma_threshhold: int = 5,
        betas: Betas2 = (0.95, 0.999),
        eps: float = 1e-5,
        weight_decay: float = 0,
    ):
        # parameter checks
        if not 0.0 <= alpha <= 1.0:
            raise ValueError("Invalid slow update rate: {}".format(alpha))
        if not 1 <= k:
            raise ValueError("Invalid lookahead steps: {}".format(k))
        if not lr > 0:
            raise ValueError("Invalid Learning Rate: {}".format(lr))
        if not eps > 0:
            raise ValueError("Invalid eps: {}".format(eps))

        # parameter comments:
        # beta1 (momentum) of .95 seems to work better than .90...
        # N_sma_threshold of 5 seems better in testing than 4.
        # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to
        # make sure which works best for you.

        # prep defaults and init torch.optim base
        defaults = dict(
            lr=lr,
            alpha=alpha,
            k=k,
            step_counter=0,
            betas=betas,
            N_sma_threshhold=N_sma_threshhold,
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

        # adjustable threshold
        self.N_sma_threshhold = N_sma_threshhold

        # now we can get to work...
        # removed as we now use step from RAdam...no need for
        # duplicate step counting
        # for group in self.param_groups:
        #    group["step_counter"] = 0
        # print("group step counter init")

        # look ahead params
        self.alpha = alpha
        self.k = k

        # radam buffer for state
        self.radam_buffer = [[None, None, None] for ind in range(10)]

        # self.first_run_check=0

        # lookahead weights
        # 9/2/19 - lookahead param tensors have been moved to state storage.
        # This should resolve issues with load/save where weights were left in
        # GPU memory from first load, slowing down future runs.

        # self.slow_weights = [[p.clone().detach() for p in group['params']]
        #                     for group in self.param_groups]

        # don't use grad for lookahead weights
        # for w in it.chain(*self.slow_weights):
        #    w.requires_grad = False
    def __setstate__(self, state: dict) -> None:
        super().__setstate__(state)
    
    def step(self, closure: OptLossClosure = None) -> OptFloat:
        r"""Performs a single optimization step.
        Arguments:
            closure: A closure that reevaluates the model and returns the loss.
        """
        if closure is not None:
            _ = closure()
        loss = None
        # note - below is commented out b/c I have other work that passes back
        # the loss as a float, and thus not a callable closure.
        # Uncomment if you need to use the actual closure...

        # if closure is not None:
        # loss = closure()
        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError("Ranger optimizer does not support " "sparse gradients")

                p_data_fp32 = p.data.float()

                state = self.state[p]  # get state dict for this param

                if len(state) == 0:  # if first time to run...init dictionary
                    # with our desired entries
                    # if self.first_run_check==0:
                    # self.first_run_check=1
                    # print("Initializing slow buffer...should not see this
                    # at load from saved model!")
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p_data_fp32)
                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state["slow_buffer"] = torch.empty_like(p.data)
                    state["slow_buffer"].copy_(p.data)

                else:
                    state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                # compute mean moving avg
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state["step"] += 1

                buffered = self.radam_buffer[int(state["step"] % 10)]
                if state["step"] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state["step"]
                    beta2_t = beta2 ** state["step"]
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt(
                            (1 - beta2_t)
                            * (N_sma - 4)
                            / (N_sma_max - 4)
                            * (N_sma - 2)
                            / N_sma
                            * N_sma_max
                            / (N_sma_max - 2)
                        ) / (1 - beta1 ** state["step"])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state["step"])
                    buffered[2] = step_size

                if group["weight_decay"] != 0:
                    p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])

                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
                else:
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])

                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state["step"] % group["k"] == 0:
                    slow_p = state["slow_buffer"]  # get access to slow param tensor
                    slow_p.add_(p.data - slow_p, alpha=self.alpha)  # (fast weights - slow weights) * alpha
                    p.data.copy_(slow_p)  # copy interpolated weights to RAdam param tensor
        return loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值