Gumbel Softmax方法

在很多场景中,我们需要对离散数据进行采样,然而神经网络等深度学习模型更擅长处理连续数据。为此,Gumbel Softmax通过引入Gumbel分布,将离散的选择过程转换为连续的分布,从而使得模型可以端到端地进行训练。本文将简单梳理一下这一方法。

一、传统离散数据采样方法

假设给定的离散概率分布为:
P ( x = 1 ) = p 1 P ( x = 2 ) = p 2 P ( x = 3 ) = p 3 … P ( x = n ) = p n (1) P(x=1)=p_1\\P(x=2)=p_2\\P(x=3)=p_3\\\dots \\P(x=n)=p_n\tag1 P(x=1)=p1P(x=2)=p2P(x=3)=p3P(x=n)=pn(1)
如果我们要从上述分布中采样出 x x x, 一种简单的方法可以这样做:

  1. 将数轴划分为n段即 [ 0 , p 1 ) , [ p 1 , p 2 ) , [ p 2 , p 3 ) , ⋯   , [ p n − 1 , p n ] [0,p_1),[p_1,p_2),[p_2,p_3),\cdots,[p_{n-1},p_n] [0,p1),[p1,p2),[p2,p3),,[pn1,pn]
  2. 从均匀分布中随机采样一个值 u ∼ U ( 0 , 1 ) u \sim U(0,1) uU(0,1)
  3. 检查 u u u落在哪个区间,落在第 n n n个区间则采样 x = n x=n x=n 即可。

或者我们也可以采样类似逆变换的方法进行采样,具体而言:

  1. 从均匀分布中随机采样一个值 u ∼ U ( 0 , 1 ) u \sim U(0,1) uU(0,1)
  2. p 0 = 0 p_0 = 0 p0=0 x = arg ⁡ max ⁡ i ( p 0 + p 1 + p 2 + ⋯ + p i − 1 ≤ u ) x= \mathop{\arg\max}\limits_{i}(p_0+p_1+p_2+\cdots+p_{i-1} \leq u) x=iargmax(p0+p1+p2++pi1u)

尽管上述两种方法都可以从离散分布中采样,但这一方法在深度学习中并不可导,也就是说上述方法不能表示成一个平滑的函数形式: x = f ( P = [ p 1 , p 2 , ⋯   , p n ] ) x=f(P=[p_1,p_2,\cdots,p_n]) x=f(P=[p1,p2,,pn])。为了使得离散采样可导,Gumbel Softmax应运而生。

二、Gumbel分布采样

Gumbel分布的概率密度函数为:
f ( x ; μ , β ) = 1 β e − z − e − z and  z = ( − x − μ β ) (2) f(x; \mu, \beta) = \frac{1}{\beta} e^{-z-e^{-z}} \text{and} \ z=\left(-\frac{x - \mu}{\beta}\right)\tag2 f(x;μ,β)=β1ezezand z=(βxμ)(2)

其中 x x x 是随机变量, μ \mu μ 是位置参数, β \beta β 是尺度参数。在标准Gumbel 分布中 μ = 0 , β = 1 \mu = 0, \beta =1 μ=0β=1 。pdf的函数图像如下所示。
在这里插入图片描述
那么如何从这个分布中采样呢?采样方法如下:
x = arg ⁡ max ⁡ i ( l o g ( p i ) + g i ) (3) x=\mathop{\arg\max}\limits_{i}(log(p_i) + g_i)\tag3 x=iargmax(log(pi)+gi)(3)
这里的 p i p_i pi 指的是各个离散数值对应的概率, g i = − l o g ( − l o g ( u i ) ) , u i ∼ U ( 0 , 1 ) g_i = -log(-log(u_i)), u_i\sim U(0,1) gi=log(log(ui)),uiU(0,1), 是从 Gumbel 分布采样得到的噪声,目的是使得的返回结果不固定,它是标准gumbel分布的CDF的逆函数。其实,上述公式类似于VAE常用的从正太分布中采样的参数重采样方法。

总之,上述公式可以从给定的离散分布中采样出 x x x, 且采样的值服从原始离散分布(废话)。具体证明参见:

  1. 漫谈重参数:从正态分布到Gumbel Softmax
  2. Gumbel-Softmax Trick和Gumbel分布

下面借用一个例子(来源:通俗易懂地理解Gumbel Softmax):

假设某人每天都会喝很多次水(比如100次),每次喝水的量服从正态分布 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)(假设 μ = 5 \mu=5 μ=5 σ = 1 \sigma=1 σ=1,不必在乎喝水次数为负的不合理情况),那么每天100次喝水里总会有一个最大值,这个最大值服从的分布就是Gumbel分布。下面给出模拟。

from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
n_days = 10000
samples_per_day = 100
samples = np.random.normal(loc=5.0, scale=1.0, size=(n_days, samples_per_day))
daily_maxes = np.max(samples, axis=1)

