使用Flux.jl实现基础生成对抗网络(GAN)教程
前言
生成对抗网络(GAN)是深度学习领域最具创造力的模型之一,它通过让两个神经网络相互对抗来学习数据分布。本教程将详细介绍如何使用Flux.jl框架实现一个基础的GAN模型,并在MNIST手写数字数据集上进行训练。
GAN基本原理
GAN由两个核心组件构成:
- 生成器(Generator):负责从随机噪声中生成假数据
- 判别器(Discriminator):负责区分真实数据和生成器产生的假数据
这两个网络在训练过程中相互对抗,最终目标是让生成器产生足以"欺骗"判别器的逼真数据。
环境准备
首先需要安装必要的Julia包:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")
Pkg.add("CUDA")
Pkg.add("Zygote")
Pkg.add("UnicodePlots")
超参数设置
lr_g = 2e-4 # 生成器学习率
lr_d = 2e-4 # 判别器学习率
batch_size = 128 # 批量大小
num_epochs = 1000 # 训练轮数
output_period = 100 # 生成样本展示间隔
n_features = 28 * 28 # MNIST图像特征数
latent_dim = 100 # 潜在空间维度
opt_dscr = ADAM(lr_d) # 判别器优化器
opt_gen = ADAM(lr_g) # 生成器优化器
数据准备
MNIST数据集包含60,000张28×28的手写数字灰度图像。我们将数据归一化到[-1,1]范围,这有助于GAN训练的稳定性。
# 加载并预处理数据
train_x, _ = MNIST.traindata(Float32)
train_x = 2f0 * reshape(train_x, 28, 28, 1, :) .- 1f0 |> gpu
train_loader = DataLoader(train_x, batchsize=batch_size, shuffle=true)
网络架构设计
判别器网络
判别器是一个多层感知机,使用LeakyReLU激活函数和Dropout层防止过拟合:
discriminator = Chain(
Dense(n_features => 1024, x -> leakyrelu(x, 0.2f0)),
Dropout(0.3),
Dense(1024 => 512, x -> leakyrelu(x, 0.2f0)),
Dropout(0.3),
Dense(512 => 256, x -> leakyrelu(x, 0.2f0)),
Dropout(0.3),
Dense(256 => 1, sigmoid)
) |> gpu
生成器网络
生成器将潜在空间的随机噪声映射到图像空间:
generator = Chain(
Dense(latent_dim, 256, x -> leakyrelu(x, 0.2f0)),
Dense(256 => 512, x -> leakyrelu(x, 0.2f0)),
Dense(512 => 1024, x -> leakyrelu(x, 0.2f0)),
Dense(1024 => n_features, tanh)
) |> gpu
训练过程
判别器训练
判别器需要同时学习识别真实数据和生成数据:
function train_dscr!(discriminator, real_data, fake_data)
batch_size = size(real_data)[end]
all_data = hcat(real_data, fake_data)
all_target = [ones(eltype(real_data), 1, batch_size)
zeros(eltype(fake_data), 1, batch_size)] |> gpu
ps = Flux.params(discriminator)
loss, pullback = Zygote.pullback(ps) do
preds = discriminator(all_data)
Flux.Losses.binarycrossentropy(preds, all_target)
end
grads = pullback(1f0)
Flux.update!(opt_dscr, Flux.params(discriminator), grads)
return loss
end
生成器训练
生成器试图欺骗判别器:
function train_gen!(discriminator, generator)
noise = randn(latent_dim, batch_size) |> gpu
ps = Flux.params(generator)
testmode!(discriminator)
loss, back = Zygote.pullback(ps) do
preds = discriminator(generator(noise))
Flux.Losses.binarycrossentropy(preds, 1.)
end
grads = back(1.0f0)
Flux.update!(opt_gen, Flux.params(generator), grads)
trainmode!(discriminator, mode=:auto)
return loss
end
主训练循环
lossvec_gen = zeros(num_epochs)
lossvec_dscr = zeros(num_epochs)
for n in 1:num_epochs
loss_sum_gen = 0.0f0
loss_sum_dscr = 0.0f0
for x in train_loader
real_data = flatten(x)
noise = randn(latent_dim, size(x)[end]) |> gpu
fake_data = generator(noise)
loss_dscr = train_dscr!(discriminator, real_data, fake_data)
loss_gen = train_gen!(discriminator, generator)
loss_sum_dscr += loss_dscr
loss_sum_gen += loss_gen
end
lossvec_gen[n] = loss_sum_gen / size(train_x)[end]
lossvec_dscr[n] = loss_sum_dscr / size(train_x)[end]
if n % output_period == 0
@show n
noise = randn(latent_dim, 4) |> gpu
fake_data = reshape(generator(noise), 28, 4*28)
p = heatmap(fake_data, colormap=:inferno)
print(p)
end
end
训练结果分析
随着训练进行,生成器产生的图像质量会逐步提高:
- 初期(约100轮):生成模糊、无结构的噪声图像
- 中期(约1000轮):开始出现可辨认的数字形状
- 后期(约5000轮):生成与真实MNIST难以区分的清晰数字
性能优化建议
- 使用GPU加速训练过程
- 适当调整学习率和批量大小
- 尝试不同的网络架构和激活函数
- 使用更先进的GAN变体如DCGAN、WGAN等
常见问题解决
-
模式崩溃:生成器只产生有限的几种样本
- 解决方案:尝试增加噪声维度或使用mini-batch判别
-
训练不稳定:损失值剧烈波动
- 解决方案:调整学习率,使用梯度裁剪
-
生成质量差:
- 解决方案:增加网络深度,调整激活函数
通过本教程,您应该已经掌握了使用Flux.jl实现基础GAN的方法。GAN训练需要耐心和反复调试,但一旦成功,它能产生令人惊叹的结果。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



