PyTorch中Adam的实现

本文详细介绍了PyTorch中Adam优化器的实现原理和代码解析,包括参数更新、动量、指数移动平均等关键步骤。适用于深度学习模型的训练优化,通过设置不同的超参数如学习率、衰减率等,可以有效提升模型的训练效果。
部署运行你感兴趣的模型镜像
import torch
from . import _functional as F
from .optimizer import Optimizer


class Adam(Optimizer):
    r"""Implements Adam algorithm.

    It has been proposed in `Adam: A Method for Stochastic Optimization`_.
    The implementation of the L2 penalty follows changes proposed in
    `Decoupled Weight Decay Regularization`_.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)

    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super(Adam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

[docs]    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            F.adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   amsgrad=group['amsgrad'],
                   beta1=beta1,
                   beta2=beta2,
                   lr=group['lr'],
                   weight_decay=group['weight_decay'],
                   eps=group['eps'])
        return loss

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

### PyTorchAdam 优化器的用法与实现 #### 使用方法 在 PyTorch 的 `torch.optim` 模块中,Adam 是一种常用的自适应学习率优化算法。它可以通过简单的配置来初始化和应用。以下是一个基本的例子展示如何创建 Adam 优化器实例: ```python import torch from torch import nn, optim model = nn.Linear(10, 1) # 假设模型为一个线性层 optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) # 初始化 Adam 优化器 ``` 上述代码展示了如何通过传递模型参数以及超参字典给 Adam 构造函数完成优化器的定义[^2]。 #### 高级功能 除了基础的学习率 (`lr`) 和动量系数 (`betas`) 外,还可以设置梯度裁剪选项以防止梯度过大影响训练稳定性。例如,在 Pyro 库中的高级用法如下所示: ```python optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0}) ``` 此片段扩展了标准 Adam 设置的功能范围,允许额外控制梯度范数的最大阈值。 #### 实现细节 Adam 的核心思想在于动态调整每个权重的学习速率,利用一阶矩估计(均值)和二阶矩估计(方差)。具体而言,其更新规则可以表示为: \[ m_t = \beta_1 m_{t-1} + (1-\beta_1)\nabla f(\theta_{t-1}) \] \[ v_t = \beta_2 v_{t-1} + (1-\beta_2)(\nabla f(\theta_{t-1}))^2 \] 其中 \(m\) 表示移动平均的一阶矩向量,\(v\) 则代表未中心化的二阶矩向量。最终参数更新公式为: \[ \theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon} \] 这些计算过程被封装于 PyTorch 的源码内部,开发者无需手动编写即可调用高效的 C++ 后端执行[^1]。 对于更复杂的场景,比如混合精度训练或者分布式环境下的性能提升需求,则可考虑引入 Apex 提供的融合版 Adam 或其他变体形式。 #### 思维导图辅助理解复杂框架内的 Adam 调整 当面对像 NeRF 这样的大型深度学习项目时,不仅需要掌握单个组件的工作原理,还需要熟悉整体架构设计思路及其相互作用关系。例如 instant-ngP CUDA 版本实现了高效渲染管线的同时也依赖恰当调节好的 Adam 参数来进行高质量图像合成任务迭代求解[^3]。
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值