def gumbel_pdf(prob,loc,scale):
    z = (prob-loc)/scale
    return np.exp(-z-np.exp(-z))/scale

def plot_maxes(daily_maxes):
    probs,bins,_ = plt.hist(daily_maxes,density=True,bins=100)
    print(f"==>> probs: {probs}") # 每个bin的概率
    print(f"==>> bins: {bins}") # 即横坐标的tick值
    print(f"==>> _: {_}")
    print(f"==>> probs.shape: {probs.shape}") # (100,)
    print(f"==>> bins.shape: {bins.shape}") # (101,)
    plt.xlabel('Volume')
    plt.ylabel('Probability of Volume being daily maximum')

    # 基于直方图,下面拟合出它的曲线。
    (fitted_loc, fitted_scale), _ = curve_fit(gumbel_pdf, bins[:-1],probs)
    print(f"==>> fitted_loc: {fitted_loc}")
    print(f"==>> fitted_scale: {fitted_scale}")
    plt.plot(bins, gumbel_pdf(bins, fitted_loc, fitted_scale))

plt.figure()
plot_maxes(daily_maxes)

采样结果可视化如下:
在这里插入图片描述
但是我们发现,上述方法仍然需要求 arg ⁡ max ⁡ \mathop{\arg\max} argmax, 这一过程仍然不可导。怎么变得可导呢,这个trick就是将Gumbel分布跟Softmax结合,也就是Gumbel Softmax。

三、Gumbel Softmax

Gumbel Softmax是一种用于从离散分布中采样的技术,其数学形式是基于Gumbel分布和Softmax函数的结合。

Gumbel Softmax的步骤如下:

  1. 为每个类别生成一个服从标准Gumbel分布的随机变量 g i = − l o g ( − l o g ( u i ) ) g_i = -log(-log(u_i)) gi=log(log(ui)) (Gumbel分布的CDF的逆函数)。其中, u i u_i ui 是从 [ 0 , 1 ] [0, 1] [0,1] 区间内均匀分布采样的随机变量。

  2. 将生成的Gumbel变量与类别概率相加,得到 y i = l o g ( p i ) + g i y_i = log(p_i) + g_i yi=log(pi)+gi

  3. 应用Softmax函数来获得一个软化的概率分布 p i ^ \hat{p_i} pi^
    p i ^ = exp ⁡ ( y i / τ ) ∑ j = 1 n exp ⁡ ( y j / τ ) (4) \hat{p_i} = \frac{\exp(y_i / \tau)}{\sum_{j=1}^{n} \exp(y_j / \tau)}\tag4 pi^=j=1nexp(yj/τ)exp(yi/τ)(4)

    其中, τ \tau τ 是温度参数,控制着采样的软化程度。当 τ \tau τ 接近0时,Softmax的输出趋近于一个one-hot编码向量,等价于公式(3), 也就实现了离散采样,且采样的随机数服从Gumbel 分布;当 τ \tau τ 增大时,输出更加平滑,接近均匀分布。

  4. 最终, p i ^ \hat{p_i } pi^可以被视为从离散变量 x x x 的类别分布中采样的连续近似。

在温度参数 τ \tau τ 趋近于0的极限情况下,Gumbel Softmax的输出将接近于一个one-hot向量,其中只有一个元素接近1,其余元素接近0,这相当于从类别分布中进行硬采样。而在较高的温度下,Gumbel Softmax的输出更加平滑,允许梯度在离散选择上传播,从而可以在神经网络中进行端到端的训练。

接下里,给出Gumbel Softmax采样示例代码:

# Gumbel softmax trick:
import torch
import torch.nn.functional as F
import numpy as np

def inverse_gumbel_cdf(y, mu, beta):
    return mu - beta * np.log(-np.log(y))

def gumbel_softmax_sampling(h, mu=0, beta=1, tau=0.1):
    """
    h : (N x K) tensor. Assume we need to sample a NxK tensor, each row is an independent r.v.
    """
    shape_h = h.shape
    p = F.softmax(h, dim=1)
    y = torch.rand(shape_h) + 1e-25  # ensure all y is positive.
    g = inverse_gumbel_cdf(y, mu, beta)
    x = torch.log(p) + g  # samples follow Gumbel distribution.
    # using softmax to generate one_hot vector:
    x = x/tau
    x = F.softmax(x, dim=1)  # now, the x approximates a one_hot vector.
    return x

N = 10  # 假设 有N个独立的离散变量需要采样
K = 3   # 假设 每个离散变量有3个取值
h = torch.randn((N, K))  # 假设 h是由一个神经网络输出的tensor。

mu = 0
beta = 1
tau = 0.1

samples = gumbel_softmax_sampling(h, mu, beta, tau)

参考资料

[1] 通俗易懂地理解Gumbel Softmax
[2] Gumbel softmax trick (快速理解附代码)
[3] 漫谈重参数:从正态分布到Gumbel Softmax
[4] Gumbel-Softmax Trick和Gumbel分布

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Researcher-Du

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值