在 PyTorch 中,自定义优化器需要遵循特定的规范以兼容 PyTorch 的训练流程。下面从核心方法、closure的作用,到 Ranger 优化器的实现,逐步展开说明。
一、PyTorch 自定义优化器的必要方法
自定义优化器必须继承 torch.optim.Optimizer,并实现以下核心方法:
- init(self, params, defaults)
- 作用:初始化优化器,定义超参数(如学习率、动量等),并为参数组(param_groups)设置状态。
- 参数:
- params:模型参数迭代器(通常是model.parameters())。
- defaults:超参数默认字典(如{‘lr’: 0.001, ‘momentum’: 0.9}),用于统一管理参数组的超参数。
- 关键操作:需调用父类__init__方法,并初始化参数组的状态(如动量缓存、二阶矩估计等)。
- step(self, closure=None)
- 作用:执行单次参数更新(核心逻辑),根据梯度和超参数计算新的参数值。
- 参数:closure是一个可选的函数(见下文详解),主要用于需要多次计算梯度的优化器(如 LBFGS)。
- 关键操作:
- 遍历param_groups中的参数,获取梯度(param.grad)。
- 根据优化器的更新规则(如 SGD 的动量更新、Adam 的自适应学习率)计算参数更新量。
- 用param.data更新参数(需在torch.no_grad()上下文下避免构建计算图)。
二、closure 是什么?
closure 是一个函数对象,其作用是重新计算模型的损失和梯度,通常用于需要多次反向传播的优化器(如 LBFGS)。
- 为什么需要closure?
某些优化器(如 LBFGS)在单次参数更新中需要多次计算梯度(例如线搜索过程),closure函数会封装前向传播、损失计算和反向传播的逻辑,供优化器在需要时反复调用。 - closure的典型实现:
def closure(): optimizer.zero_grad() # 清空梯度 outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() # 重新计算梯度 return loss
- 使用场景:
大多数优化器(如 SGD、Adam)不需要closure,但如果自定义优化器涉及线搜索或多次梯度计算,则必须在step方法中处理closure。
三、Ranger 优化器的实现(最佳实践)
Ranger(RAdam + Lookahead)是结合了 RAdam(Rectified Adam,自适应学习率)和 Lookahead(慢更新策略)的优化器,兼具收敛速度和稳定性。下面是符合 PyTorch 规范的实现:
- Ranger 的核心原理
RAdam:在 Adam 基础上修正了自适应学习率的方差估计,解决了训练初期样本量不足导致的学习率不稳定问题。
Lookahead:维护两组参数(“快参数” 和 “慢参数”),快参数按 RAdam 更新,每k步将慢参数更新为快参数的移动平均,并将快参数重置为慢参数,增强收敛稳定性。import torch from torch.optim.optimizer import Optimizer import math class Ranger(Optimizer): """ 继承PyTorch的Optimizer基类 所有 PyTorch 优化器必须继承torch.optim.Optimizer,该基类提供了参数管理、状态保存等基础功能,self.state和self.param_groups正是从这里继承的。 """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, alpha=0.5, k=6, N_sma_threshhold=5): """ Ranger优化器初始化 :param params: 模型参数 :param lr: 学习率(默认1e-3) :param betas: RAdam的动量参数(默认(0.9, 0.999)) :param eps: 数值稳定性参数(默认1e-8) :param weight_decay: 权重衰减(默认0) :param alpha: Lookahead的移动平均系数(默认0.5) :param k: Lookahead的慢更新间隔(默认6步) :param N_sma_threshhold: RAdam的方差修正阈值(默认5) """ # 检查超参数合法性 if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= alpha <= 1.0: raise ValueError(f"Invalid alpha parameter: {alpha}") if not 1 <= k: raise ValueError(f"Invalid k parameter: {k}") # 定义超参数默认字典 defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, alpha=alpha, k=k, N_sma_threshhold=N_sma_threshhold, # RAdam状态变量 step=0, exp_avg=0, exp_avg_sq=0, # Lookahead状态变量 slow_step=0, slow_weights=None ) # 调用父类构造函数(关键!) #super().__init__(params, defaults)会初始化self.param_groups(参数组),并为后续状态管理提供基础。 super().__init__(params, defaults) def __setstate__(self, state): """反序列化时恢复状态(兼容模型保存/加载)""" super().__setstate__(state) def step(self, closure=None): """执行单次参数更新(RAdam + Lookahead)""" loss = None if closure is not None: loss = closure() # 若提供closure,先计算损失和梯度 for group in self.param_groups: # 遍历参数组,获取超参数和状态 lr = group['lr'] beta1, beta2 = group['betas'] eps = group['eps'] weight_decay = group['weight_decay'] alpha = group['alpha'] k = group['k'] N_sma_threshhold = group['N_sma_threshhold'] # 遍历组内每个参数 for p in group['params']: if p.grad is None: continue # 无梯度的参数跳过 grad = p.grad.data # 获取梯度 state = self.state[p] # 获取当前参数的状态字典 # 初始化参数状态(首次更新时) if len(state) == 0: state['step'] = 0 # 记录更新步数 # RAdam的一阶矩和二阶矩缓存 state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data) # Lookahead的慢参数(初始与原参数相同) state['slow_weights'] = torch.clone(p.data).detach() state['slow_step'] = 0 # 取出状态变量 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] slow_weights = state['slow_weights'] state['step'] += 1 state['slow_step'] += 1 # 1. 权重衰减(可选) if weight_decay != 0: grad = grad.add(p.data, alpha=weight_decay) # 2. RAdam更新(快参数更新) # 一阶矩(动量)更新:exp_avg = beta1*exp_avg + (1-beta1)*grad exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # 二阶矩(方差)更新:exp_avg_sq = beta2*exp_avg_sq + (1-beta2)*grad^2 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # 计算RAdam的方差修正项 beta2_t = beta2 ** state['step'] N_sma = 2 / (1 - beta2) - 1 # 理论上的平滑系数 N_sma = N_sma * (1 - beta2_t) # 实际平滑系数 # # 根据N_sma动态调整更新方式修正学习率(根据N_sma是否超过阈值) if N_sma >= N_sma_threshhold: # 方差修正(避免训练初期学习率波动) rect = math.sqrt((1 - beta2_t) / (exp_avg_sq + eps)) # 偏差修正(一阶矩) step_size = lr * math.sqrt(1 - beta2_t) / (1 - beta1 ** state['step']) # 应用更新:p = p - step_size * exp_avg * rect p.data.addcdiv_(exp_avg, rect, value=-step_size) else: # # 初期用简单动量更新, N_sma较小时,使用普通动量更新 step_size = lr / (1 - beta1 ** state['step']) p.data.add_(exp_avg, alpha=-step_size) # 3. Lookahead更新(慢参数更新) if state['slow_step'] % k == 0: # 慢参数 = 慢参数 * alpha + 快参数 * (1 - alpha) slow_weights.mul_(alpha).add_(p.data, alpha=1 - alpha) # 将快参数重置为慢参数(同步) p.data.copy_(slow_weights) return loss
四、代码解析与最佳实践
4.1 代码逻辑解析
- 初始化(init)
- 对超参数进行合法性校验,避免无效值(如负学习率)。
- 通过defaults字典统一管理超参数,支持参数组(不同层使用不同超参数)。
- 初始化 RAdam 需要的状态(exp_avg:一阶矩缓存,exp_avg_sq:二阶矩缓存)和 Lookahead 需要的状态(slow_weights:慢参数,slow_step:慢更新计数器)。
- 参数更新(step)
- 权重衰减:先对梯度应用权重衰减(若启用),再进行参数更新。
- RAdam 更新:
- 维护一阶矩(动量)和二阶矩(方差)的指数移动平均。
- 根据N_sma(平滑系数)动态调整学习率:当N_sma较小时(训练初期),使用简单动量更新;当N_sma足够大时,启用方差修正,实现自适应学习率。
- Lookahead 更新:
- 每k步将慢参数更新为快参数的移动平均(slow_weights = slow_weights * alpha + p.data * (1 - alpha))。
- 同步快参数与慢参数,确保后续更新基于慢参数的 “稳定值”。
- 兼容性处理
- 实现__setstate__方法,确保优化器状态能被正确序列化(支持torch.save保存模型时同时保存优化器状态)。
- 对无梯度的参数(p.grad is None)跳过更新,避免错误。
- 核心逻辑拆解:
- 参数组遍历:支持不同层使用不同超参数(如对卷积层和全连接层设置不同lr)。
- 状态初始化:首次更新时为每个参数创建专属状态(如动量缓存、慢参数)。
- RAdam 更新:通过一阶矩(动量)和二阶矩(方差)动态调整学习率,解决训练初期不稳定性。
- Lookahead 更新:每k步同步快参数和慢参数,用慢参数的平滑特性提升收敛稳定性。
4.2 self.state和self.param_groups的详解
这两个属性均继承自Optimizer基类,是 PyTorch 优化器的核心数据结构。这两个继承来的属性是 PyTorch 优化器的基础,self.param_groups管理 “参数与超参数”,self.state管理 “参数更新状态”,缺一不可。
- self.param_groups:参数组管理
- 定义:是一个列表,每个元素是一个字典,代表一组参数及其对应的超参数。
- 来源:由Optimizer基类在__init__中初始化,基于用户传入的params和defaults构建。
- 结构示例:
[ { 'params': [param1, param2, ...], # 模型参数(如卷积层权重) 'lr': 0.001, # 该组的学习率 'betas': (0.9, 0.999), # 该组的动量参数 # ... 其他超参数(来自defaults) }, # 可能有多个参数组(如不同层单独设置超参数) ]
- 作用:
- 支持对不同参数(如特征提取层和分类头)设置不同超参数(如冻结特征提取层时将其lr设为 0)。
- 遍历self.param_groups即可访问所有需要更新的参数。
拓展
在 PyTorch 中,self.param_groups 支持对不同参数组设置不同超参数(如学习率、权重衰减等),其核心原理是通过将模型参数划分为多个组,并为每个组单独配置超参数。这种灵活性在实际训练中非常有用(例如:对预训练模型的冻结层设置较小的学习率,对新层设置较大的学习率)。
self.param_groups 是一个列表,其中每个元素是一个字典,代表一组参数及其对应的超参数。每个字典包含两个核心键:
‘params’:该组包含的模型参数(如 nn.Linear 的权重或偏置)。
其他键:该组的超参数(如 ‘lr’、‘weight_decay’ 等,来自优化器初始化时的 defaults)。参数分组的实现逻辑:
- 初始化优化器时,用户可传入参数组列表(而非单一的 model.parameters())。
- 优化器会为每个参数组创建一个字典,默认超参数来自 defaults,但用户可在分组时覆盖特定超参数。
- 训练时,step 方法遍历 self.param_groups,对每个组内的参数应用其专属超参数进行更新。
示例场景
假设我们有一个简单的神经网络,包含:
一个预训练的特征提取层(feature_extractor),希望用较小的学习率微调。
一个新初始化的分类头(classifier),希望用较大的学习率训练。import torch import torch.nn as nn from torch.optim import SGD # 以SGD为例,Ranger等自定义优化器同样适用 # 1. 定义模型(含特征提取层和分类头) class MyModel(nn.Module): def __init__(self): super().__init__() # 预训练特征提取层(假设已加载预训练权重) self.feature_extractor = nn.Sequential( nn.Linear(10, 20), nn.ReLU() ) # 新初始化的分类头 self.classifier = nn.Linear(20, 2) def forward(self, x): x = self.feature_extractor(x) return self.classifier(x) # 2. 初始化模型 model = MyModel() # 3. 定义参数组(核心步骤) # 组1:特征提取层参数(小学习率,无权重衰减) # 组2:分类头参数(大学习率,有权重衰减) param_groups = [ { 'params': model.feature_extractor.parameters(), 'lr': 1e-4, # 特征提取层学习率:1e-4(较小) 'weight_decay': 0.0 # 不使用权重衰减 }, { 'params': model.classifier.parameters(), 'lr': 1e-3, # 分类头学习率:1e-3(较大) 'weight_decay': 1e-5 # 使用权重衰减 } ] # 4. 初始化优化器(传入参数组) # 注意:这里的默认lr和weight_decay会被param_groups中的设置覆盖 optimizer = SGD( param_groups, lr=0.1, # 默认学习率(会被分组中的lr覆盖) momentum=0.9, # 所有组共享的超参数(未在分组中指定) weight_decay=0.1 # 默认权重衰减(会被分组中的weight_decay覆盖) ) # 5. 查看self.param_groups的结构(验证分组效果) print("参数组结构:") for i, group in enumerate(optimizer.param_groups): print(f"\n组{i}超参数:") print(f"学习率: {group['lr']}") print(f"权重衰减: {group['weight_decay']}") print(f"包含参数数量: {len(group['params'])}")
关键解析:
- 分组覆盖机制:参数组中显式指定的超参数(如 lr=1e-4)会覆盖优化器初始化时的默认值(如 lr=0.1)。
- 共享超参数:未在分组中指定的超参数(如 momentum=0.9)会沿用优化器的默认值,实现跨组共享。
- 参数隔离:feature_extractor 和 classifier 的参数被分为两组,后续 optimizer.step() 会分别用其组内的超参数更新。
常见应用场景
- 迁移学习:对预训练层用小学习率微调,对新层用大学习率训练。
- 分层学习率:深层网络中,浅层参数用小学习率,深层参数用大学习率。
- 冻结部分参数:将不需要更新的参数组的 lr 设为 0(等价于冻结)。
param_groups = [ {'params': frozen_params, 'lr': 0.0}, # 冻结参数 {'params': trainable_params, 'lr': 1e-3} # 可训练参数 ]
- self.state:参数状态管理
- 定义:是一个字典,键为参数对象(param),值为该参数的状态字典(存储更新所需的中间变量)。
- 来源:由Optimizer基类初始化,是一个空字典,用户需在step中为每个参数动态添加状态。
- 结构示例:
{
param1: { # param1是模型中的一个参数(如nn.Linear的weight)
'step': 100, # 该参数已更新100步
'exp_avg': tensor(...), # RAdam的一阶矩缓存
'exp_avg_sq': tensor(...), # RAdam的二阶矩缓存
'slow_weights': tensor(...),# Lookahead的慢参数
'slow_step': 100 # 慢参数更新步数
},
# 其他参数的状态...
}
- 作用:
- 存储每个参数的更新历史(如动量缓存、步数计数),避免使用全局变量导致的冲突。
- 确保参数状态与参数一一对应,支持多参数组和复杂更新逻辑(如 Ranger 的快慢参数)。
五、使用示例
import torch.nn as nn
# 定义模型
model = nn.Linear(10, 2)
# 初始化Ranger优化器
optimizer = Ranger(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
weight_decay=1e-5,
k=6, # 每6步更新一次慢参数
alpha=0.5 # 慢参数移动平均系数
)
# 训练示例
inputs = torch.randn(32, 10)
labels = torch.randint(0, 2, (32,))
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
optimizer.zero_grad() # 清空梯度
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward() # 计算梯度
optimizer.step() # 执行Ranger更新
print(f"Epoch {epoch}, Loss: {loss.item()}")
六、总结
自定义优化器的核心是实现__init__(初始化状态和超参数)和step(参数更新逻辑)。closure用于需要多次梯度计算的场景,大多数优化器(包括 Ranger)无需依赖它。Ranger 的实现结合了 RAdam 的自适应学习率和 Lookahead 的慢更新策略,代码中需注意状态维护、超参数校验和兼容性处理,以确保在 PyTorch 生态中正常工作。
七、拓展
在混合精度训练的时候,需要注意优化器尽量使用FP32,如果不在代码里强制指定数据类型,Pytorch会默认优化器的中间状态为指定的精度FP16。
For Example:
如果是混合精度训练时,使用Adam自带优化器的话,可以强制使用FP32
optimizer = Adam([p.float() for p in model.parameters()], lr=1e-3)
自定义Ranger优化器
import math
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
Params = Union[Iterable[torch.Tensor], Iterable[dict]]
LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]
class Ranger(Optimizer):
"""
Implements Ranger optimization algorithm (Lookahead with RAdam).
Implementation is modified version from ``pytorch-ranger`` package which build upon
its `original implementation <https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer>`_.
Ranger seems to be benefiting most models.
Args:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate (default: 1e-3)
alpha: linear interpolation factor. 1.0 recovers the inner optimizer.
(default: 0.5)
k: number of lookahead steps (default: 6)
N_sma_threshhold: Maximum length of the simple moving average (SMA)
betas: coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps: term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay: weight decay (L2 penalty) (default: 0)
Example:
>>> from pytorch_forecasting.optim import Ranger
>>> optimizer = Ranger(model.parameters(), lr=0.1)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
>>> optimizer.step()
>>> scheduler.step()
"""
def __init__(
self,
params: Params,
lr: float = 1e-3,
alpha: float = 0.5,
k: int = 6,
N_sma_threshhold: int = 5,
betas: Betas2 = (0.95, 0.999),
eps: float = 1e-5,
weight_decay: float = 0,
):
# parameter checks
if not 0.0 <= alpha <= 1.0:
raise ValueError("Invalid slow update rate: {}".format(alpha))
if not 1 <= k:
raise ValueError("Invalid lookahead steps: {}".format(k))
if not lr > 0:
raise ValueError("Invalid Learning Rate: {}".format(lr))
if not eps > 0:
raise ValueError("Invalid eps: {}".format(eps))
# parameter comments:
# beta1 (momentum) of .95 seems to work better than .90...
# N_sma_threshold of 5 seems better in testing than 4.
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to
# make sure which works best for you.
# prep defaults and init torch.optim base
defaults = dict(
lr=lr,
alpha=alpha,
k=k,
step_counter=0,
betas=betas,
N_sma_threshhold=N_sma_threshhold,
eps=eps,
weight_decay=weight_decay,
)
super().__init__(params, defaults)
# adjustable threshold
self.N_sma_threshhold = N_sma_threshhold
# now we can get to work...
# removed as we now use step from RAdam...no need for
# duplicate step counting
# for group in self.param_groups:
# group["step_counter"] = 0
# print("group step counter init")
# look ahead params
self.alpha = alpha
self.k = k
# radam buffer for state
self.radam_buffer = [[None, None, None] for ind in range(10)]
# self.first_run_check=0
# lookahead weights
# 9/2/19 - lookahead param tensors have been moved to state storage.
# This should resolve issues with load/save where weights were left in
# GPU memory from first load, slowing down future runs.
# self.slow_weights = [[p.clone().detach() for p in group['params']]
# for group in self.param_groups]
# don't use grad for lookahead weights
# for w in it.chain(*self.slow_weights):
# w.requires_grad = False
def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is not None:
_ = closure()
loss = None
# note - below is commented out b/c I have other work that passes back
# the loss as a float, and thus not a callable closure.
# Uncomment if you need to use the actual closure...
# if closure is not None:
# loss = closure()
# Evaluate averages and grad, update param tensors
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError("Ranger optimizer does not support " "sparse gradients")
p_data_fp32 = p.data.float()
state = self.state[p] # get state dict for this param
if len(state) == 0: # if first time to run...init dictionary
# with our desired entries
# if self.first_run_check==0:
# self.first_run_check=1
# print("Initializing slow buffer...should not see this
# at load from saved model!")
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
# look ahead weight storage now in state dict
state["slow_buffer"] = torch.empty_like(p.data)
state["slow_buffer"].copy_(p.data)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
# begin computations
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
# compute variance mov avg
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# compute mean moving avg
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state["step"] += 1
buffered = self.radam_buffer[int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
) / (1 - beta1 ** state["step"])
else:
step_size = 1.0 / (1 - beta1 ** state["step"])
buffered[2] = step_size
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
if N_sma > self.N_sma_threshhold:
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
else:
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
# integrated look ahead...
# we do it at the param level instead of group level
if state["step"] % group["k"] == 0:
slow_p = state["slow_buffer"] # get access to slow param tensor
slow_p.add_(p.data - slow_p, alpha=self.alpha) # (fast weights - slow weights) * alpha
p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
return loss