import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from IPython import display
import time
import imageio
import glob
from PIL import Image
EPOCHS =30
BATCH_SIZE=128
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 其他代码...
# 在训练循环之前生成一个恒定的种子 可以确保每次运行程序时生成的随机数都是相同的都是tensor ,torch.Size([16, 100])
num_examples_to_generate = 16
seed = torch.randn(num_examples_to_generate, 100, device=device)
# MNIST 数据处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载 MNIST 数据
train_dataset = datasets.MNIST(root="./data/mnist", train=True, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
#生成器模型 输入[BATCH_SIZE,100]>>>输出[BATCH_SIZE,1,28,28]
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Linear(100, 7 * 7 * 256)
self.batch_norm1 = nn.BatchNorm2d(256)
self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, output_padding=0, bias=False)
self.batch_norm2 = nn.BatchNorm2d(128)
self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
self.batch_norm3 = nn.BatchNorm2d(64)
self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 256, 7, 7)
x = self.batch_norm1(x)
x = nn.functional.leaky_relu(x)
x = self.deconv1(x)
x = self.batch_norm2(x)
x = nn.functional.leaky_relu(x)
x = self.deconv2(x)
x = self.batch_norm3(x)
x = nn.functional.leaky_relu(x)
x = self.deconv3(x)
x = torch.tanh(x)
return x
# 判别器模型 输入[BATCH_SIZE,1,28,28]>>>输出[BATCH_SIZE,]
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2)
self.leaky_relu1 = nn.LeakyReLU(0.2)
self.dropout1 = nn.Dropout(0.3)
self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
self.leaky_relu2 = nn.LeakyReLU(0.2)
self.dropout2 = nn.Dropout(0.3)
self.flatten = nn.Flatten()
self.fc = nn.Linear(7 * 7 * 128, 1)
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = self.leaky_relu2(x)
x = self.dropout2(x)
x = self.flatten(x)
x = self.fc(x)
return x
# 创建模型实例
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=1e-4)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=1e-4)
# 训练循环
def train_step(images):
batch_size = images.size(0) # 获取实际的批次大小
noise = torch.randn(batch_size, 100, device=device)
optimizer_generator.zero_grad()
optimizer_discriminator.zero_grad()
generated_images = generator(noise)
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
real_output = discriminator(images)
# 计算判别器对生成样本的损失,将生成器的输出张量分离以避免梯度传播到生成器模型
fake_output = discriminator(generated_images.detach())
gen_loss = criterion(discriminator(generated_images), real_labels)
gen_loss.backward()
optimizer_generator.step()
disc_loss_real = criterion(real_output, real_labels)
disc_loss_fake = criterion(fake_output, fake_labels)
disc_loss = disc_loss_real + disc_loss_fake
disc_loss.backward()
optimizer_discriminator.step()
def generate_and_save_images(model, epoch, test_input):
model.eval()
with torch.no_grad():
predictions = model(test_input).cpu().detach()
model.train()
plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, 0, :, :] * 0.5 + 0.5, cmap='gray')
plt.axis('off') # 关闭坐标轴显示
plt.savefig('C:/Users/Administrator/Desktop/GAN/images2/image_at_epoch_{:04d}.png'.format(epoch))
def train_model(generator, discriminator, train_loader, num_epochs):
for epoch in range(num_epochs):
start = time.time()
for i, (images, _) in enumerate(train_loader):
images = images.to(device)
train_step(images)
display.clear_output(wait=True)
generate_and_save_images(generator, epoch + 1, seed)
if (epoch + 1) % 10 == 0:
checkpoint_path = 'C:/Users/Administrator/Desktop/GAN/checkpoints/checkpoint_epoch_{}.tar'.format(epoch + 1)
torch.save({
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'optimizer_generator_state_dict': optimizer_generator.state_dict(),
'optimizer_discriminator_state_dict': optimizer_discriminator.state_dict(),
}, checkpoint_path)
print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
display.clear_output(wait=True)
generate_and_save_images(generator, num_epochs, seed)
# 训练模型
train_model(generator, discriminator, train_loader, EPOCHS)
# 恢复最新的检查点
latest_checkpoint_path = 'C:/Users/Administrator/Desktop/GAN/checkpoints/checkpoint_epoch_{}.tar'.format(EPOCHS)
checkpoint = torch.load(latest_checkpoint_path)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
# 评估模型
def display_image(epoch_no):
image_path = 'C:/Users/Administrator/Desktop/GAN/images2/image_at_epoch_{:04d}.png'.format(epoch_no)
img = Image.open(image_path)
display.display(img)
# 使用 EPOCHS 变量调用
display_image(EPOCHS)