26,PyTorch GAN 的原理与结构

在这里插入图片描述

26, PyTorch GAN 的原理与结构

生成对抗网络(Generative Adversarial Network,GAN)自 2014 年 Ian Goodfellow 提出以来,已成为深度学习领域最具影响力的生成模型之一。PyTorch 的动态图机制与模块化设计,使得实现与调试 GAN 变得直观、高效。本节在冻结与微调的基础上,继续深入 PyTorch 中 GAN 的核心原理、网络结构、训练流程与常见实现细节,帮助读者快速搭建并稳定训练自己的 GAN 模型。


26.1 GAN 的核心思想

GAN 由两个互相对抗的神经网络构成:

  • 生成器(Generator, G):接收随机噪声 z(通常从 N(0, I) 或 U(−1, 1) 采样),输出伪造样本 G(z)。
  • 判别器(Discriminator, D):接收真实样本 x 或伪造样本 G(z),输出一个标量,表示输入样本为真的概率。

训练目标是一个极小–极大博弈:

[
\min_G \max_D V(D,G)= \mathbb{E}{x\sim p{\text{data}}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))]
]

  • 判别器希望最大化 V,即正确区分真伪;
  • 生成器希望最小化 V,即让判别器将伪造样本误判为真实。

26.2 GAN 的网络结构

26.2.1 生成器(G)

以 DCGAN 为例,输入 100 维噪声向量,通过若干转置卷积(ConvTranspose2d)逐步上采样到目标尺寸(如 64×64×3)。典型结构:

配置输出尺寸
Linear100 → 4×4×5124×4×512
ConvTranspose2dk=4, s=2, p=18×8×256
ConvTranspose2dk=4, s=2, p=116×16×128
ConvTranspose2dk=4, s=2, p=132×32×64
ConvTranspose2dk=4, s=2, p=164×64×3

激活函数:除最后一层用 Tanh 外,中间层使用 ReLU 或 LeakyReLU;批归一化(BatchNorm)通常紧跟卷积层。

26.2.2 判别器(D)

镜像生成器:将 64×64×3 图像逐步下采样至 1×1×1 的概率值。结构示例:

配置输出尺寸
Conv2dk=4, s=2, p=132×32×64
Conv2dk=4, s=2, p=116×16×128
Conv2dk=4, s=2, p=18×8×256
Conv2dk=4, s=2, p=14×4×512
Conv2dk=4, s=1, p=01×1×1

激活函数:除最后一层外,使用 LeakyReLU(0.2);不使用 BatchNorm 的最后一层。


26.3 PyTorch 实现要点

26.3.1 权重初始化

DCGAN 论文指出,G 与 D 的权重需做特定初始化以稳定训练:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

创建网络后统一调用:

netG = Generator().apply(weights_init)
netD = Discriminator().apply(weights_init)

26.3.2 优化器与学习率

  • 使用 Adam,β1=0.5,β2=0.999;
  • 学习率通常 2e-4,也可使用 TTUR(G 1e-4,D 4e-4);
  • 避免使用带动量的 SGD,容易产生振荡。

26.3.3 损失函数

原始 GAN 使用 BCEWithLogitsLoss;若用 WGAN-GP,则需自定义梯度惩罚项。

criterion = nn.BCEWithLogitsLoss()

real_label = 1.
fake_label = 0.

# 训练 D
optimizerD.zero_grad()
output = netD(real_imgs).view(-1)
errD_real = criterion(output, torch.full_like(output, real_label))
output = netD(fake_imgs.detach()).view(-1)
errD_fake = criterion(output, torch.full_like(output, fake_label))
errD = errD_real + errD_fake
errD.backward()
optimizerD.step()

# 训练 G
optimizerG.zero_grad()
output = netD(fake_imgs).view(-1)
errG = criterion(output, torch.full_like(output, real_label))
errG.backward()
optimizerG.step()

26.3.4 训练循环范式

  1. 固定 G,训练 k 步 D(通常 k=1);
  2. 固定 D,训练 1 步 G;
  3. 交替进行,直至收敛。
for epoch in range(num_epochs):
    for i, (real_batch, _) in enumerate(dataloader):
        # 1. 更新 D
        ...
        # 2. 更新 G
        ...
        # 3. 记录损失、采样图像

26.4 稳定训练技巧

问题现象解决策略
梯度消失D 输出 0,G 无法学习使用 LeakyReLU、标签平滑、TTUR
模式崩塌G 生成少量重复样本Minibatch stddev、Unrolling、谱归一化
训练振荡损失剧烈波动历史平均参数、梯度惩罚(WGAN-GP)

26.5 进阶结构速览

  • cGAN:引入条件向量 y,实现类别控制;
  • WGAN/WGAN-GP:Earth-Mover 距离、梯度惩罚,缓解梯度消失;
  • SAGAN:引入自注意力机制,捕捉长程依赖;
  • StyleGAN/StyleGAN2:逐层风格调制,生成高质量人脸;
  • Diffusion-GAN 混合:用扩散模型改进 GAN 的生成能力。

26.6 小结

GAN 通过对抗博弈实现高质量生成,PyTorch 的灵活性让实验变得轻松。掌握 G 与 D 的网络设计、初始化、损失与优化器选择,以及稳定训练技巧,是构建成功 GAN 的关键。下一节将在此基础上,实现一个完整的 DCGAN 并在 CelebA 人脸数据集上训练,展示从数据到可视化的完整流程。
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值