关于sampling softmax 中重要性采样的论文阅读笔记

本文探讨了在词向量训练、神经网络语言模型及机器翻译等任务中,如何通过重要性采样来优化softmax函数的计算效率。通过数学推导,详细介绍了如何避免直接计算softmax分母,从而减少计算量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先列出参考资料:

  1. word2vec Parameter Learning Explained
  2. Quick Training of Probabilistic Neural Nets by Importance Sampling
  3. On Using Very Large Target Vocabulary for Neural Machine Translation
  4. Adaptive importance sampling to accelerate training of a neural probabilistic language model. IEEE Transactions on Neural Networks
  5. Sebastian Ruder’s blog(博客很不错)

主要是对重要性采样softmax的学习过程做一些笔记。
在词向量训练、神经网络语言模型、神经网络机器翻译等任务中,softmax函数有如下形式:

p(w|c)=exp(hvw)wiVexp(hvwi)=exp(hvw)Z(h)

其中,h是导数层的输出, vwi 是w对应的输出词向量(即softmax的权重矩阵,具体可参考 (1)),V是词典,c是上下文。 在神经网络语言模型中,一般会把c压缩为h
sampling softmax解决了softmax分母部分计算量大的问题。

如果损失函数采用交叉熵损失函数:

H(q,p)=xq(x)logp(x)

这里q是真是期望分布,例如 q=[0,1,,0] ,p是模型输出分布,对应上面的softmax公式。
对于一个样本,可得交叉熵损失函数(这里把模型的参数统称为 θ )
Jθ=logexp(hvw)wiVexp(hvwi)

根据简单的
logxy=logxlogy

得到:
Jθ=hvw+logwiVexp(hvwi)

令:
E(w)=hvw

得到:
Jθ=E(w)+logwiVexp(E(wi))

θ 求梯度得:

θJθ=θE(w)+θlogwiVexp(E(wi))

根据:
xlogx=1x

