pytorch的torch.distributions中可以定义正态分布
如下:
import torch
from torch.distributions import Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)
sample()
sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:
c=normal.sample()
print("c:",c)
输出:
c: tensor([-1.3362, 3.1730])
rsample()
rsample()不是在定义的正太分布上采样,而是先对标准正太分布N(0,1)N(0,1)N(0,1)进行采样,然后输出:mean+std×采样值mean+std\times采样值mean+std×采样值
a=normal.rsample()
输出:
a: tensor([ 0.0530, 2.8396])
log_prob(value)
log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是f(x)=12πσe−(x−μ)22σ2f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}f(x)=2πσ1e−2σ2(x−μ)2,对其取对数可得log(f(x))=−(x−μ)22σ2−log(σ)−log(2π)log(f(x))=-\frac{(x-\mu)^2}{2\sigma^2}-log(\sigma)-log(\sqrt{2\pi})log(f(x))=−2σ2(x−μ)2−log(σ)−log(2π)
这里我们通过对数概率还原其对应的真实概率:
print("c log_prob:",normal.log_prob(c).exp())
输出:
c log_prob: tensor([ 0.1634, 0.2005])
PyTorch正态分布采样详解

本文详细介绍了在PyTorch中使用torch.distributions模块定义正态分布,并对比了sample()和rsample()两种采样方法的区别,同时解释了log_prob()函数如何计算给定值在正态分布中的对数概率。
4214





