基于PyTorch的Wasserstein GAN实现详解
引言
Wasserstein GAN(WGAN)是生成对抗网络(GAN)的一种重要变体,由Martin Arjovsky等人在2017年提出。与传统的GAN相比,WGAN通过使用Wasserstein距离(又称Earth-Mover距离)作为损失函数,显著改善了训练过程的稳定性。本文将详细介绍如何使用PyTorch实现一个简单的WGAN模型,并应用于MNIST手写数字生成任务。
Wasserstein GAN的核心原理
WGAN相较于传统GAN有几个关键改进:
- 损失函数:使用Wasserstein距离替代Jensen-Shannon散度,解决了传统GAN训练不稳定的问题
- 权重裁剪:强制限制判别器(在WGAN中称为"critic")的权重范围,满足Lipschitz约束
- 训练策略:通常critic需要比生成器训练更多次(如5:1的比例)
- 输出层:critic使用线性输出层而非sigmoid激活函数
环境准备与数据加载
首先我们需要准备PyTorch环境并加载MNIST数据集:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 超参数设置
random_seed = 0
generator_lr = 0.0005
critic_lr = 0.0005
num_epochs = 100
batch_size = 128
latent_dim = 50
img_shape = (1, 28, 28)
# WGAN特定参数
num_critic_iters = 5 # 每次生成器更新前critic的训练次数
clip_value = 0.01 # 权重裁剪值
# MNIST数据加载
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
模型架构实现
WGAN模型包含生成器(Generator)和判别器(Critic)两部分:
class WGAN(nn.Module):
def __init__(self):
super(WGAN, self).__init__()
img_size = 1 * 28 * 28 # MNIST图像展平后的尺寸
# 生成器网络
self.generator = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(inplace=True),
nn.Linear(128, img_size),
nn.Tanh() # 输出在[-1,1]范围
)
# 判别器网络(在WGAN中称为critic)
self.critic = nn.Sequential(
nn.Linear(img_size, 128),
nn.LeakyReLU(inplace=True),
nn.Linear(128, 1), # 线性输出,不使用sigmoid
)
def forward_generator(self, z):
return self.generator(z)
def forward_critic(self, img):
return self.critic(img.view(img.size(0), -1)).view(-1)
关键点说明:
- 生成器接收潜在空间向量(latent vector)作为输入,输出与真实图像相同维度的数据
- Critic直接输出一个标量分数,不经过sigmoid激活
- 生成器最后使用Tanh激活,将输出限制在[-1,1]范围
训练过程实现
WGAN的训练过程与传统GAN有所不同:
# 初始化模型和优化器
model = WGAN().to(device)
optim_g = torch.optim.Adam(model.generator.parameters(), lr=generator_lr)
optim_c = torch.optim.Adam(model.critic.parameters(), lr=critic_lr)
# 训练循环
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(train_loader):
real_imgs = real_imgs.to(device)
# 训练Critic多次
for _ in range(num_critic_iters):
# 生成假图像
z = torch.randn(real_imgs.size(0), latent_dim).to(device)
fake_imgs = model.forward_generator(z)
# 计算Critic损失
real_scores = model.forward_critic(real_imgs)
fake_scores = model.forward_critic(fake_imgs.detach())
# WGAN损失函数
loss_c = -torch.mean(real_scores) + torch.mean(fake_scores)
# 更新Critic
optim_c.zero_grad()
loss_c.backward()
optim_c.step()
# 权重裁剪
for p in model.critic.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
z = torch.randn(real_imgs.size(0), latent_dim).to(device)
fake_imgs = model.forward_generator(z)
fake_scores = model.forward_critic(fake_imgs)
# 生成器希望Critic给假图像高分
loss_g = -torch.mean(fake_scores)
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
训练关键点:
- Critic的训练次数多于生成器(通常5:1)
- 每次Critic更新后需要裁剪权重到[-0.01,0.01]范围
- 损失函数直接使用Critic输出的差值
- 生成器训练目标是让Critic给假图像高分
结果分析与改进方向
训练过程中可以观察到:
- 初期Critic损失波动较大,随着训练逐渐稳定
- 生成器损失逐渐降低,表示生成质量提高
- 与传统GAN相比,训练过程更加稳定
可能的改进方向:
- 使用梯度惩罚(WGAN-GP)替代权重裁剪,能获得更好的效果
- 增加网络深度或使用卷积结构提升生成质量
- 调整学习率和训练比例参数
- 增加批归一化层加速收敛
总结
本文详细介绍了Wasserstein GAN的PyTorch实现,包括:
- WGAN的核心原理和优势
- 模型架构的具体实现
- 训练过程的特殊处理
- 关键代码解析和注意事项
WGAN通过改进损失函数和训练策略,有效解决了传统GAN训练不稳定的问题,是生成模型领域的重要进展。读者可以基于这个简单实现进行扩展,应用于更复杂的生成任务。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考