G1 - 生成对抗网络(GAN)

本文是深度学习训练营的学习记录,介绍了生成对抗网络(GAN)。GAN由生成器和判别器组成,基于博弈思想,是无监督学习。文中说明了其基本原理,还给出了Python和Pytorch环境下的实践步骤,包括环境设置、数据准备、模型设计、训练及效果展示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >



理论知识

生成对抗网络(Generative Adversarial Networks, GAN)并不是指某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。

GAN通常由两个部分组成,分别是:生成器(Generator)和判别器(Discriminator)。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实的样本尽可能的区分出来。

理想情况下,经过足够多次的博弈,判别器会无法分辨出样本的真实性,这时可以认为生成器的结果已经逼真到让判别器无法分辨,就可以停止博弈了。

生成器

GANs中,生成器G选取随机噪声z作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。

从数据角度来说,生成式的方法对于特定的真实数据,首先要对数据的显式变量或隐含变量做分布假设;然后再将真实的数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。

从机器学习的角度来说,模型不会做分布假设,而是通过不断地学习真实的数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本的生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。

判别器

GANs中,判别器D对于输入的样本x,输出一个[0, 1]之间的概率数值D(x)。x可以是来自于原始数据集中的真实样本x,也可以是来自于生成器G的人工样本G(z)。通常约定,概率值 D(x) 越接近于1就代表样本为真实样本的可能性越大;反之概率值越小则此样本为伪造样本的可能性更大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管是在生成器中还是在判别器中,样本的类别信息都没有用到,也表明GAN是一个无监督学习的过程。

基本原理

GAN是博弈论和机器学习相结合的产物。于2014年Ian Goodfellow的论文中问世。
GAN模型结构示意图

环境

Python: 3.11
Pytorch: 2.3.0+cu121
显卡:GTX3070

步骤

环境设置

首先设置数据的目录

PARENT_DIR = 'GAN01/'

然后引用本次需要的包

import torch.nn as nn
import torch
import numpy as np
import os
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms, datasets
import torch.optim as optim

创建需要用到的文件夹

os.makedirs(PARENT_DIR + 'images/', exist_ok=True) # 保存生成的图像
os.makedirs(PARENT_DIR + 'save/', exist_ok=True) # 保存模型参数
os.makedirs(PARENT_DIR + 'datasets', exist_ok=True) # 保存下载的数据集

超参数设置

n_epochs = 50  # 训练轮数
batch_size = 64 # 批次大小
lr = 2e-4 # 学习率
b1 = 0.5 # Adam参数1
b2 = 0.999 # Adam参数2
n_cpu = 2 # 数据加载时使用的cpu数量
latent_dim = 100 # 随机向量的维度
img_size = 28 # 图像的大小
channels = 1 # 图像的通道数
sample_intervals = 500 # 保存生成图像的间隔

img_shape = (channels, img_size, img_size) # 图像的尺寸
img_area = np.prod(img_shape)

# 全局设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

下载数据集

mnist = datasets.MNIST(root=PARENT_DIR+'/datasets', train=True, download=True, transform=transforms.Compose([
	transforms.Resize(img_size),
	transforms.ToTensor(),
	transforms.Normalize([0.5], [0.5]),
]))

配置数据加载器

dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

模型设计

鉴别器模块

class Discriminator(nn.Module):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
			nn.Linear(img_area, 512),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(512, 256),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(256, 1),
			nn.Sigmoid(),
		)
	def forward(self, img):
		img_flat = img.view(img.size(0), -1)
		validity = self.model(img_flat)
		return validity

生成器模块

class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		def block(in_feat, out_feat, normalize=True):
			layers = [nn.Linear(in_feat, out_feat)]
			if normalize:
				layers.append(nn.BatchNorm1d(out_feat, 0.8))
			layers.append(nn.LeakyReLU(0.2, inplace=True))
			return layers
		self.model = nn.Sequential(
			*block(latent_dim, 128, normalize=False),
			*block(128, 256),
			*block(256, 512),
			*block(512, 1024),
			nn.Linear(1024, img_area),
			nn.Tanh(),
		)
	def forward(self, z):
		imgs = self.model(z)
		imgs = imgs.view(imgs.size(0), *img_shape)
		return imgs

模型训练

创建模型实例

# 生成器
generator = Generator().to(device)
# 判别器
discriminator = Discriminator().to(device)
# 损失函数
criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

训练过程

for epoch in n_epochs:
	for i, (imgs, _ ) in enumerate(dataloader):
		imgs = imgs.view(imgs.size(0), -1)
		real_img = Variable(imgs).to(device)
		real_label = Variable(torch.ones(imgs.size(0), -1).to(device)
		fake_label = Variable(torch.zeros(imgs.size(0), -1).to(device)
		
		# 训练判别器 - 正例
		real_out = discriminator(real_img)
		loss_real_D = criterion(real_out, real_label)
		real_scores = real_out

		# 训练判别器 - 反例
		z = Variable(torch.randn(imgs.size(0), latent_dim).to(device)
		fake_img = generator(z).detach()
		fake_out = discriminator(fake_img)
		loss_fake_D = criterion(fake_out, fake_label)
		fake_scores = fake_out

		# 训练判别器
		loss_D = loss_real_D + loss_fake_D
		optimizer_D.zero_grad()
		loss_D.backward()
		optimizer_D.step()

		# 训练生成器
		z = Variable(torch.randn(imgs.size(0), latent_dim).to(device)
		fake_img = generator(z)
		output = discriminator(fake_img)
		loss_G = criterion(output, real_label)
		optimizer_G.zero_grad()
		loss_G.backward()
		optimizer_G.step()

		# 日志打印
		if (i+1) % 300 == 0:
			print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]' % (epoch, n_epochs, i, len(dataloader), loss_D.item(). loss_G.item(), real_scores.data.mean(). fake_scores.data.mean()))
		# 保存训练过的图片
		batches_done = epoch * len(dataloader) + i
		if batches_done % sample_intervals == 0:
			save_image(fake_img.data[:25], (PARENT_DIR + 'images/%d.png') % batches_done, nrow-5, normalize=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值