基于PyTorch的Wasserstein GAN实现详解

基于PyTorch的Wasserstein GAN实现详解

deeplearning-models A collection of various deep learning architectures, models, and tips deeplearning-models 项目地址: https://gitcode.com/gh_mirrors/de/deeplearning-models

引言

Wasserstein GAN(WGAN)是生成对抗网络(GAN)的一种重要变体,由Martin Arjovsky等人在2017年提出。与传统的GAN相比,WGAN通过使用Wasserstein距离(又称Earth-Mover距离)作为损失函数,显著改善了训练过程的稳定性。本文将详细介绍如何使用PyTorch实现一个简单的WGAN模型,并应用于MNIST手写数字生成任务。

Wasserstein GAN的核心原理

WGAN相较于传统GAN有几个关键改进:

  1. 损失函数:使用Wasserstein距离替代Jensen-Shannon散度,解决了传统GAN训练不稳定的问题
  2. 权重裁剪:强制限制判别器(在WGAN中称为"critic")的权重范围,满足Lipschitz约束
  3. 训练策略:通常critic需要比生成器训练更多次(如5:1的比例)
  4. 输出层:critic使用线性输出层而非sigmoid激活函数

环境准备与数据加载

首先我们需要准备PyTorch环境并加载MNIST数据集:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 超参数设置
random_seed = 0
generator_lr = 0.0005
critic_lr = 0.0005
num_epochs = 100
batch_size = 128
latent_dim = 50
img_shape = (1, 28, 28)

# WGAN特定参数
num_critic_iters = 5  # 每次生成器更新前critic的训练次数
clip_value = 0.01     # 权重裁剪值

# MNIST数据加载
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

模型架构实现

WGAN模型包含生成器(Generator)和判别器(Critic)两部分:

class WGAN(nn.Module):
    def __init__(self):
        super(WGAN, self).__init__()
        img_size = 1 * 28 * 28  # MNIST图像展平后的尺寸
        
        # 生成器网络
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, img_size),
            nn.Tanh()  # 输出在[-1,1]范围
        )
        
        # 判别器网络(在WGAN中称为critic)
        self.critic = nn.Sequential(
            nn.Linear(img_size, 128),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, 1),  # 线性输出,不使用sigmoid
        )
    
    def forward_generator(self, z):
        return self.generator(z)
    
    def forward_critic(self, img):
        return self.critic(img.view(img.size(0), -1)).view(-1)

关键点说明:

  1. 生成器接收潜在空间向量(latent vector)作为输入,输出与真实图像相同维度的数据
  2. Critic直接输出一个标量分数,不经过sigmoid激活
  3. 生成器最后使用Tanh激活,将输出限制在[-1,1]范围

训练过程实现

WGAN的训练过程与传统GAN有所不同:

# 初始化模型和优化器
model = WGAN().to(device)
optim_g = torch.optim.Adam(model.generator.parameters(), lr=generator_lr)
optim_c = torch.optim.Adam(model.critic.parameters(), lr=critic_lr)

# 训练循环
for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        
        # 训练Critic多次
        for _ in range(num_critic_iters):
            # 生成假图像
            z = torch.randn(real_imgs.size(0), latent_dim).to(device)
            fake_imgs = model.forward_generator(z)
            
            # 计算Critic损失
            real_scores = model.forward_critic(real_imgs)
            fake_scores = model.forward_critic(fake_imgs.detach())
            
            # WGAN损失函数
            loss_c = -torch.mean(real_scores) + torch.mean(fake_scores)
            
            # 更新Critic
            optim_c.zero_grad()
            loss_c.backward()
            optim_c.step()
            
            # 权重裁剪
            for p in model.critic.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        # 训练生成器
        z = torch.randn(real_imgs.size(0), latent_dim).to(device)
        fake_imgs = model.forward_generator(z)
        fake_scores = model.forward_critic(fake_imgs)
        
        # 生成器希望Critic给假图像高分
        loss_g = -torch.mean(fake_scores)
        
        optim_g.zero_grad()
        loss_g.backward()
        optim_g.step()

训练关键点:

  1. Critic的训练次数多于生成器(通常5:1)
  2. 每次Critic更新后需要裁剪权重到[-0.01,0.01]范围
  3. 损失函数直接使用Critic输出的差值
  4. 生成器训练目标是让Critic给假图像高分

结果分析与改进方向

训练过程中可以观察到:

  • 初期Critic损失波动较大,随着训练逐渐稳定
  • 生成器损失逐渐降低,表示生成质量提高
  • 与传统GAN相比,训练过程更加稳定

可能的改进方向:

  1. 使用梯度惩罚(WGAN-GP)替代权重裁剪,能获得更好的效果
  2. 增加网络深度或使用卷积结构提升生成质量
  3. 调整学习率和训练比例参数
  4. 增加批归一化层加速收敛

总结

本文详细介绍了Wasserstein GAN的PyTorch实现,包括:

  1. WGAN的核心原理和优势
  2. 模型架构的具体实现
  3. 训练过程的特殊处理
  4. 关键代码解析和注意事项

WGAN通过改进损失函数和训练策略,有效解决了传统GAN训练不稳定的问题,是生成模型领域的重要进展。读者可以基于这个简单实现进行扩展,应用于更复杂的生成任务。

deeplearning-models A collection of various deep learning architectures, models, and tips deeplearning-models 项目地址: https://gitcode.com/gh_mirrors/de/deeplearning-models

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

余媛奕Lowell

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值