- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
理论知识
生成对抗网络(Generative Adversarial Networks, GAN)并不是指某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。
GAN通常由两个部分组成,分别是:生成器(Generator)和判别器(Discriminator)。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实的样本尽可能的区分出来。
理想情况下,经过足够多次的博弈,判别器会无法分辨出样本的真实性,这时可以认为生成器的结果已经逼真到让判别器无法分辨,就可以停止博弈了。
生成器
GANs中,生成器G选取随机噪声z作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)
。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数据角度来说,生成式的方法对于特定的真实数据,首先要对数据的显式变量或隐含变量做分布假设;然后再将真实的数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。
从机器学习的角度来说,模型不会做分布假设,而是通过不断地学习真实的数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本的生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。
判别器
GANs中,判别器D对于输入的样本x,输出一个[0, 1]
之间的概率数值D(x)
。x可以是来自于原始数据集中的真实样本x,也可以是来自于生成器G的人工样本G(z)
。通常约定,概率值 D(x)
越接近于1就代表样本为真实样本的可能性越大;反之概率值越小则此样本为伪造样本的可能性更大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管是在生成器中还是在判别器中,样本的类别信息都没有用到,也表明GAN是一个无监督学习的过程。
基本原理
GAN是博弈论和机器学习相结合的产物。于2014年Ian Goodfellow的论文中问世。
环境
Python: 3.11
Pytorch: 2.3.0+cu121
显卡:GTX3070
步骤
环境设置
首先设置数据的目录
PARENT_DIR = 'GAN01/'
然后引用本次需要的包
import torch.nn as nn
import torch
import numpy as np
import os
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms, datasets
import torch.optim as optim
创建需要用到的文件夹
os.makedirs(PARENT_DIR + 'images/', exist_ok=True) # 保存生成的图像
os.makedirs(PARENT_DIR + 'save/', exist_ok=True) # 保存模型参数
os.makedirs(PARENT_DIR + 'datasets', exist_ok=True) # 保存下载的数据集
超参数设置
n_epochs = 50 # 训练轮数
batch_size = 64 # 批次大小
lr = 2e-4 # 学习率
b1 = 0.5 # Adam参数1
b2 = 0.999 # Adam参数2
n_cpu = 2 # 数据加载时使用的cpu数量
latent_dim = 100 # 随机向量的维度
img_size = 28 # 图像的大小
channels = 1 # 图像的通道数
sample_intervals = 500 # 保存生成图像的间隔
img_shape = (channels, img_size, img_size) # 图像的尺寸
img_area = np.prod(img_shape)
# 全局设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
数据准备
下载数据集
mnist = datasets.MNIST(root=PARENT_DIR+'/datasets', train=True, download=True, transform=transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]))
配置数据加载器
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
模型设计
鉴别器模块
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(img_area, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
生成器模块
class Generator(nn.Module):
def __init__(self):
super().__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, img_area),
nn.Tanh(),
)
def forward(self, z):
imgs = self.model(z)
imgs = imgs.view(imgs.size(0), *img_shape)
return imgs
模型训练
创建模型实例
# 生成器
generator = Generator().to(device)
# 判别器
discriminator = Discriminator().to(device)
# 损失函数
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
训练过程
for epoch in n_epochs:
for i, (imgs, _ ) in enumerate(dataloader):
imgs = imgs.view(imgs.size(0), -1)
real_img = Variable(imgs).to(device)
real_label = Variable(torch.ones(imgs.size(0), -1).to(device)
fake_label = Variable(torch.zeros(imgs.size(0), -1).to(device)
# 训练判别器 - 正例
real_out = discriminator(real_img)
loss_real_D = criterion(real_out, real_label)
real_scores = real_out
# 训练判别器 - 反例
z = Variable(torch.randn(imgs.size(0), latent_dim).to(device)
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake_D = criterion(fake_out, fake_label)
fake_scores = fake_out
# 训练判别器
loss_D = loss_real_D + loss_fake_D
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
# 训练生成器
z = Variable(torch.randn(imgs.size(0), latent_dim).to(device)
fake_img = generator(z)
output = discriminator(fake_img)
loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
# 日志打印
if (i+1) % 300 == 0:
print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]' % (epoch, n_epochs, i, len(dataloader), loss_D.item(). loss_G.item(), real_scores.data.mean(). fake_scores.data.mean()))
# 保存训练过的图片
batches_done = epoch * len(dataloader) + i
if batches_done % sample_intervals == 0:
save_image(fake_img.data[: