模式识别与机器学习课程笔记(12):生成对抗网络

模式识别与机器学习课程笔记(12):生成对抗网络(GAN)

1. 引言:为什么需要GAN?

在传统生成模型(如变分自编码器VAE)中,生成过程依赖于显式的概率分布假设(如高斯分布),且生成样本的质量和多样性常受限于模型对数据分布的拟合能力。2014年,Ian Goodfellow等人提出的生成对抗网络(Generative Adversarial Networks, GAN) 打破了这一局限——它通过“对抗训练”的思想,让两个网络(生成器与判别器)相互博弈,无需显式定义数据分布即可学习到数据的真实特征,最终生成高度逼真的样本。

GAN的核心价值在于:

  • 摆脱对显式概率分布的依赖,更贴合真实世界数据的复杂分布;
  • 生成样本的细节丰富度更高(如图像、音频),在计算机视觉、自然语言处理等领域应用广泛;
  • 启发了大量变种模型(如DCGAN、WGAN、CycleGAN),成为生成式AI的重要基石。

2. GAN的基本原理

GAN的架构由两个核心网络组成:生成器(Generator, G)判别器(Discriminator, D),二者通过“极小极大游戏(Min-Max Game)”实现共同优化。

2.1 核心思想:对抗与协作

  • 判别器D:本质是一个二分类器,目标是区分输入样本“是真实数据(来自数据分布 p d a t a ( x ) p_{data}(x) pdata(x))”还是“生成器伪造的数据(来自生成分布 p g ( x ) p_g(x) pg(x))”,输出样本为真实数据的概率(范围 [ 0 , 1 ] [0,1] [0,1])。
  • 生成器G:接收一个随机噪声向量 z z z(来自噪声分布 p z ( z ) p_z(z) pz(z),如高斯分布),通过非线性变换将其映射为与真实数据维度一致的生成样本 G ( z ) G(z) G(z),目标是让生成样本足够逼真,以至于判别器无法区分(即让 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))1)。

二者的对抗关系可概括为:判别器努力提高区分准确率,生成器努力降低判别器的区分准确率,最终达到纳什均衡——判别器无法区分真实与生成样本(输出概率恒为 0.5 0.5 0.5),生成器学到了真实数据的分布。

2.2 网络结构与数学定义

(1)符号定义
  • x x x:真实数据样本, x ∼ p d a t a ( x ) x \sim p_{data}(x) xpdata(x) p d a t a p_{data} pdata为真实数据分布);
  • z z z:随机噪声向量, z ∼ p z ( z ) z \sim p_z(z) zpz(z) p z p_z pz为噪声分布,常用 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1));
  • G ( z ) G(z) G(z):生成器输出的伪造样本, G : R d → R n G: \mathbb{R}^d \to \mathbb{R}^n G:RdRn d d d为噪声维度, n n n为真实数据维度);
  • D ( x ) D(x) D(x):判别器对样本 x x x的输出(真实概率), D : R n → [ 0 , 1 ] D: \mathbb{R}^n \to [0,1] D:Rn[0,1]
(2)目标函数(极小极大游戏)

GAN的核心是求解以下极小极大优化问题:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  • 对判别器 D D D:需最大化目标函数 V ( D , G ) V(D,G) V(D,G)。第一项 log ⁡ D ( x ) \log D(x) logD(x)表示“正确识别真实样本”的对数损失( D ( x ) D(x) D(x)越大,该项越大);第二项 log ⁡ ( 1 − D ( G ( z ) ) ) \log (1-D(G(z))) log(1D(G(z)))表示“正确识别伪造样本”的对数损失( D ( G ( z ) ) D(G(z)) D(G(z))越小,该项越大)。
  • 对生成器 G G G:需最小化目标函数 V ( D , G ) V(D,G) V(D,G)。生成器希望伪造样本被误认为真实样本,即 D ( G ( z ) ) → 1 D(G(z)) \to 1 D(G(z))1,此时 log ⁡ ( 1 − D ( G ( z ) ) ) → − ∞ \log (1-D(G(z))) \to -\infty log(1D(G(z))),需通过最小化 V ( D , G ) V(D,G) V(D,G)让该项尽可能大(等价于最大化 log ⁡ D ( G ( z ) ) \log D(G(z)) logD(G(z)),后续常用此等价形式避免梯度消失)。

3. GAN的训练流程

