【一文学会】Gumbel-Softmax的采样技巧

目录

基于softmax的采样

基于gumbel-max的采样

基于gumbel-softmax的采样

基于ST-gumbel-softmax的采样

Gumbel分布

回答问题一

回答问题二

回答问题三

附录


 

以强化学习为例,假设网络输出的三维向量代表三个动作(前进、停留、后退)在下一步的收益,value=[-10,10,15],那么下一步我们就会选择收益最大的动作(后退)继续执行,于是输出动作[0,0,1]。选择值最大的作为输出动作,这样做本身没问题,但是在网络中这种取法有个问题是不能计算梯度,也就不能更新网络。

基于softmax的采样

这时通常的做法是加上softmax函数,把向量归一化,这样既能计算梯度,同时值的大小还能表示概率的含义(多项分布)。

                                                    \fn_phv \large \pi_k = \frac{e^{x_k}}{\sum_{i=1}^{K} e^{x_{i}}}

于是value=[-10,10,15]通过softmax函数后有σ(value)=[0,0.007,0.993],这样做不会改变动作或者说类别的选取,同时softmax倾向于让最大值的概率显著大于其他值,比如这里15和10经过softmax放缩之后变成了0.993和0.007,这有利于把网络训成一个one-hot输出的形式,这种方式在分类问题中是常用方法。

但这样就不会体现概率的含义了,因为σ(value)=[0,0.007,0.993]与σ(value)=[0.3,0.2,0.5]在类别选取的结果看来没有任何差别,都是选择第三个类别,但是从概率意义上讲差别是巨大的。

很直接的方法是依概率采样完事了,比如直接用np.random.choice函数依照概率生成样本值,这样概率就有意义了。所以,经典的采样方法就是用softmax函数加上轮盘赌方法(np.random.choice)。但这样还是会有个问题,这种方式怎么计算梯度?不能计算梯度怎么更新网络?

def sample_with_softmax(logits, size):
# logits为输入数据
# size为采样数
    pro = softmax(logits)
    return np.random.choice(len(logits), size, p=pro)

 

基于gumbel-max的采样

gumbel分布的具体介绍会放在后文,我们先看看结论。对于K维概率向量\large \alpha,对\large \alpha对应的离散变量x_{i}=log(\alpha _i)添加Gumbel噪声,再取样

                                   \large x=\mathop{argmax}_i(\log(\alpha _i)+G_i)

其中,\fn_phv \large G_i是独立同分布的标准Gumbel分布的随机变量,标准Gumbel分布的CDF为F(x)=e^{-e^{-x}}.所以\large G_i可以通过Gumbel分布求逆从均匀分布生成,即G_i=-\log(-\log(U_i)),U_i\sim U(0,1)x_{i}=log(\alpha _i)代入计算可知,这里的\large \alpha就是上面softmax采样的\large \pi,这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本ϵ1,...,ϵK;

  • 通过G_i=-\log(-\log(\varepsilon _i))计算得到G_i;

  • 对应相加得到新的值向量v′=[v1+G1,v2+G2,...,vK+GK];

  • 取最大值作为最终的类别

可以证明,gumbel-max 方法的采样效果等效于基于 softmax 的方式(后文也会证明)。由于 Gumbel 随机数可以预先计算好,采样过程也不需要计算 softmax,因此,某些情况下,gumbel-max 方法相比于 softmax,在采样速度上会有优势。当然,可以看到由于这中间有一个argmax操作,这是不可导的,依旧没法用于计算网络梯度。

def sample_with_gumbel_noise(logits, size):
    noise = sample_gumbel((size, len(logits)))    # 产生gumbel noise
    return np.argmax(logits + noise, axis=1)

 

基于gumbel-softmax的采样

如果仅仅是提供一种常规 softmax 采样的替代方案, gumbel 分布似乎应用价值并不大。幸运的是,我们可以利用 gumbel 实现多项分布采样的 reparameterization(再参数化)。

VAE中,假设隐变量(latent variables)服从标准正态分布。而现在,利用 gumbel-softmax 技巧,我们可以将隐变量建模为服从离散的多项分布。在前面的两种方法中,random.choice和argmax注定了这两种方法不可导,但我们可以将后一种方法中的argmax soft化,变为softmax。

                              \large x=\mathop{softmax}((\log(\alpha _i)+G_i)/temperature)

temperature 是在大于零的参数,它控制着 softmax 的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。训练中,可以通过逐渐降低温度,以逐步逼近真实的离散分布。

这样就得到了基于gumbel-max的采样过程:

  • 对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本ϵ1,...,ϵK;

  • 通过

### Argmax采样Gumbel-Softmax采样的对比 #### 定义与机制 Argmax操作是从给定的组概率中选取具有最高概率的类别作为输出。这过程是非随机性的,总是返回最可能的结果。 相比之下,Gumbel-Softmax引入了随机性,在logits上加上来自Gumbel分布的噪声之后再执行Softmax函数来近似离散分布抽样[^2]。这使得模型能够在训练期间保留梯度信息并允许反向传播算法更新参数。 #### 应用场景 当需要确切的最大值预测时,比如分类任务中的最终决策阶段,通常采用argmax方法。然而,在涉及强化学习或变分自编码器等情况下,其中期望对不确定性和探索有所表示,则更倾向于使用带有温度参数控制平滑程度的Gumbel-Softmax技巧[^1]。 #### 优点分析 ##### Argmax采样 - **简单直观**:直接选择最大可能性选项,易于理解和实现。 - **确定性强**:每次都会给出相同的决定,适合于那些不需要考虑多样性的场合。 ##### Gumbel-Softmax采样 - **可微分性**:由于加入了连续松弛(即通过调整τ),因此可以在神经网络框架内方便地计算损失相对于输入logit的导数。 - **支持不确定性建模**:能够表达不同样本间的差异以及提供定程度上的多样性,有助于防止过拟合现象的发生。 #### 缺点探讨 ##### Argmax采样 - **缺乏灵活性**:旦做出选择就无法改变,不利于处理动态环境下的问题求解。 - **不可导特性**:因为它是硬性分配而不是软性加权平均的形式,所以在某些特定架构下难以与其他组件协同工作。 ##### Gumbel-Softmax采样 - **复杂度增加**:相较于简单的取大运算而言,增加了额外的操作步骤和超参调节需求(如设置合适的τ值)。 - **潜在偏差风险**:如果设定不当的话可能会导致估计结果偏离真实情况较远;另外,随着维度的增长也可能面临数值稳定性挑战。 ```python import torch from torch.distributions import gumbel def sample_gumbel_softmax(logits, temperature=1.0): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar Returns: Sampled tensor of shape [batch_size, n_class] """ y = logits + gumbel.Gumbel(0, 1).sample(logits.shape) return torch.nn.functional.softmax(y / temperature, dim=-1) # Example usage with argmax vs gumbel-softmax sampling logits = torch.tensor([[1., 2., 3.]]) print("Argmax Sampling:", torch.argmax(logits, dim=-1)) for temp in [0.5, 1.0, 2.0]: print(f"Gumbel-Softmax Sampling (temp={temp}):", sample_gumbel_softmax(logits, temperature=temp)) ```
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值