pytorch优化器Adam和Adamw中的weight_decay的区别

本文详细介绍了PyTorch中的优化器Adam和AdamW,重点讨论了两者在weight_decay上的区别。Adam中的weight_decay实际上实现了L2正则化,而AdamW则是对参数直接进行约束,以防止过拟合。两者的更新方式不同,但都旨在控制权重大小。
PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

pytorch优化器Adam和Adamw

只介绍二者的weight_decay的区别

一般的梯度下降的方法是θj=θj−α∂∂θjJ\theta_j=\theta_j-\alpha\frac{\partial}{\partial\theta_j}Jθj=θjαθjJ

Adam

算法

给出pytorch.optim.Adam类中具体实现的算法,来自pytorch中class Adam(Optimizer)
input:γ (lr),β1,β2 (betas),θ0 (params),f(θ) (objective)λ (weight decay), amsgradinitialize:m0←0 ( first moment),v0←0 (second moment), v0^max←0for t=1 to … dogt←∇θft(θt−1)if λ≠0gt←gt+λθt−1mt←β1mt−1+(1−β1)gtvt←β2vt−1+(1−β2)gt2mt^←mt/(1−β1t)vt^←vt/(1−β2t)if amsgradvt^max←max(vt^max,vt^)θt←θt−1−γmt^/(vt^max+ϵ)elseθt←θt−1−γmt^/(vt^+ϵ)return θt \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ &\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\textbf{if} \: amsgrad \\ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, \widehat{v_t}) \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} input:γ (lr),β1,β2 (betas),θ0 (params),f(θ) (objective)λ (weight decay),amsgradinitialize:m00 ( first moment),v00 (second moment),v0max0fort=1todogtθft(θt1)ifλ=0gtgt+λθt1mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2mtmt/(1β1t)vtvt/(1β2t)ifamsgradvtmaxmax(vtmax,vt)θtθt1γmt/(vtmax+ϵ)elseθtθt1γmt/(vt+ϵ)returnθt
Adam中的weight_decay实际上实现的是L2 norm,对梯度计算时加上了λθt−1\lambda \theta_{t-1}λθt1,相当于对cost function加上了L2范数,J=cost+λ2∣∣wi∣∣2J=cost+\frac{\lambda}{2}||w_i||^2J=cost+2λwi2,这样对JJJ求梯度gti=∂J∂θi+λθigt_i=\frac{\partial J}{\partial \theta_i}+\lambda \theta_igti=θiJ+λθi,与算法思想一致

Adamw

算法

来自pytorch class AdamW(Optimizer):

input:γ(lr), β1,β2(betas), θ0(params), f(θ)(objective), ϵ (epsilon)λ(weight decay), amsgradinitialize:m0←0 (first moment),v0←0 ( second moment), v0^max←0for t=1 to … dogt←∇θft(θt−1)θt←θt−1−γλθt−1mt←β1mt−1+(1−β1)gtvt←β2vt−1+(1−β2)gt2mt^←mt/(1−β1t)vt^←vt/(1−β2t)if amsgradvt^max←max(vt^max,vt^)θt←θt−1−γmt^/(vt^max+ϵ)elseθt←θt−1−γmt^/(vt^+ϵ)return θt \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \: \epsilon \text{ (epsilon)} \\ &\hspace{13mm} \lambda \text{(weight decay)}, \: amsgrad \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\textbf{if} \: amsgrad \\ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, \widehat{v_t}) \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} input:γ(lr),β1,β2(betas),θ0(params),f(θ)(objective),ϵ (epsilon)λ(weight decay),amsgradinitialize:m00 (first moment),v00 ( second moment),v0max0fort=1todogtθft(θt1)θtθt1γλθt1mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2mtmt/(1β1t)vtvt/(1β2t)ifamsgradvtmaxmax(vtmax,vt)θtθt1γmt/(vtmax+ϵ)elseθtθt1γmt/(vt+ϵ)returnθt
Adamw中的weight_dacay则是对第一项θj\theta_jθj进行一个约束,即θj=(1−lr∗wd)∗θj−α∂∂θjJ\theta_j=(1-lr*wd)*\theta_j-\alpha\frac{\partial}{\partial\theta_j}Jθj=(1lrwd)θjαθjJ,更新t时刻的参数时对t-1时刻的参数进行一个约束。

总结

二者都是为了限制权重,使其尽量的小,避免过拟合的发生,在权重更新方式上有所不同。

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值