在很多场景中,我们需要对离散数据进行采样,然而神经网络等深度学习模型更擅长处理连续数据。为此,Gumbel Softmax通过引入Gumbel分布,将离散的选择过程转换为连续的分布,从而使得模型可以端到端地进行训练。本文将简单梳理一下这一方法。
一、传统离散数据采样方法
假设给定的离散概率分布为:
P
(
x
=
1
)
=
p
1
P
(
x
=
2
)
=
p
2
P
(
x
=
3
)
=
p
3
…
P
(
x
=
n
)
=
p
n
(1)
P(x=1)=p_1\\P(x=2)=p_2\\P(x=3)=p_3\\\dots \\P(x=n)=p_n\tag1
P(x=1)=p1P(x=2)=p2P(x=3)=p3…P(x=n)=pn(1)
如果我们要从上述分布中采样出
x
x
x, 一种简单的方法可以这样做:
- 将数轴划分为n段即 [ 0 , p 1 ) , [ p 1 , p 2 ) , [ p 2 , p 3 ) , ⋯ , [ p n − 1 , p n ] [0,p_1),[p_1,p_2),[p_2,p_3),\cdots,[p_{n-1},p_n] [0,p1),[p1,p2),[p2,p3),⋯,[pn−1,pn]。
- 从均匀分布中随机采样一个值 u ∼ U ( 0 , 1 ) u \sim U(0,1) u∼U(0,1)。
- 检查 u u u落在哪个区间,落在第 n n n个区间则采样 x = n x=n x=n 即可。
或者我们也可以采样类似逆变换的方法进行采样,具体而言:
- 从均匀分布中随机采样一个值 u ∼ U ( 0 , 1 ) u \sim U(0,1) u∼U(0,1)。
- 令 p 0 = 0 p_0 = 0 p0=0, x = arg max i ( p 0 + p 1 + p 2 + ⋯ + p i − 1 ≤ u ) x= \mathop{\arg\max}\limits_{i}(p_0+p_1+p_2+\cdots+p_{i-1} \leq u) x=iargmax(p0+p1+p2+⋯+pi−1≤u)。
尽管上述两种方法都可以从离散分布中采样,但这一方法在深度学习中并不可导,也就是说上述方法不能表示成一个平滑的函数形式: x = f ( P = [ p 1 , p 2 , ⋯ , p n ] ) x=f(P=[p_1,p_2,\cdots,p_n]) x=f(P=[p1,p2,⋯,pn])。为了使得离散采样可导,Gumbel Softmax应运而生。
二、Gumbel分布采样
Gumbel分布的概率密度函数为:
f
(
x
;
μ
,
β
)
=
1
β
e
−
z
−
e
−
z
and
z
=
(
−
x
−
μ
β
)
(2)
f(x; \mu, \beta) = \frac{1}{\beta} e^{-z-e^{-z}} \text{and} \ z=\left(-\frac{x - \mu}{\beta}\right)\tag2
f(x;μ,β)=β1e−z−e−zand z=(−βx−μ)(2)
其中
x
x
x 是随机变量,
μ
\mu
μ 是位置参数,
β
\beta
β 是尺度参数。在标准Gumbel 分布中
μ
=
0
,
β
=
1
\mu = 0, \beta =1
μ=0,β=1 。pdf的函数图像如下所示。
那么如何从这个分布中采样呢?采样方法如下:
x
=
arg
max
i
(
l
o
g
(
p
i
)
+
g
i
)
(3)
x=\mathop{\arg\max}\limits_{i}(log(p_i) + g_i)\tag3
x=iargmax(log(pi)+gi)(3)
这里的
p
i
p_i
pi 指的是各个离散数值对应的概率,
g
i
=
−
l
o
g
(
−
l
o
g
(
u
i
)
)
,
u
i
∼
U
(
0
,
1
)
g_i = -log(-log(u_i)), u_i\sim U(0,1)
gi=−log(−log(ui)),ui∼U(0,1), 是从 Gumbel 分布采样得到的噪声,目的是使得的返回结果不固定,它是标准gumbel分布的CDF的逆函数。其实,上述公式类似于VAE常用的从正太分布中采样的参数重采样方法。
总之,上述公式可以从给定的离散分布中采样出 x x x, 且采样的值服从原始离散分布(废话)。具体证明参见:
下面借用一个例子(来源:通俗易懂地理解Gumbel Softmax):
假设某人每天都会喝很多次水(比如100次),每次喝水的量服从正态分布 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)(假设 μ = 5 \mu=5 μ=5, σ = 1 \sigma=1 σ=1,不必在乎喝水次数为负的不合理情况),那么每天100次喝水里总会有一个最大值,这个最大值服从的分布就是Gumbel分布。下面给出模拟。
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
n_days = 10000
samples_per_day = 100
samples = np.random.normal(loc=5.0, scale=1.0, size=(n_days, samples_per_day))
daily_maxes = np.max(samples, axis=1)
def gumbel_pdf(prob,loc,scale):
z = (prob-loc)/scale
return np.exp(-z-np.exp(-z))/scale
def plot_maxes(daily_maxes):
probs,bins,_ = plt.hist(daily_maxes,density=True,bins=100)
print(f"==>> probs: {probs}") # 每个bin的概率
print(f"==>> bins: {bins}") # 即横坐标的tick值
print(f"==>> _: {_}")
print(f"==>> probs.shape: {probs.shape}") # (100,)
print(f"==>> bins.shape: {bins.shape}") # (101,)
plt.xlabel('Volume')
plt.ylabel('Probability of Volume being daily maximum')
# 基于直方图,下面拟合出它的曲线。
(fitted_loc, fitted_scale), _ = curve_fit(gumbel_pdf, bins[:-1],probs)
print(f"==>> fitted_loc: {fitted_loc}")
print(f"==>> fitted_scale: {fitted_scale}")
plt.plot(bins, gumbel_pdf(bins, fitted_loc, fitted_scale))
plt.figure()
plot_maxes(daily_maxes)
采样结果可视化如下:
但是我们发现,上述方法仍然需要求
arg
max
\mathop{\arg\max}
argmax, 这一过程仍然不可导。怎么变得可导呢,这个trick就是将Gumbel分布跟Softmax结合,也就是Gumbel Softmax。
三、Gumbel Softmax
Gumbel Softmax是一种用于从离散分布中采样的技术,其数学形式是基于Gumbel分布和Softmax函数的结合。
Gumbel Softmax的步骤如下:
-
为每个类别生成一个服从标准Gumbel分布的随机变量 g i = − l o g ( − l o g ( u i ) ) g_i = -log(-log(u_i)) gi=−log(−log(ui)) (Gumbel分布的CDF的逆函数)。其中, u i u_i ui 是从 [ 0 , 1 ] [0, 1] [0,1] 区间内均匀分布采样的随机变量。
-
将生成的Gumbel变量与类别概率相加,得到 y i = l o g ( p i ) + g i y_i = log(p_i) + g_i yi=log(pi)+gi。
-
应用Softmax函数来获得一个软化的概率分布 p i ^ \hat{p_i} pi^:
p i ^ = exp ( y i / τ ) ∑ j = 1 n exp ( y j / τ ) (4) \hat{p_i} = \frac{\exp(y_i / \tau)}{\sum_{j=1}^{n} \exp(y_j / \tau)}\tag4 pi^=∑j=1nexp(yj/τ)exp(yi/τ)(4)其中, τ \tau τ 是温度参数,控制着采样的软化程度。当 τ \tau τ 接近0时,Softmax的输出趋近于一个one-hot编码向量,等价于公式(3), 也就实现了离散采样,且采样的随机数服从Gumbel 分布;当 τ \tau τ 增大时,输出更加平滑,接近均匀分布。
-
最终, p i ^ \hat{p_i } pi^可以被视为从离散变量 x x x 的类别分布中采样的连续近似。
在温度参数 τ \tau τ 趋近于0的极限情况下,Gumbel Softmax的输出将接近于一个one-hot向量,其中只有一个元素接近1,其余元素接近0,这相当于从类别分布中进行硬采样。而在较高的温度下,Gumbel Softmax的输出更加平滑,允许梯度在离散选择上传播,从而可以在神经网络中进行端到端的训练。
接下里,给出Gumbel Softmax采样示例代码:
# Gumbel softmax trick:
import torch
import torch.nn.functional as F
import numpy as np
def inverse_gumbel_cdf(y, mu, beta):
return mu - beta * np.log(-np.log(y))
def gumbel_softmax_sampling(h, mu=0, beta=1, tau=0.1):
"""
h : (N x K) tensor. Assume we need to sample a NxK tensor, each row is an independent r.v.
"""
shape_h = h.shape
p = F.softmax(h, dim=1)
y = torch.rand(shape_h) + 1e-25 # ensure all y is positive.
g = inverse_gumbel_cdf(y, mu, beta)
x = torch.log(p) + g # samples follow Gumbel distribution.
# using softmax to generate one_hot vector:
x = x/tau
x = F.softmax(x, dim=1) # now, the x approximates a one_hot vector.
return x
N = 10 # 假设 有N个独立的离散变量需要采样
K = 3 # 假设 每个离散变量有3个取值
h = torch.randn((N, K)) # 假设 h是由一个神经网络输出的tensor。
mu = 0
beta = 1
tau = 0.1
samples = gumbel_softmax_sampling(h, mu, beta, tau)
参考资料
[1] 通俗易懂地理解Gumbel Softmax
[2] Gumbel softmax trick (快速理解附代码)
[3] 漫谈重参数:从正态分布到Gumbel Softmax
[4] Gumbel-Softmax Trick和Gumbel分布