Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles

这篇论文介绍了一种结合集成学习与神经网络的方法,通过集成MMM个不同模型预测不确定性。作者使用高斯分布假设训练,增加方差输出,并利用对抗样本增强预测平滑性。预测过程中,通过最速梯度方法获取对抗样本并采用平均策略集成预测结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这篇论文中作者用集成学习的思想来进行不确定性的度量。

问题设定

训练数据用 D={xn,yn}n=1N\mathcal{D} = \{\mathbf{x}_{n}, y_{n}\}_{n=1}^{N}D={xn,yn}n=1N 表示,x∈RD\mathbf{x} \in \mathbb{R}^{D}xRD 代表 DDD 维的数据,对于分类任务,y∈{1,…,K}y \in \{1, \ldots, K\}y{1,,K}, 对于回归任务,y∈Ry \in \mathbb{R}yR。给定输入特征 x\mathrm{x}x, 用神经网络对概率性预测分布进行建模为 pθ(y∣x)p_\theta(y\mid\mathbf{x})pθ(yx), 其中 θ\thetaθ 是神经网络的权重参数。

本文中采用集成学习(emsemble)的方法进行不确定性的预测,就是将 MMM 个不同的神经网络进行集成,用 {θm}m=1M\{\theta_{m}\}_{m=1}^{M}{θm}m=1M 来表示这些神经网络的权重参数。

回归任务上的训练

文章将训练数据看作是来自某种高斯分布中采样,所以神经网络应该能通过{xn,yn}n=1N\{\mathbf{x}_{n}, y_{n}\}_{n=1}^{N}{xn,yn}n=1N 的对应关系预测出这种高斯分布,也就是得到该分布的均值和方差,普通的神经网络的输出可以看作是均值 μ(x)\mu(\mathrm{x})μ(x),该文章在此基础上增加了方差的输出 σ2>0\sigma^{2} > 0σ2>0 , 则优化目标可以指定为最小化以下 negative log-likelihood criterion:
−log⁡pθ(yn∣xn)=log⁡σθ2(x)2+(y−μθ(x))22σθ2(x)+constant⁡(1) -\log p_{\theta}\left(y_{n} \mid \mathbf{x}_{n}\right)=\frac{\log \sigma_{\theta}^{2}(\mathbf{x})}{2}+\frac{\left(y-\mu_{\theta}(\mathbf{x})\right)^{2}}{2 \sigma_{\theta}^{2}(\mathbf{x})}+\operatorname{constant} \tag{1} logpθ(ynxn)=2logσθ2(x)+2σθ2(x)(yμθ(x))2+constant(1)
作者除了用采用集成学习的策略,还利用了对抗样本来对预测结果进行平滑,本文采用最速梯度轨迹方法(fastfastfast gradientgradientgradient signsignsign methodmethodmethod)来获取对抗样本,具体来说就是,给定输入 x\mathbf{x}x 以及对应标签 yyy,网络前馈计算获取损失值 ℓ(θ,x,y)\ell(\theta, \mathbf{x}, y)(θ,x,y),所求取的对抗样本为 x′=x+ϵ\mathbf{x}^{\prime} = \mathbf{x} + \boldsymbol\epsilonx=x+ϵ,其中 ϵ\boldsymbol\epsilonϵ 是某个方向的一个比较小的值,获取方法就是求损失函数对于 x\mathbf{x}x 的导数,即 ϵ=sign⁡(∇x ℓ(θ,x,y))\boldsymbol\epsilon = \operatorname{sign}(\nabla_{\mathbf{x}} ~\ell(\theta, \mathbf{x}, y))ϵ=sign(x (θ,x,y)),将 (x′,y)(\mathbf{x}^{\prime}, y)(x,y) 也看作是一个训练样本。

最后训练过程的伪代码如下图所示:
训练过程的伪代码
对于集成的策略,作者采用最简单的加起来就平均的方式,则最终的预测结果为:
p(y∣x)=1M∑m=1M p(y∣x,θm)(2) p(y\mid\mathbf{x}) = \frac{1}{M} \sum_{m=1}^{M} ~p(y\mid\mathbf{x},\theta_{m}) \tag{2} p(yx)=M1m=1M p(yx,θm)(2)预测均值为:
μ∗(x)=1M∑m μθm(x) \mu_{*}(\mathbf{x}) = \frac{1}{M} \sum_{m} ~\mu_{\theta_{m}}(\mathbf{x}) μ(x)=M1m μθm(x)预测方差为:
σ∗2(x)=1M∑m (σθm2(x)+μθm2(x))−μ∗2(x)(3) \sigma_{*}^{2}(\mathbf{x}) = \frac{1}{M} \sum_{m} ~(\sigma_{\theta_{m}}^{2}(\mathbf{x})+\mu_{\theta_{m}}^{2}(\mathbf{x}))-\mu_{*}^{2}(\mathbf{x}) \tag{3} σ2(x)=M1m (σθm2(x)+μθm2(x))μ2(x)(3)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值