生成对抗网络 GAN 的数学原理

本文深入探讨生成对抗网络(GAN)的数学原理,从概率分布、极大似然估计、KL散度、JS散度出发,详细阐述了GAN的内在工作机制,并通过实例介绍了参数求解的两步迭代方法,最后简要提及了GAN在工程实践中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

摘要

本文从概率分布及参数估计说起, 通过介绍极大似然估计, KL 散度, JS 散度, 详细的介绍了 GAN 生成对抗网络的数学原理.

相关

系列文章索引 :

https://blog.youkuaiyun.com/oBrightLamp/article/details/85067981

正文

无论是黑白图片或彩色图片, 都是使用 0 ~ 255 的数值表示像素. 将所有的像素值除以 255 我们就可以将一张图片转化为 0 ~ 1 的概率分布, 而且这种转化是可逆的, 乘以 255 就可以还原.

从某种意义上来讲, GAN 图片生成任务就是生成概率分布. 因此, 我们有必要结合概率分布来理解 GAN 生成对抗网络的原理.

1. 概率分布及参数估计

假设一个抽奖盒子里有45个球, 其编号是 1~9 共9个数字. 每个编号的球拥有的数量是:

编号 1 2 3 4 5 6 7 8 9
数量 2 4 6 8 9 7 5 3 1
占比 0.044 0.088 0.133 0.178 0.200 0.156 0.111 0.066 0.022

占比是指用每个编号的数量除以所有编号的总数量, 在数理统计学中, 在不引起误会的情况下, 这里的占比也可以被称为 概率/频率.

使用向量 q q q 表示上述的概率分布 :

q = ( 2 , 4 , 6 , 8 , 9 , 7 , 5 , 3 , 1 ) / 45    = ( 0.044 , 0.088 , 0.133 , 0.178 , 0.200 , 0.156 , 0.111 , 0.066 , 0.022 ) q = (2,4,6,8,9,7,5,3,1)/45 \;\\ =(0.044, 0.088, 0.133, 0.178, 0.200, 0.156, 0.111, 0.066, 0.022) q=(2,4,6,8,9,7,5,3,1)/45=(0.044,0.088,0.133,0.178,0.200,0.156,0.111,0.066,0.022)

将上述分布使用图像绘制如下 :

在这里插入图片描述

现在我们希望构建一个函数 p = p ( x ; θ ) p = p(x;\theta) p=p(x;θ), 以 x x x 为编号作为输入数据, 输出编号 x x x 的概率. θ \theta θ 是参与构建这个函数的参数, 一经选定就不再变化.

假设上述概率分布服从二次抛物线函数 :
p = p ( x ; θ ) = θ 1 ( x + θ 2 ) 2 + θ 3    x = ( 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ) p=p(x;\theta)=\theta_1 (x+\theta_2)^2+\theta_3\\ \;\\ x = (1,2,3,4,5,6,7,8,9) p=p(x;θ)=θ1(x+θ2)2+θ3x=(1,2,3,4,5,6,7,8,9)

使用 L2 误差作为评价拟合效果的损失函数, 总误差值为 error (标量 e) :
e = ∑ i = 1 9 ( p i − q i ) 2 e = \sum_{i=1}^{9}(p_i-q_i)^2 e=i=19(piqi)2
我们希望求得一个 θ ∗ \theta^* θ, 使得 e e e 的值越小越好, 使用数学公式表达是这样子的 :
θ ∗ = a r g m i n θ ( e ) \theta^* = \underset{\theta}{argmin}(e) θ=θargmin(e)
argmin 是 argument minimum 的缩写.

如何求 θ ∗ \theta^* θ 不是本文的重点, 这是生成对抗网络的任务. 为了帮助理解, 取其中一个可能的数值作为示例 :

θ ∗ = ( θ 1 , θ 2 , θ 3 ) = ( − 0.01 , − 5.0 , 0.2 )    p = p ( x ; θ ) = − 0.01 ( x − 5.0 ) 2 + 0.2 \theta^* = (\theta_1,\theta_2,\theta_3)=(-0.01,-5.0,0.2)\\ \;\\ p=p(x;\theta)=-0.01 (x-5.0)^2+0.2 θ=(θ1,θ2,θ3)=(0.01,5.0,0.2)p=p(x;θ)=0.01(x5.0)2+0.2
绘制函数图像如下 :

在这里插入图片描述

在生成对抗网络中, 本例的估计函数 p ( x ; θ ) p(x;\theta) p(x;θ) 相当于生成模型 (generator), 损失函数相当于鉴别模型 (discriminator).

2. 极大似然估计