GAN的训练采用“交替迭代”方式,即先固定生成器训练判别器,再固定判别器训练生成器,重复此过程直到收敛。具体步骤如下:

  1. 初始化参数

    • 初始化生成器 G G G的参数 θ G \theta_G θG(如卷积层、全连接层权重);
    • 初始化判别器 D D D的参数 θ D \theta_D θD
  2. 迭代训练(直到收敛)

    • 步骤1:训练判别器 D D D(固定 θ G \theta_G θG

      1. 从真实数据集中采样 m m m个样本: x ( 1 ) , x ( 2 ) , . . . , x ( m ) ∼ p d a t a ( x ) x^{(1)}, x^{(2)}, ..., x^{(m)} \sim p_{data}(x) x(1),x(2),...,x(m)pdata(x)
      2. 从噪声分布中采样 m m m个噪声向量: z ( 1 ) , z ( 2 ) , . . . , z ( m ) ∼ p z ( z ) z^{(1)}, z^{(2)}, ..., z^{(m)} \sim p_z(z) z(1),z(2),...,z(m)pz(z),生成伪造样本 G ( z ( 1 ) ) , . . . , G ( z ( m ) ) G(z^{(1)}), ..., G(z^{(m)}) G(z(1)),...,G(z(m))
      3. 计算判别器的损失(基于交叉熵),并对 θ D \theta_D θD求导更新:
        L D = − 1 m [ ∑ i = 1 m log ⁡ D ( x ( i ) ) + ∑ i = 1 m log ⁡ ( 1 − D ( G ( z ( i ) ) ) ) ] \mathcal{L}_D = -\frac{1}{m}\left[ \sum_{i=1}^m \log D(x^{(i)}) + \sum_{i=1}^m \log (1 - D(G(z^{(i)}))) \right] LD=m1[i=1mlogD(x(i))+i=1mlog(1D(G(z(i))))]
        通过梯度上升(或优化器如Adam)最小化 L D \mathcal{L}_D LD,即最大化 V ( D , G ) V(D,G) V(D,G)
    • 步骤2:训练生成器 G G G(固定 θ D \theta_D θD

      1. 再次从噪声分布中采样 m m m个噪声向量: z ( 1 ) , . . . , z ( m ) ∼ p z ( z ) z^{(1)}, ..., z^{(m)} \sim p_z(z) z(1),...,z(m)pz(z)
      2. 计算生成器的损失(等价形式,避免梯度消失),并对 θ G \theta_G θG求导更新:
        L G = − 1 m ∑ i = 1 m log ⁡ D ( G ( z ( i ) ) ) \mathcal{L}_G = -\frac{1}{m} \sum_{i=1}^m \log D(G(z^{(i)})) LG=m1i=1mlogD(G(z(i)))
        通过梯度下降最小化 L G \mathcal{L}_G LG,即让 D ( G ( z ) ) D(G(z)) D(G(z))尽可能接近1。
  3. 收敛判断

    • 当判别器对真实样本和伪造样本的输出概率均接近 0.5 0.5 0.5(即 D ( x ) ≈ 0.5 D(x) \approx 0.5 D(x)0.5 D ( G ( z ) ) ≈ 0.5 D(G(z)) \approx 0.5 D(G(z))0.5);
    • 生成样本的视觉/语义质量不再提升(如图像生成中,样本无明显模糊或重复)。

4. 原始GAN的核心问题

原始GAN虽思想巧妙,但在实际训练中存在两大核心问题,限制了其应用:

4.1 训练不稳定(梯度消失)

当生成器生成的样本质量较差时,判别器会轻易识别( D ( G ( z ) ) ≈ 0 D(G(z)) \approx 0 D(G(z))0),此时生成器的损失 log ⁡ ( 1 − D ( G ( z ) ) ) ≈ log ⁡ 1 = 0 \log(1-D(G(z))) \approx \log 1 = 0 log(1D(G(z)))log1=0,梯度 ∇ θ G log ⁡ ( 1 − D ( G ( z ) ) ) ≈ 0 \nabla_{\theta_G} \log(1-D(G(z))) \approx 0 θGlog(1D(G(z)))0,导致生成器参数无法更新(梯度消失)。

原因:判别器能力过强,生成器与判别器“实力失衡”,生成器无法从失败中学习。

4.2 模式崩溃(Mode Collapse)

模式崩溃是指生成器仅能生成少数几种类型的样本(即生成分布 p g ( x ) p_g(x) pg(x)仅覆盖真实分布 p d a t a ( x ) p_{data}(x) pdata(x)的一小部分),但判别器无法区分(因这些样本已能“欺骗”判别器)。

原因:生成器为最小化损失,倾向于生成判别器最易误判的样本(而非覆盖全部真实分布),导致生成样本多样性缺失。

5. 常见GAN变种及其改进

为解决原始GAN的问题,研究者提出了多种变种,以下是最常用的三类:

5.1 DCGAN(深度卷积GAN)

  • 核心改进:用卷积层(判别器)和反卷积层(生成器)替代全连接层,减少参数并保留空间特征(适配图像数据)。
  • 架构原则
    • 判别器:使用strided卷积(替代池化)下采样,增加感受野;
    • 生成器:使用反卷积(转置卷积)上采样,逐步恢复图像分辨率;
    • 移除全连接层,批量归一化(Batch Norm)加速训练。
  • 应用场景:图像生成(如人脸、风景)、图像修复。

5.2 WGAN(Wasserstein GAN)

  • 核心改进:用Wasserstein距离(Earth-Mover距离) 替代原始GAN的JS散度,衡量真实分布与生成分布的差异,解决训练不稳定和梯度消失问题。
  • 关键设计
    • 判别器改为“评论家(Critic)”,输出连续值(非概率),用于计算Wasserstein距离;
    • 加入权重裁剪(Weight Clipping),确保评论家满足Lipschitz条件(Wasserstein距离计算的前提)。
  • 优势:损失值可直接反映生成样本质量(损失越小,质量越高),训练更稳定。

5.3 CycleGAN(循环一致性GAN)

  • 核心改进:针对“无监督图像风格迁移”场景,引入循环一致性损失,解决“一对一”风格迁移的映射问题。
  • 架构设计
    • 两个生成器: G : X → Y G: X \to Y G:XY(将X域图像转为Y域)、 F : Y → X F: Y \to X F:YX(将Y域图像转为X域);
    • 两个判别器: D Y D_Y DY(区分Y域真实/生成图像)、 D X D_X DX(区分X域真实/生成图像);
    • 循环一致性损失:确保 F ( G ( x ) ) ≈ x F(G(x)) \approx x F(G(x))x G ( F ( y ) ) ≈ y G(F(y)) \approx y G(F(y))y,避免生成图像脱离原域特征:
      L c y c ( G , F ) = E x ∼ p X [ ∥ F ( G ( x ) ) − x ∥ 1 ] + E y ∼ p Y [ ∥ G ( F ( y ) ) − y ∥ 1 ] \mathcal{L}_{cyc}(G,F) = \mathbb{E}_{x \sim p_X}[\|F(G(x)) - x\|_1] + \mathbb{E}_{y \sim p_Y}[\|G(F(y)) - y\|_1] Lcyc(G,F)=ExpX[F(G(x))x1]+EypY[G(F(y))y1]
  • 应用场景:风格迁移(如照片转油画、白天转黑夜)、跨域图像转换(如马转斑马)。

6. GAN的典型应用领域

GAN凭借强大的生成能力,已在多个领域落地,典型应用包括:

应用领域具体场景举例常用模型
图像生成高清人脸生成(StyleGAN)、动漫角色生成DCGAN、StyleGAN
图像风格迁移照片转梵高风格、黑白图转彩色CycleGAN、StarGAN
超分辨率重建低清图像提升至4K/8K分辨率SRGAN、ESRGAN
文本生成图像基于文字描述生成逼真图像(如“红色跑车”)StackGAN、AttnGAN
语音合成基于文本生成自然语音(TTS)GAN-TTS、MelGAN

7. 总结与展望

7.1 核心总结

  • GAN的本质是“对抗训练”:通过生成器与判别器的极小极大游戏,隐式学习真实数据分布;
  • 原始GAN的痛点是训练不稳定和模式崩溃,后续变种(DCGAN、WGAN、CycleGAN)通过架构调整或损失函数改进逐步解决;
  • GAN的优势在于生成样本的高逼真度,尤其适合图像、音频等复杂结构化数据。

7.2 未来方向

  • 多模态GAN:融合文本、图像、音频等多模态信息,实现更精准的生成(如文本+图像生成视频);
  • 效率优化:减少GAN的训练时间和计算资源消耗(如轻量化架构、分布式训练);
  • 可控生成:实现对生成样本的精细控制(如指定人脸表情、图像风格强度)。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BetterInsight

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值