import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
time_start = time.time()
# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
transforms.ToTensor(), # 会做0-1归一化,也会channels, height, width
transforms.Normalize((0.5,), (0.5,))
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform)
dataLoader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Tanh()
)
def forward(self, x):
img = self.main(x)
img = img.view(-1, 28, 28, 1)
return img
# 判别器网络定义
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 512),
nn.LeakyReLU(),
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.main(x)
return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)
# 损失函数
loss_fn = torch.nn.BCELoss()
# 绘图函数
def gen_img_plot(model, test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow((prediction[i] + 1)/2)
plt.axis('off')
plt.show()
test_input = torch.randn(16, 100, device=device)
# GAN训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(20):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataLoader) # 返回批次数
for step, (img, _) in enumerate(dataLoader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size, 100, device=device)
# 判别器的损失与优化
d_optimizer.zero_grad()
real_output = dis(img) # 对判别器输入真实图片, real_output是对真实图片的判断结果
d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失
d_real_loss.backward()
gen_img = gen(random_noise)
fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,fake_output对生成图片的预测
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optimizer.step()
# 生成器的损失与优化
g_optimizer.zero_grad()
fake_output = dis(gen_img)
g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失
g_loss.backward()
g_optimizer.step()
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print("Epoch:", epoch)
gen_img_plot(gen, test_input)
time_end = time.time()
print("花费总时间为:", time_end - time_start)
GAN实战——生成手写字体
最新推荐文章于 2024-08-13 22:34:16 发布