论文:Generative Adversarial Networks
作者:Ian J. Goodfellow
年份:2014年
从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简单记录一下,有时间会补充。
更多关于GAN的可以看我另一篇:https://blog.youkuaiyun.com/demo_jie/article/details/106724016
直接讲代码实现部分,这个代码是用pytorch训练GAN,基于MNIST数据集
真实图片:
代码:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os
if not os.path.exists('img'):
os.mkdir('img')
def to_img(x):
out = 0.5 * (x + 1)
out = out.clamp(0, 1) #输出限制在0,1范围内
out = out.view(-1, 1, 28, 28)
return out
# 初始化参数
batch_size = 128
num_epoch = 10
z_dimension = 100
# 对图片进行一些前期处理操作
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# img_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ]
# 数据集下载
mnist = datasets.MNIST(
root='E:/low-light/deep learning/GAN/data/', train=True, transform=img_transform, download=True)
# 数据集加载
dataloader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=batch_size, shuffle=True)
# 判别网络
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2), nn.Linear(256, 1),
nn.Sigmoid()) # sigmoid激活函数得到一个0到1之间的概率进行二分类
def forward(self, x):
x = self.dis(x)
return x
# 生成器
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()) # Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。
def forward(self, x):
x = self.gen(x)
return x
D = discriminator()
G = generator()
if torch.cuda.is_available():
D = D.cuda()
G = G.cuda()
# 判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。
# 二进制交叉熵损失和优化器
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# 开始训练
for epoch in range(num_epoch):
for i, (img, _) in enumerate(dataloader):
num_img = img.size(0)
# ========================================================================训练判别器
img = img.view(num_img, -1) # # 将图片展开乘28x28=784
# real_img = Variable(img).cuda()
# real_label = Variable(torch.ones(num_img)).cuda()
# fake_label = Variable(torch.zeros(num_img)).cuda()
real_img = Variable(img)
real_label = Variable(torch.ones(num_img)) # 定义真实label为1
fake_label = Variable(torch.zeros(num_img)) # 定义假label为1
# 计算 real_img 的损失
real_out = D(real_img) # 将真实的图片放入判别器中
d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss
real_scores = real_out # 越接近一越好
# 计算 fake_img的损失
# z = Variable(torch.randn(num_img, z_dimension)).cuda()
z = Variable(torch.randn(num_img, z_dimension)) # 随机生成一些噪声
fake_img = G(z) # 放入生成网络生成一张假的图片
fake_out = D(fake_img) ## 判别器判断假的图片
d_loss_fake = criterion(fake_out, fake_label) ## 得到假的图片的loss
fake_scores = fake_out # 越接近0越好
# 反向传播和优化
d_loss = d_loss_real + d_loss_fake # 将真假图片的loss加起来
d_optimizer.zero_grad() # 每次梯度归零
d_loss.backward() # 反向传播
d_optimizer.step() # 更新参数
# =====================================================================训练生成器
# 计算fake_img损失
# z = Variable(torch.randn(num_img, z_dimension)).cuda()
z = Variable(torch.randn(num_img, z_dimension)) # 得到随机噪声
fake_img = G(z) # 生成假的图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, real_label) ##得到假的图片与真实图片label的loss
# 反向传播和优化
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(
epoch, num_epoch, d_loss.item(), g_loss.item(),
real_scores.data.mean(), fake_scores.data.mean()))
if epoch == 0:
real_images = to_img(real_img.cpu().data)
save_image(real_images, 'real_images.png')
fake_images = to_img(fake_img.cpu().data)
save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')
运行结果:
这次一共跑了10次,以下是生成的噪声图片,分别是跑了1,3,5,7,9,10次的图片(训练次数太少了,所以效果不明显,可以自己设置训练次数)
生成的真实图片: