27,PyTorch 生成器与判别器的实现

在这里插入图片描述

27, PyTorch 生成器与判别器的实现

在上一节我们梳理了 GAN 的核心原理与训练范式,本节给出在 PyTorch 2.x 环境下最常用、最稳定的 DCGAN 生成器(Generator)与判别器(Discriminator)完整代码实现。代码遵循以下设计原则:

  • 模块化:每个子块独立成类,方便后续替换为残差块、注意力块等。
  • 可配置:通过 __init__ 形参即可调整通道数、图像尺寸、激活函数等。
  • 易扩展:注释清晰,关键维度一目了然,方便快速迁移到 128×128 或 RGBA 四通道任务。

以下代码可直接保存为 models.py,在任何训练脚本中通过 from models import Generator, Discriminator 调用。


27.1 生成器 Generator

27.1.1 设计要点

  • 输入:长度为 nz 的一维噪声向量(默认 100)。
  • 上采样:使用 ConvTranspose2d + BatchNorm2d + ReLU 四连级,逐步将 4×4 特征图放大到 64×64。
  • 输出:Tanh 激活,像素值落在 [-1, 1],与 torchvision 标准化区间一致。
  • 初始化:采用 DCGAN 论文推荐的 N(0, 0.02) 正态分布。

27.1.2 代码

import torch
import torch.nn as nn

class GenBlock(nn.Module):
    """
    生成器基础块:
    ConvTranspose2d -> BatchNorm2d -> ReLU
    """
    def __init__(self, in_ch, out_ch, kernel=4, stride=2, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel, stride, padding, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class Generator(nn.Module):
    def __init__(self, nz: int = 100, ngf: int = 64, nc: int = 3):
        """
        nz: 噪声维度
        ngf: 生成器基础通道数
        nc: 输出图像通道数
        """
        super().__init__()
        self.main = nn.Sequential(
            # 100 -> 4*4*512
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),

            # 4*4 -> 8*8
            GenBlock(ngf * 8, ngf * 4),

            # 8*8 -> 16*16
            GenBlock(ngf * 4, ngf * 2),

            # 16*16 -> 32*32
            GenBlock(ngf * 2, ngf),

            # 32*32 -> 64*64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        # z: (B, nz) -> (B, nz, 1, 1)
        z = z.view(z.size(0), -1, 1, 1)
        return self.main(z)

27.2 判别器 Discriminator

27.2.1 设计要点

  • 输入:RGB 图像,尺寸 64×64,已归一化到 [-1, 1]。
  • 下采样:使用 Conv2d + BatchNorm2d + LeakyReLU(0.2) 四连级。
  • 输出:无 Sigmoid,直接给出 logit,配合 BCEWithLogitsLoss 数值更稳。
  • 初始化:同生成器,权重 N(0, 0.02)

27.2.2 代码

class DisBlock(nn.Module):
    """
    判别器基础块:
    Conv2d -> BatchNorm2d -> LeakyReLU(0.2)
    """
    def __init__(self, in_ch, out_ch, kernel=4, stride=2, padding=1, bn=True):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=False)]
        if bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class Discriminator(nn.Module):
    def __init__(self, nc: int = 3, ndf: int = 64):
        """
        nc: 输入图像通道数
        ndf: 判别器基础通道数
        """
        super().__init__()
        self.main = nn.Sequential(
            # 64*64 -> 32*32, 无 BN
            DisBlock(nc, ndf, bn=False),

            # 32*32 -> 16*16
            DisBlock(ndf, ndf * 2),

            # 16*16 -> 8*8
            DisBlock(ndf * 2, ndf * 4),

            # 8*8 -> 4*4
            DisBlock(ndf * 4, ndf * 8),

            # 4*4 -> 1*1
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
            # 不经过 Sigmoid
        )

    def forward(self, x):
        return self.main(x).view(-1)  # (B, 1) -> (B,)

27.3 权重初始化函数

在实例化网络后,调用一次即可:

def weights_init(m):
    classname = m.__class__.__name__
    if "Conv" in classname:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif "BatchNorm" in classname:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# 使用示例
netG = Generator(nz=100, ngf=64, nc=3).apply(weights_init)
netD = Discriminator(nc=3, ndf=64).apply(weights_init)

27.4 迁移到 128×128 或更高分辨率

若需生成 128×128 图像,只需在 GeneratorDiscriminator 各插入一个 GenBlock / DisBlock

  • 生成器:在 GenBlock(ngf*2, ngf) 后再加一个 GenBlock(ngf, ngf//2),最后输出层通道改为 ngf//2
  • 判别器:在 DisBlock(ndf*4, ndf*8) 前插入 DisBlock(ndf*2, ndf*4),对应调整通道数即可。

无需改动初始化函数,所有卷积核尺寸、步长保持不变,代码复用率 100%。


27.5 小结

本节给出了可直接落地的 DCGAN 网络实现,重点展示了 PyTorch 中 “小步快跑” 的卷积–反卷积范式。通过 GenBlock / DisBlock 的复用,你可以在五分钟内把它扩展为 128×128、256×256 甚至 StyleGAN 的雏形。下一节我们将基于这两个网络,用 CelebA 人脸数据完成一次端到端训练,并给出 TensorBoard 可视化脚本。
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值