使用生成对抗网络实现手写数字生成。目标是依据MNIST数据集生成新的手写数字图像。
数据直接运行就能下载
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, sampler
import torchvision
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.autograd import Variable
batch_sizes = 64
Noise_dim = 96 #生成器输入噪声维度
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # images reshape to (batch_size, D)
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg, sqrtimg]))
plt.show()
return
#预处理
def preprocess_img(x):
x = torchvision.transforms.ToTensor()(x)
return (x - 0.5) / 0.5
#反预处理
def deprocess_img(x):
return (x + 1.0) / 2.0
#定义一个采样器,继承pytorch的sampler.Sampler基类
class ChunkSampler(sampler.Sampler):
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
mnist_train_set = torchvision.datasets.MNIST('../data', train=True, download=True, transform=preprocess_img)
mnist_test_set = torchvision.datasets.MNIST('../data', train=False, download=True, transform=preprocess_img)
train_loader = DataLoader(mnist_train_set, batch_size=batch_sizes, sampler=ChunkSampler(55000, 0))
val_data = DataLoader(mnist_train_set, batch_size=batch_sizes, sampler=ChunkSampler(5000, 55000))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#imgs = deprocess_img(train_loader.__iter__().next()[0].view(batch_sizes, 784)).numpy().squeeze()
# images, _ = next(iter(train_loader)) #获取加载数据集中的第一批次数据
# imgs = deprocess_img(images.view(batch_sizes, 784)).numpy()
# show_images(imgs)
#定义判别器
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(1024, 1024),
nn.LeakyReLU(0.01),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
#定义生成器
class generator(nn.Module):
def __init__(self, noise_dim = Noise_dim):
super(generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024),
nn.Linear(1024, 7*7*128),
nn.ReLU(True),
nn.BatchNorm1d(7*7*128)
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 7, 7)
x = self.conv(x)
return x
#定义损失函数
bce_loss = nn.BCEWithLogitsLoss()
#判别器损失函数
def discriminator_loss(logits_real, logits_fake):
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float().to(device)
false_labels = Variable(torch.zeros(size, 1)).float().to(device)
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
#生成器损失函数
def generator_loss(logits_fake):
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float().to(device)
loss = bce_loss(logits_fake, true_labels)
return loss
#搭建生成对抗网络
def Train_GAN(D, G, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
batch_size=batch_sizes, noise_dim=Noise_dim, num_epochs=10):
iter_count = 0 #记录处理的批次数量
for epoch in range(num_epochs):
for x, _ in train_loader: #遍历训练集,忽略标签
bs = x.shape[0] #获取批次大小 x: (64, 1, 28, 28)
real_data = Variable(x).to(device)
logits_real = D(real_data) #判别器判别真实图像
#-1~1均匀分布的噪声
sample_noise = (torch.rand(bs, noise_dim) - 0.5) / 0.5 #生成噪声
g_fake_seed = Variable(sample_noise).to(device) #转化为Variable
fake_images = G(g_fake_seed) #生成假图像
logits_fake = D(fake_images) #判别器判别假图像
#更新判别器
d_total_error = discriminator_loss(logits_real, logits_fake) #计算判别器损失
D_optimizer.zero_grad() #梯度清零
d_total_error.backward() #反向传播
D_optimizer.step() #权重更新
#更新生成器
g_fake_seed = Variable(sample_noise).to(device) #重新生成噪声
fake_images = G(g_fake_seed) #重新生成假图像
g_logits_fake = D(fake_images) #判别器判别假图像
g_error = generator_loss(g_logits_fake) #计算生成器损失
G_optimizer.zero_grad() #梯度清零
g_error.backward() #反向传播
G_optimizer.step() #更新生成器的权重
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
print()
iter_count += 1
#初始化网络
D = discriminator().to(device)
G = generator().to(device)
D_optimizer = torch.optim.Adam(D.parameters(), lr=3e-4, betas=(0.5, 0.999))
G_optimizer = torch.optim.Adam(G.parameters(), lr=3e-4, betas=(0.5, 0.999))
Train_GAN(D, G, D_optimizer, G_optimizer, discriminator_loss, generator_loss, num_epochs=10)