Argmax是不可求导的,Gumbel Softmax允许模型能从网络层的离散分布(比如类别分布categorical distribution)中稀疏采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数。
算法流程
- 对于某个网络层输出的 n \mathrm{n} n 维向量 v = [ v 1 , v 2 , … , v n ] v=\left[v_1, v_2, \ldots, v_n\right] v=[v1,v2,…,vn],生成 n \mathrm{n} n 个服从均匀分布 U ( 0 , 1 ) \mathrm{U}(0,1) U(0,1) 的独立样本 ϵ 1 , … , ϵ n \epsilon_1, \ldots, \epsilon_n ϵ1,…,ϵn
- 通过 G i = − log ( − log ( ϵ i ) ) G_i=-\log \left(-\log \left(\epsilon_i\right)\right) Gi=−log(−lo