26, PyTorch GAN 的原理与结构
生成对抗网络(Generative Adversarial Network,GAN)自 2014 年 Ian Goodfellow 提出以来,已成为深度学习领域最具影响力的生成模型之一。PyTorch 的动态图机制与模块化设计,使得实现与调试 GAN 变得直观、高效。本节在冻结与微调的基础上,继续深入 PyTorch 中 GAN 的核心原理、网络结构、训练流程与常见实现细节,帮助读者快速搭建并稳定训练自己的 GAN 模型。
26.1 GAN 的核心思想
GAN 由两个互相对抗的神经网络构成:
- 生成器(Generator, G):接收随机噪声 z(通常从 N(0, I) 或 U(−1, 1) 采样),输出伪造样本 G(z)。
- 判别器(Discriminator, D):接收真实样本 x 或伪造样本 G(z),输出一个标量,表示输入样本为真的概率。
训练目标是一个极小–极大博弈:
[
\min_G \max_D V(D,G)= \mathbb{E}{x\sim p{\text{data}}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))]
]
- 判别器希望最大化 V,即正确区分真伪;
- 生成器希望最小化 V,即让判别器将伪造样本误判为真实。
26.2 GAN 的网络结构
26.2.1 生成器(G)
以 DCGAN 为例,输入 100 维噪声向量,通过若干转置卷积(ConvTranspose2d)逐步上采样到目标尺寸(如 64×64×3)。典型结构:
层 | 配置 | 输出尺寸 |
---|---|---|
Linear | 100 → 4×4×512 | 4×4×512 |
ConvTranspose2d | k=4, s=2, p=1 | 8×8×256 |
ConvTranspose2d | k=4, s=2, p=1 | 16×16×128 |
ConvTranspose2d | k=4, s=2, p=1 | 32×32×64 |
ConvTranspose2d | k=4, s=2, p=1 | 64×64×3 |
激活函数:除最后一层用 Tanh 外,中间层使用 ReLU 或 LeakyReLU;批归一化(BatchNorm)通常紧跟卷积层。
26.2.2 判别器(D)
镜像生成器:将 64×64×3 图像逐步下采样至 1×1×1 的概率值。结构示例:
层 | 配置 | 输出尺寸 |
---|---|---|
Conv2d | k=4, s=2, p=1 | 32×32×64 |
Conv2d | k=4, s=2, p=1 | 16×16×128 |
Conv2d | k=4, s=2, p=1 | 8×8×256 |
Conv2d | k=4, s=2, p=1 | 4×4×512 |
Conv2d | k=4, s=1, p=0 | 1×1×1 |
激活函数:除最后一层外,使用 LeakyReLU(0.2);不使用 BatchNorm 的最后一层。
26.3 PyTorch 实现要点
26.3.1 权重初始化
DCGAN 论文指出,G 与 D 的权重需做特定初始化以稳定训练:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
创建网络后统一调用:
netG = Generator().apply(weights_init)
netD = Discriminator().apply(weights_init)
26.3.2 优化器与学习率
- 使用 Adam,β1=0.5,β2=0.999;
- 学习率通常 2e-4,也可使用 TTUR(G 1e-4,D 4e-4);
- 避免使用带动量的 SGD,容易产生振荡。
26.3.3 损失函数
原始 GAN 使用 BCEWithLogitsLoss;若用 WGAN-GP,则需自定义梯度惩罚项。
criterion = nn.BCEWithLogitsLoss()
real_label = 1.
fake_label = 0.
# 训练 D
optimizerD.zero_grad()
output = netD(real_imgs).view(-1)
errD_real = criterion(output, torch.full_like(output, real_label))
output = netD(fake_imgs.detach()).view(-1)
errD_fake = criterion(output, torch.full_like(output, fake_label))
errD = errD_real + errD_fake
errD.backward()
optimizerD.step()
# 训练 G
optimizerG.zero_grad()
output = netD(fake_imgs).view(-1)
errG = criterion(output, torch.full_like(output, real_label))
errG.backward()
optimizerG.step()
26.3.4 训练循环范式
- 固定 G,训练 k 步 D(通常 k=1);
- 固定 D,训练 1 步 G;
- 交替进行,直至收敛。
for epoch in range(num_epochs):
for i, (real_batch, _) in enumerate(dataloader):
# 1. 更新 D
...
# 2. 更新 G
...
# 3. 记录损失、采样图像
26.4 稳定训练技巧
问题 | 现象 | 解决策略 |
---|---|---|
梯度消失 | D 输出 0,G 无法学习 | 使用 LeakyReLU、标签平滑、TTUR |
模式崩塌 | G 生成少量重复样本 | Minibatch stddev、Unrolling、谱归一化 |
训练振荡 | 损失剧烈波动 | 历史平均参数、梯度惩罚(WGAN-GP) |
26.5 进阶结构速览
- cGAN:引入条件向量 y,实现类别控制;
- WGAN/WGAN-GP:Earth-Mover 距离、梯度惩罚,缓解梯度消失;
- SAGAN:引入自注意力机制,捕捉长程依赖;
- StyleGAN/StyleGAN2:逐层风格调制,生成高质量人脸;
- Diffusion-GAN 混合:用扩散模型改进 GAN 的生成能力。
26.6 小结
GAN 通过对抗博弈实现高质量生成,PyTorch 的灵活性让实验变得轻松。掌握 G 与 D 的网络设计、初始化、损失与优化器选择,以及稳定训练技巧,是构建成功 GAN 的关键。下一节将在此基础上,实现一个完整的 DCGAN 并在 CelebA 人脸数据集上训练,展示从数据到可视化的完整流程。
更多技术文章见公众号: 大城市小农民