在上例中, 我们很幸运的知道了所有可能的概率分布, 并让求解最优的概率分布估计函数 p ( x ; θ ) p(x;\theta) p(x;θ) 成为可能.

如果上例的抽奖盒子 (简称样本) 中的 45 个球是从更大的抽奖池 (简称总体) 中选择出来的, 而我们不知道抽奖池中所有球的数量及其编号. 那么, 我们如何根据现有的 45 个球来估计抽奖池的概率分布呢? 我们当然可以直接用上例求得的样本估计函数来代表抽奖池的概率分布, 但本例会介绍一种更常用的估计方法.

假设 p ( x ) = p ( x ; θ ) p(x)=p(x;\theta) p(x)=p(x;θ) 是总体的概率分布函数. 则编号 x = ( x 1 , x 2 , x 3 , ⋯   , x n ) x = (x_1,x_2,x_3,\cdots,x_n) x=(x1,x2,x3,,xn) 出现的概率为 :
p = p ( x 1 ) , p ( x 2 ) , p ( x 3 ) , ⋯   , p ( x n ) p = p(x_1),p(x_2),p(x_3),\cdots,p(x_n) p=p(x1),p(x2),p(x3),,p(xn)
在本例中, n = 9 n = 9 n=9, 即共 9 个编号.

d = ( d 1 , d 2 , d 3 , d 3 ⋯   , d m ) d=(d_1,d_2,d_3,d_3\cdots,d_m) d=(d1,d2,d3,d3,dm) 是所有的抽样的编号. 在本例中, m = 45 m = 45 m=45, 即样本中共有 45 个抽样. 假设所有的样本和抽样都是独立的, 则样本出现的概率为 :
ρ = p ( d 1 ) × p ( d 2 ) × p ( d 3 ) × ⋯ × p ( d m ) = ∏ i = 1 m p ( d i ) \rho= p(d_1)\times p(d_2)\times p(d_3)\times\cdots\times p(d_m)=\prod_{i=1}^{m}p(d_i) ρ=p(d1)×p(d2)×p(d3)××p(dm)=i=1mp(di)
p ( x ) = p ( x ; θ ) p(x)=p(x;\theta) p(x)=p(x;θ) 的函数结构是人为按经验选取的, 比如线性函数, 多元二次函数, 更复杂的非线性函数等, 一经选取则不再改变. 现在我们需要求解一个参数集 θ ∗ \theta^* θ, 使得 ρ \rho ρ 的值越大越好. 即
θ ∗ = a r g m a x θ ( ρ ) = a r g m a x θ ∏ i = 1 m p ( d i ; θ ) \theta^* = \underset{\theta}{argmax}(\rho)=\underset{\theta}{argmax}\prod_{i=1}^{m}p(d_i;\theta) θ=θargmax(ρ)=θargmaxi=1mp(di;θ)
argmax 是 argument maximum 的缩写.

通俗来讲, 因为样本是实际已发生的事实, 在函数结构已确定的情况下, 我们需要尽量优化参数, 使得样本的理论估计概率越大越好.

这里有一个前提, 就是人为选定的函数结构应当能够有效评估样本分布. 反之, 如果使用线性函数去拟合正态概率分布 (normal distribution), 则无论如何选择参数都无法得到满意的效果.

连乘运算不方便, 将之改为求和运算. 由于 l o g log log 对数函数的单调性, 上面的式子等价于 :
θ ∗ = a r g m a x θ    l o g ∏ i = 1 m p ( d i ; θ ) = a r g m a x θ ∑ i = 1 m l o g    p ( d i ; θ ) \theta^* =\underset{\theta}{argmax}\;log\prod_{i=1}^{m}p(d_i;\theta)=\underset{\theta}{argmax}\sum_{i=1}^{m}log\;p(d_i;\theta) θ=θargmaxlogi=1mp(di;θ)=θargmaxi=1mlogp(di;θ)
设样本分布为 q ( x ) q(x) q(x), 对于给定样本, 这个分布是已知的, 可以通过统计抽样的计算得出. 将上式转化成期望公式 :
θ ∗ = a r g m a x θ ∑ i = 1 m l o g    p ( d i ; θ ) = a r g m a x θ ∑ i = 1 n q ( x i )    l o g    p ( x i ; θ ) \theta^* =\underset{\theta}{argmax}\sum_{i=1}^{m}log\;p(d_i;\theta) =\underset{\theta}{argmax}\sum_{i=1}^{n}q(x_i)\;log\;p(x_i;\theta) θ=θargmaxi=1mlogp(di;θ)=θargm

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值