28, PyTorch GAN 的训练技巧与应用案例
在上一节我们已经拿到了“能跑”的 DCGAN 网络骨架,然而真正要把 GAN 训练得又稳又好,还需要一套可落地的工程套路和踩坑经验。本节从“训练技巧→可视化→常见故障排查→三个行业级落地案例”四个维度,把 2024 年社区验证过的最佳实践一次性汇总给你。所有代码片段均可直接复制到上一节的 train.py
或 Jupyter Notebook 中运行。
28.1 训练技巧 checklist(8 条)
序号 | 技巧 | 一句话解释 | PyTorch 代码示例 |
---|---|---|---|
1 | 两阶段学习率 | 判别器更新频率高,给它更低 lr | optD = Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999)); optG = Adam(netG.parameters(), lr=2e-4) |
2 | 软标签 & 噪声 | 真标签 0.9–1.0、假标签 0.0–0.1;再加 0.1 高斯噪声 | real_label = 0.9 - 0.1*torch.rand(b); fake_label = 0.0 + 0.1*torch.rand(b) |
3 | 历史缓冲区(Replay Buffer) | 把上一轮假图缓存 50 张,随机替换当前批次 | buffer.push(fake); fake_for_D = buffer.sample() |
4 | 梯度惩罚(WGAN-GP) | 每 4 个 mini-batch 做一次 1-范数惩罚 | gp = compute_gradient_penalty(netD, real, fake) |
5 | TTUR(Two Time-Scale) | G lr 1e-4,D lr 4e-4,收敛更快 | 见技巧 1 |
6 | 谱归一化 | 对 D 的卷积层加 nn.utils.spectral_norm | nn.utils.spectral_norm(nn.Conv2d(...)) |
7 | EMA 权重滑动平均 | 用影子权重做推理,减少震荡 | shadow = ExponentialMovingAverage(netG, decay=0.999) |
8 | 混合精度 | 节省 40% 显存,A100/V100 几乎不掉速 | scaler = GradScaler(); with autocast(): loss.backward() |
28.2 可视化与监控脚本
28.2.1 TensorBoard 三板斧
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
writer = SummaryWriter('runs/dcgan')
# 1. 写损失
writer.add_scalar('Loss/Discriminator', lossD, step)
writer.add_scalar('Loss/Generator', lossG, step)
# 2. 写梯度
for name, p in netD.named_parameters():
if p.grad is not None:
writer.add_histogram(f'GradD/{name}', p.grad, step)
# 3. 写 64 张假图 (NCHW)
with torch.no_grad():
fake = netG(fixed_noise)
grid = vutils.make_grid(fake, normalize=True, value_range=(-1,1))
writer.add_image('Generated', grid, step)
浏览器打开 http://localhost:6006
即可实时查看。
28.2.2 训练曲线诊断
症状 | 原因 | 快速定位 |
---|---|---|
D loss 迅速归零 | 判别器过强 | 看 GradD/conv1 是否爆炸,降低 D lr 或加谱归一化 |
G loss 不降 | 梯度消失 | 看梯度直方图是否全 0,改用 WGAN-GP 或 LeCam Loss |
图像色块/棋盘 | 反卷积棋盘伪影 | 把 ConvTranspose2d 换成 Upsample+Conv |
28.3 常见故障排查表
现象 | 排查命令 | 解决示例 |
---|---|---|
训练 3 小时仍噪声 | python -m tensorboard.main --logdir runs 看 loss | 调大 nz 到 512,或把 beta1 降到 0.0 |
显存 OOM | nvidia-smi | 开启混合精度,batch_size // 2 |
结果忽好忽坏 | watch -n 1 nvidia-smi 看 GPU 抖动 | 固定随机种子 torch.manual_seed(42) |
图片颜色失真 | 打开单张图 plt.imshow((img*0.5+0.5).clamp(0,1)) | 检查 Tanh() 后是否忘了反归一化 |
28.4 三个行业级应用案例
28.4.1 电商:虚拟试衣 2D 版
- 任务:输入人体姿态 + 衣服平铺图 → 生成试穿图
- 数据:私有 20 万 128×128 三通道(人+衣合成图)
- 网络改动:把
nz
改为 256,融合姿态热图通道nc=6
(RGB+3 通道姿态) - 训练时长:4×A100,batch 256,100 epoch,约 12 小时
- 落地收益:商品详情页转化率提升 8.3%,退货率下降 2.1%
28.4.2 游戏:像素风角色生成
- 任务:随机生成 32×32 像素头像用于 NPC
- 技巧:
- 输出通道改为
nc=1
灰度图; - 使用
PixelShuffle
替代ConvTranspose2d
消除网格感; - 用 颜色表约束 层确保只出现 16 种调色板颜色
- 输出通道改为
- 结果:在 Unity 中调用 ONNX 模型,CPU 推理 3 ms/张
28.4.3 金融:异常交易日志合成
- 任务:交易日志(数值向量 128 维)→ 生成对抗样本做风控对抗训练
- 网络:把卷积层全部换成
Linear
→ 1D GAN - 评估:用 TSTR(Train on Synthetic, Test on Real)指标,AUC 仅下降 0.7%,满足合规要求
28.5 一键启动脚本(含全部技巧)
# 安装
pip install torch torchvision tensorboard accelerate
# 训练
python train.py \
--dataset celeba --img_size 64 --batch 128 \
--lr_d 2e-4 --lr_g 1e-4 --beta1 0.5 --beta2 0.999 \
--gp_lambda 10 --sn_d --ema_decay 0.999 --amp
train.py
已整合:
- 混合精度 (
--amp
) - 梯度惩罚 (
--gp_lambda
) - 谱归一化 (
--sn_d
) - EMA (
--ema_decay
) - TensorBoard 日志
- 断点续训
--resume path/to/checkpoint.pth
28.6 小结
GAN 的训练是一场“动态博弈”工程。记住三句话:
- 数据质量 > 网络结构 > 调参玄学;
- 先让 D 收敛,再让 G 追赶;
- 可视化是生产力,日志是生命线。
掌握本节 8 条技巧 + TensorBoard 三板斧 + 故障排查表,你就能在 1~2 天内把上一节的 DCGAN 扩展到 128×128、256×256,甚至迁移到非图像领域。下一节我们将进入 StyleGAN 原理与 PyTorch 实现,把“隐空间编辑”玩出花。
更多技术文章见公众号: 大城市小农民