27, PyTorch 生成器与判别器的实现
在上一节我们梳理了 GAN 的核心原理与训练范式,本节给出在 PyTorch 2.x 环境下最常用、最稳定的 DCGAN 生成器(Generator)与判别器(Discriminator)完整代码实现。代码遵循以下设计原则:
- 模块化:每个子块独立成类,方便后续替换为残差块、注意力块等。
- 可配置:通过
__init__
形参即可调整通道数、图像尺寸、激活函数等。 - 易扩展:注释清晰,关键维度一目了然,方便快速迁移到 128×128 或 RGBA 四通道任务。
以下代码可直接保存为 models.py
,在任何训练脚本中通过 from models import Generator, Discriminator
调用。
27.1 生成器 Generator
27.1.1 设计要点
- 输入:长度为
nz
的一维噪声向量(默认 100)。 - 上采样:使用
ConvTranspose2d
+BatchNorm2d
+ReLU
四连级,逐步将 4×4 特征图放大到 64×64。 - 输出:Tanh 激活,像素值落在 [-1, 1],与 torchvision 标准化区间一致。
- 初始化:采用 DCGAN 论文推荐的
N(0, 0.02)
正态分布。
27.1.2 代码
import torch
import torch.nn as nn
class GenBlock(nn.Module):
"""
生成器基础块:
ConvTranspose2d -> BatchNorm2d -> ReLU
"""
def __init__(self, in_ch, out_ch, kernel=4, stride=2, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, kernel, stride, padding, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class Generator(nn.Module):
def __init__(self, nz: int = 100, ngf: int = 64, nc: int = 3):
"""
nz: 噪声维度
ngf: 生成器基础通道数
nc: 输出图像通道数
"""
super().__init__()
self.main = nn.Sequential(
# 100 -> 4*4*512
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(inplace=True),
# 4*4 -> 8*8
GenBlock(ngf * 8, ngf * 4),
# 8*8 -> 16*16
GenBlock(ngf * 4, ngf * 2),
# 16*16 -> 32*32
GenBlock(ngf * 2, ngf),
# 32*32 -> 64*64
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
# z: (B, nz) -> (B, nz, 1, 1)
z = z.view(z.size(0), -1, 1, 1)
return self.main(z)
27.2 判别器 Discriminator
27.2.1 设计要点
- 输入:RGB 图像,尺寸 64×64,已归一化到 [-1, 1]。
- 下采样:使用
Conv2d
+BatchNorm2d
+LeakyReLU(0.2)
四连级。 - 输出:无 Sigmoid,直接给出 logit,配合
BCEWithLogitsLoss
数值更稳。 - 初始化:同生成器,权重
N(0, 0.02)
。
27.2.2 代码
class DisBlock(nn.Module):
"""
判别器基础块:
Conv2d -> BatchNorm2d -> LeakyReLU(0.2)
"""
def __init__(self, in_ch, out_ch, kernel=4, stride=2, padding=1, bn=True):
super().__init__()
layers = [nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=False)]
if bn:
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class Discriminator(nn.Module):
def __init__(self, nc: int = 3, ndf: int = 64):
"""
nc: 输入图像通道数
ndf: 判别器基础通道数
"""
super().__init__()
self.main = nn.Sequential(
# 64*64 -> 32*32, 无 BN
DisBlock(nc, ndf, bn=False),
# 32*32 -> 16*16
DisBlock(ndf, ndf * 2),
# 16*16 -> 8*8
DisBlock(ndf * 2, ndf * 4),
# 8*8 -> 4*4
DisBlock(ndf * 4, ndf * 8),
# 4*4 -> 1*1
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
# 不经过 Sigmoid
)
def forward(self, x):
return self.main(x).view(-1) # (B, 1) -> (B,)
27.3 权重初始化函数
在实例化网络后,调用一次即可:
def weights_init(m):
classname = m.__class__.__name__
if "Conv" in classname:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif "BatchNorm" in classname:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# 使用示例
netG = Generator(nz=100, ngf=64, nc=3).apply(weights_init)
netD = Discriminator(nc=3, ndf=64).apply(weights_init)
27.4 迁移到 128×128 或更高分辨率
若需生成 128×128 图像,只需在 Generator
与 Discriminator
各插入一个 GenBlock / DisBlock
:
- 生成器:在
GenBlock(ngf*2, ngf)
后再加一个GenBlock(ngf, ngf//2)
,最后输出层通道改为ngf//2
。 - 判别器:在
DisBlock(ndf*4, ndf*8)
前插入DisBlock(ndf*2, ndf*4)
,对应调整通道数即可。
无需改动初始化函数,所有卷积核尺寸、步长保持不变,代码复用率 100%。
27.5 小结
本节给出了可直接落地的 DCGAN 网络实现,重点展示了 PyTorch 中 “小步快跑” 的卷积–反卷积范式。通过 GenBlock / DisBlock
的复用,你可以在五分钟内把它扩展为 128×128、256×256 甚至 StyleGAN 的雏形。下一节我们将基于这两个网络,用 CelebA 人脸数据完成一次端到端训练,并给出 TensorBoard 可视化脚本。
更多技术文章见公众号: 大城市小农民