得:
θJθ=θE(w)+1wiVexp(E(wi))θwiVexp(E(wi)

θJθ=θE(w)+1wiVexp(E(wi))wiVθexp(E(wi))

根据 xexp(x)=exp(x) ,继续利用求导链式法则:
θJθ=θE(w)+1wiVexp(E(wi))wiVexp(E(wi))θ(E(wi))

θJθ=θE(w)+wiVexp(E(wi))wiVexp(E(wi))θ(E(wi))

可以看到 exp(E(wi))wiVexp(E(wi)) 就是softmax的输出,即 wi 的概率 P(wi)
于是得到如下:
θJθ=θE(w)+wiVP(wi)θ(E(wi))

最终得到:
θJθ=θE(w)wiVP(wi)θE(wi)

对于梯度公式的第二部分,可以认为是 θE(wi) 对于softmax输出 P(wi) 的期望,即:
wiVP(wi)θE(wi)=EwiP[θE(wi)]

这就是采样要优化的部分。
根据传统的重要性采样方法,按照如下公式计算期望:
EwiP[θE(wi)]1NwiQ(w)P(wi)Q(wi)θE(wi)

其中N是从分布Q(我们自己定义的一个容易采样的分布)中采样的样本数,但是这种方法仍然需要计算 P(wi) ,而 P(wi) 的计算又需要softmax做归一化,这是我们不想看到的,所以要使用一种有偏估计的方法。


现在,让我们来观察Softmax公式( exp(E(wi))wiVexp(E(wi)) )的分母部分:

Z(h)=wiVexp(E(wi))=MwiV(1M)exp(E(wi))

这样,我们可以把 wiV(1M)exp(E(wi)) 看出是一种期望的形式,进而可以采用采样的方法得到 Z(h) 。现在我们还是取候选分布为Q。
则:
Z(h)=Z^(h)=MNwiQ(w)R^(wi)exp(E(wi))Q(wi)=MNwiQ(w)exp(E(wi))MQ(wi)

上式中的 R^(wi) 代表概率 1M ,约去M可得:
Z^(h)=1NwiQ(w)exp(E(wi))Q(wi)

到这里,我们就可以用 Z^(h) 去近似 Z(h) 了。
现在理一下思路:
给定候选分布Q,传统采样方法需要计算P,也就是说需要计算分母Z,这是我们不想看到的。幸运的是分母Z仍然可以通过采样得到,采样Z的时候,仍然采用候选分布Q。

现在继续计算 θJθ 中系数为负的部分,即期望部分。

EwiP[θE(wi)]1NwiQ(w)P(wi)Q(wi)θE(wi)=1NwiQ(w)P^(wi)Q(wi)θE(wi)

其中 P^(wi) 代表采样方式获得的概率:
P^(wi)=exp(E(wi))Z^(h)

可得:
EwiP[θE(wi)]1NwiQ(w)exp(E(wi))Q(wi)Z^(h)θE(wi)

现在我们就从Q分布中采样N个样本,组成集合J,最终得到:
EwiP[θE(wi)]wjJexp(E(wj))θE(wj)/Q(wj)wjJexp(E(wj))/Q(wj)

整体梯度为:

θJθ=θE(w)wjJexp(E(wj))θE(wj)/Q(wj)wjJexp(E(wj))/Q(wj)

下面给出算法步骤(来自 Quick Training of Probabilistic Neural Nets by Importance Sampling):
这里写图片描述

OK,如果你正在使用tensorflow的Seq2Seq模型,并且正在阅读On Using Very Large Target Vocabulary for Neural Machine Translation,对于论文中的公式(10)和公式(11),本篇笔记可以给出大致思路的解释,有不完善或错误的地方,欢迎批评指正!

### PPO算法中的重要性采样 #### 什么是重要性采样? 在强化学习中,重要性采样Importance Sampling, IS)是一种用于估计不同分布下期望的技术。具体而言,在PPO算法中,重要性采样被用来评估旧策略下的动作如何适应新策略的更新过程[^3]。 #### PPO算法中的重要性采样实现 在PPO算法中,重要性采样的核心在于通过概率比来调整奖励信号。假设当前策略为 \(\pi_{\theta}\),而之前的策略为 \(\pi_{\text{old}}\),则对于某个状态 \(s_t\) 和动作 \(a_t\),可以定义概率比为: ```python ratio = pi_theta.prob(action) / pi_old.prob(action) ``` 其中 `pi_theta` 表示新的策略分布,`pi_old` 表示旧的策略分布,`action` 则表示所采取的动作。此概率比会进一步影响损失函数的设计[^1]。 #### 原理分析 PPO的核心思想之一就是利用重要性采样技术减少方差并提高样本利用率。为了防止策略更新过程中发生过大幅度的变化,引入了截断机制(Clipping Mechanism)。该机制通过对上述提到的概率比施加一定的范围限制,从而确保策略不会偏离原有轨迹过多。具体表达式如下所示: \[ L^{CLIP} (\theta) = \hat{\mathbb{E}}_t [\min(r_t(\theta)\hat{A}_t,\; clip(r_t(\theta),\; 1-\epsilon,\; 1+\epsilon)\hat{A}_t)] \] 这里 \(r_t(\theta)=\frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}\) 即为我们之前讨论过的概率比;\(\hat{A}_t\) 是优势函数的估计值;参数 \(\epsilon\) 控制着允许的最大偏差程度[^4]。 #### Python代码示例 以下是基于PyTorch框架的一个简化版PPO训练循环片段,展示了如何应用重要性采样以及clipping技巧: ```python import torch from torch.distributions import Categorical def compute_loss(new_log_probs, old_log_probs, advantages, epsilon=0.2): ratio = (new_log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * advantages return -torch.min(surr1, surr2).mean() # Example usage within a training loop... for epoch in range(num_epochs): for state_batch, action_batch, old_log_prob_batch, advantage_batch in dataloader: policy_dist = model(state_batch) new_log_probs = policy_dist.log_prob(action_batch) loss = compute_loss(new_log_probs, old_log_prob_batch.detach(), advantage_batch) optimizer.zero_grad() loss.backward() optimizer.step() ```
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值