模式识别与机器学习课程笔记(12):生成对抗网络(GAN)
文章目录
1. 引言:为什么需要GAN?
在传统生成模型(如变分自编码器VAE)中,生成过程依赖于显式的概率分布假设(如高斯分布),且生成样本的质量和多样性常受限于模型对数据分布的拟合能力。2014年,Ian Goodfellow等人提出的生成对抗网络(Generative Adversarial Networks, GAN) 打破了这一局限——它通过“对抗训练”的思想,让两个网络(生成器与判别器)相互博弈,无需显式定义数据分布即可学习到数据的真实特征,最终生成高度逼真的样本。
GAN的核心价值在于:
- 摆脱对显式概率分布的依赖,更贴合真实世界数据的复杂分布;
- 生成样本的细节丰富度更高(如图像、音频),在计算机视觉、自然语言处理等领域应用广泛;
- 启发了大量变种模型(如DCGAN、WGAN、CycleGAN),成为生成式AI的重要基石。
2. GAN的基本原理
GAN的架构由两个核心网络组成:生成器(Generator, G) 和判别器(Discriminator, D),二者通过“极小极大游戏(Min-Max Game)”实现共同优化。
2.1 核心思想:对抗与协作
- 判别器D:本质是一个二分类器,目标是区分输入样本“是真实数据(来自数据分布 p d a t a ( x ) p_{data}(x) pdata(x))”还是“生成器伪造的数据(来自生成分布 p g ( x ) p_g(x) pg(x))”,输出样本为真实数据的概率(范围 [ 0 , 1 ] [0,1] [0,1])。
- 生成器G:接收一个随机噪声向量 z z z(来自噪声分布 p z ( z ) p_z(z) pz(z),如高斯分布),通过非线性变换将其映射为与真实数据维度一致的生成样本 G ( z ) G(z) G(z),目标是让生成样本足够逼真,以至于判别器无法区分(即让 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))→1)。
二者的对抗关系可概括为:判别器努力提高区分准确率,生成器努力降低判别器的区分准确率,最终达到纳什均衡——判别器无法区分真实与生成样本(输出概率恒为 0.5 0.5 0.5),生成器学到了真实数据的分布。
2.2 网络结构与数学定义
(1)符号定义
- x x x:真实数据样本, x ∼ p d a t a ( x ) x \sim p_{data}(x) x∼pdata(x)( p d a t a p_{data} pdata为真实数据分布);
- z z z:随机噪声向量, z ∼ p z ( z ) z \sim p_z(z) z∼pz(z)( p z p_z pz为噪声分布,常用 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1));
- G ( z ) G(z) G(z):生成器输出的伪造样本, G : R d → R n G: \mathbb{R}^d \to \mathbb{R}^n G:Rd→Rn( d d d为噪声维度, n n n为真实数据维度);
- D ( x ) D(x) D(x):判别器对样本 x x x的输出(真实概率), D : R n → [ 0 , 1 ] D: \mathbb{R}^n \to [0,1] D:Rn→[0,1]。
(2)目标函数(极小极大游戏)
GAN的核心是求解以下极小极大优化问题:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
d
a
t
a
(
x
)
[
log
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
)
)
)
]
\min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))]
GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
- 对判别器 D D D:需最大化目标函数 V ( D , G ) V(D,G) V(D,G)。第一项 log D ( x ) \log D(x) logD(x)表示“正确识别真实样本”的对数损失( D ( x ) D(x) D(x)越大,该项越大);第二项 log ( 1 − D ( G ( z ) ) ) \log (1-D(G(z))) log(1−D(G(z)))表示“正确识别伪造样本”的对数损失( D ( G ( z ) ) D(G(z)) D(G(z))越小,该项越大)。
- 对生成器 G G G:需最小化目标函数 V ( D , G ) V(D,G) V(D,G)。生成器希望伪造样本被误认为真实样本,即 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))→1,此时 log ( 1 − D ( G ( z ) ) ) → − ∞ \log (1-D(G(z))) \to -\infty log(1−D(G(z)))→−∞,需通过最小化 V ( D , G ) V(D,G) V(D,G)让该项尽可能大(等价于最大化 log D ( G ( z ) ) \log D(G(z)) logD(G(z)),后续常用此等价形式避免梯度消失)。
3. GAN的训练流程
GAN的训练采用“交替迭代”方式,即先固定生成器训练判别器,再固定判别器训练生成器,重复此过程直到收敛。具体步骤如下:
-
初始化参数:
- 初始化生成器 G G G的参数 θ G \theta_G θG(如卷积层、全连接层权重);
- 初始化判别器 D D D的参数 θ D \theta_D θD。
-
迭代训练(直到收敛):
-
步骤1:训练判别器 D D D(固定 θ G \theta_G θG)
- 从真实数据集中采样 m m m个样本: x ( 1 ) , x ( 2 ) , . . . , x ( m ) ∼ p d a t a ( x ) x^{(1)}, x^{(2)}, ..., x^{(m)} \sim p_{data}(x) x(1),x(2),...,x(m)∼pdata(x);
- 从噪声分布中采样 m m m个噪声向量: z ( 1 ) , z ( 2 ) , . . . , z ( m ) ∼ p z ( z ) z^{(1)}, z^{(2)}, ..., z^{(m)} \sim p_z(z) z(1),z(2),...,z(m)∼pz(z),生成伪造样本 G ( z ( 1 ) ) , . . . , G ( z ( m ) ) G(z^{(1)}), ..., G(z^{(m)}) G(z(1)),...,G(z(m));
- 计算判别器的损失(基于交叉熵),并对
θ
D
\theta_D
θD求导更新:
L D = − 1 m [ ∑ i = 1 m log D ( x ( i ) ) + ∑ i = 1 m log ( 1 − D ( G ( z ( i ) ) ) ) ] \mathcal{L}_D = -\frac{1}{m}\left[ \sum_{i=1}^m \log D(x^{(i)}) + \sum_{i=1}^m \log (1 - D(G(z^{(i)}))) \right] LD=−m1[i=1∑mlogD(x(i))+i=1∑mlog(1−D(G(z(i))))]
通过梯度上升(或优化器如Adam)最小化 L D \mathcal{L}_D LD,即最大化 V ( D , G ) V(D,G) V(D,G)。
-
步骤2:训练生成器 G G G(固定 θ D \theta_D θD)
- 再次从噪声分布中采样 m m m个噪声向量: z ( 1 ) , . . . , z ( m ) ∼ p z ( z ) z^{(1)}, ..., z^{(m)} \sim p_z(z) z(1),...,z(m)∼pz(z);
- 计算生成器的损失(等价形式,避免梯度消失),并对
θ
G
\theta_G
θG求导更新:
L G = − 1 m ∑ i = 1 m log D ( G ( z ( i ) ) ) \mathcal{L}_G = -\frac{1}{m} \sum_{i=1}^m \log D(G(z^{(i)})) LG=−m1i=1∑mlogD(G(z(i)))
通过梯度下降最小化 L G \mathcal{L}_G LG,即让 D ( G ( z ) ) D(G(z)) D(G(z))尽可能接近1。
-
-
收敛判断:
- 当判别器对真实样本和伪造样本的输出概率均接近 0.5 0.5 0.5(即 D ( x ) ≈ 0.5 D(x) \approx 0.5 D(x)≈0.5且 D ( G ( z ) ) ≈ 0.5 D(G(z)) \approx 0.5 D(G(z))≈0.5);
- 生成样本的视觉/语义质量不再提升(如图像生成中,样本无明显模糊或重复)。
4. 原始GAN的核心问题
原始GAN虽思想巧妙,但在实际训练中存在两大核心问题,限制了其应用:
4.1 训练不稳定(梯度消失)
当生成器生成的样本质量较差时,判别器会轻易识别( D ( G ( z ) ) ≈ 0 D(G(z)) \approx 0 D(G(z))≈0),此时生成器的损失 log ( 1 − D ( G ( z ) ) ) ≈ log 1 = 0 \log(1-D(G(z))) \approx \log 1 = 0 log(1−D(G(z)))≈log1=0,梯度 ∇ θ G log ( 1 − D ( G ( z ) ) ) ≈ 0 \nabla_{\theta_G} \log(1-D(G(z))) \approx 0 ∇θGlog(1−D(G(z)))≈0,导致生成器参数无法更新(梯度消失)。
原因:判别器能力过强,生成器与判别器“实力失衡”,生成器无法从失败中学习。
4.2 模式崩溃(Mode Collapse)
模式崩溃是指生成器仅能生成少数几种类型的样本(即生成分布 p g ( x ) p_g(x) pg(x)仅覆盖真实分布 p d a t a ( x ) p_{data}(x) pdata(x)的一小部分),但判别器无法区分(因这些样本已能“欺骗”判别器)。
原因:生成器为最小化损失,倾向于生成判别器最易误判的样本(而非覆盖全部真实分布),导致生成样本多样性缺失。
5. 常见GAN变种及其改进
为解决原始GAN的问题,研究者提出了多种变种,以下是最常用的三类:
5.1 DCGAN(深度卷积GAN)
- 核心改进:用卷积层(判别器)和反卷积层(生成器)替代全连接层,减少参数并保留空间特征(适配图像数据)。
- 架构原则:
- 判别器:使用strided卷积(替代池化)下采样,增加感受野;
- 生成器:使用反卷积(转置卷积)上采样,逐步恢复图像分辨率;
- 移除全连接层,批量归一化(Batch Norm)加速训练。
- 应用场景:图像生成(如人脸、风景)、图像修复。
5.2 WGAN(Wasserstein GAN)
- 核心改进:用Wasserstein距离(Earth-Mover距离) 替代原始GAN的JS散度,衡量真实分布与生成分布的差异,解决训练不稳定和梯度消失问题。
- 关键设计:
- 判别器改为“评论家(Critic)”,输出连续值(非概率),用于计算Wasserstein距离;
- 加入权重裁剪(Weight Clipping),确保评论家满足Lipschitz条件(Wasserstein距离计算的前提)。
- 优势:损失值可直接反映生成样本质量(损失越小,质量越高),训练更稳定。
5.3 CycleGAN(循环一致性GAN)
- 核心改进:针对“无监督图像风格迁移”场景,引入循环一致性损失,解决“一对一”风格迁移的映射问题。
- 架构设计:
- 两个生成器: G : X → Y G: X \to Y G:X→Y(将X域图像转为Y域)、 F : Y → X F: Y \to X F:Y→X(将Y域图像转为X域);
- 两个判别器: D Y D_Y DY(区分Y域真实/生成图像)、 D X D_X DX(区分X域真实/生成图像);
- 循环一致性损失:确保
F
(
G
(
x
)
)
≈
x
F(G(x)) \approx x
F(G(x))≈x和
G
(
F
(
y
)
)
≈
y
G(F(y)) \approx y
G(F(y))≈y,避免生成图像脱离原域特征:
L c y c ( G , F ) = E x ∼ p X [ ∥ F ( G ( x ) ) − x ∥ 1 ] + E y ∼ p Y [ ∥ G ( F ( y ) ) − y ∥ 1 ] \mathcal{L}_{cyc}(G,F) = \mathbb{E}_{x \sim p_X}[\|F(G(x)) - x\|_1] + \mathbb{E}_{y \sim p_Y}[\|G(F(y)) - y\|_1] Lcyc(G,F)=Ex∼pX[∥F(G(x))−x∥1]+Ey∼pY[∥G(F(y))−y∥1]
- 应用场景:风格迁移(如照片转油画、白天转黑夜)、跨域图像转换(如马转斑马)。
6. GAN的典型应用领域
GAN凭借强大的生成能力,已在多个领域落地,典型应用包括:
| 应用领域 | 具体场景举例 | 常用模型 |
|---|---|---|
| 图像生成 | 高清人脸生成(StyleGAN)、动漫角色生成 | DCGAN、StyleGAN |
| 图像风格迁移 | 照片转梵高风格、黑白图转彩色 | CycleGAN、StarGAN |
| 超分辨率重建 | 低清图像提升至4K/8K分辨率 | SRGAN、ESRGAN |
| 文本生成图像 | 基于文字描述生成逼真图像(如“红色跑车”) | StackGAN、AttnGAN |
| 语音合成 | 基于文本生成自然语音(TTS) | GAN-TTS、MelGAN |
7. 总结与展望
7.1 核心总结
- GAN的本质是“对抗训练”:通过生成器与判别器的极小极大游戏,隐式学习真实数据分布;
- 原始GAN的痛点是训练不稳定和模式崩溃,后续变种(DCGAN、WGAN、CycleGAN)通过架构调整或损失函数改进逐步解决;
- GAN的优势在于生成样本的高逼真度,尤其适合图像、音频等复杂结构化数据。
7.2 未来方向
- 多模态GAN:融合文本、图像、音频等多模态信息,实现更精准的生成(如文本+图像生成视频);
- 效率优化:减少GAN的训练时间和计算资源消耗(如轻量化架构、分布式训练);
- 可控生成:实现对生成样本的精细控制(如指定人脸表情、图像风格强度)。
1083

被折叠的 条评论
为什么被折叠?



