GAN简单介绍—使用PyTorch框架搭建GAN对MNIST数据集进行训练

GAN 的简单介绍

GAN,全称Generative Adversarial Networks,即生成对抗网络,是深度学习中一种强大的生成模型。GAN 是由 Ian Goodfellow 等人在 2014 年提出的,通过让两个神经网络相互对抗来生成新的数据。

GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习框架,用于生成新的、与训练数据相似的数据。GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络在训练过程中相互竞争和协作,使得生成器能够生成越来越逼真的数据。

本文将介绍如何使用 PyTorch 实现 GAN。

GAN 介绍

GAN 的工作原理是,生成器(Generator)和判别器(Discriminator)两个神经网络相互对抗地学习。生成器用于生成图像或数据,而判别器则用于判断输入的数据是否真实。两个神经网络不断地交替训练,直到生成器可以生成接近于真实样本的样本。

GAN 可以用于生成各类数据,如图像、音频、文本等。在图像生成方面,GAN 用于生成无限多张与训练数据相似却并不存在于训练数据中的图像。GAN 可以应用于许多领域,如计算机图形学、自然语言处理等等。

实现 GAN

加载数据集

使用 PyTorch 中的 MNIST 数据集进行训练。MNIST 数据集包含手写数字图像,大小为 28x28 像素。

import torch
from torchvision import datasets, transforms

# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist_data, batch_size=128, shuffle=True)

定义生成器和判别器

生成器和判别器都是神经网络模型。生成器将从随机噪声中生成图像,而判别器将对输入图像进行分类,判断它是真实图像还是生成器生成的伪造图像。

import torch.nn as nn

# 定义生成器
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden_dim=128):
        super(Generator, self).__init__() # 继承 nn.Module 类并初始化
        self.z_dim = z_dim # 输入向量的维度
        self.hidden_dim = hidden_dim # 隐藏层维度
        self.fc1 = nn.Linear(z_dim, hidden_dim) # 全连接层,输入为 z_dim 维,输出为 hidden_dim 维
        self.fc2 = nn.Linear(hidden_dim, 28 * 28) # 全连接层,将隐藏层映射到 28*28 的图像

    def forward(self, z):
        x = self.fc1(z) # 输入 z,通过全连接层 fc1 得到隐藏层向量 x
        x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
        x = self.fc2(x) # 将隐藏层映射到生成的图像
        x = nn.functional.tanh(x) # 将输出值映射到 [-1, 1] 的范围
        return x.view(-1, 1, 28, 28) # 将输出展平成图片张量形式

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, hidden_dim=128):
        super(Discriminator, self).__init__() # 继承 nn.Module 类并初始化
        self.hidden_dim = hidden_dim # 隐藏层维度
        self.fc1 = nn.Linear(28 * 28, hidden_dim) # 输入为 28*28 的图像,输出为隐藏层向量
        self.fc2 = nn.Linear(hidden_dim, 1) # 只有一个输出,表示输入是否是真实的图像。

    def forward(self, x):
        x = x.view(-1, 28*28) # 将输入展平为向量
        x = self.fc1(x) # 将展平后的向量通过全连接层 fc1 映射到隐藏层
        x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
        x = self.fc2(x) # 将隐藏层输出映射到单个输出值
        x = nn.functional.sigmoid(x) # 将输出值映射到 [0, 1] 的范围,表示输入是否是真实图像的概率
        return x

这段代码实现了一个基本的生成对抗网络(GAN)。GAN 由两个神经网络组成,一个生成器(Generator)和一个判别器(Discriminator),它们在博弈中对抗地进行训练。生成器从随机噪声中生成假图像,而判别器则试图区分真实图像和生成图像。训练过程目的是使得生成器生成的图像更逼真,从而欺骗判别器,使其将生成的图像和真实图像分类错误。以下是代码实现的具体步骤:

  1. 导入 torch.nn 库,以使用 PyTorch 提供的深度学习模型库;
  2. 定义生成器:继承 nn.Module 类并初始化输入随机噪声向量的维度(z_dim)和隐藏层维度(hidden_dim),以及两个全连接层 fc1 和 fc2。在 forward 方法中,通过全连接层将输入向量 z 映射成隐藏层向量 x,然后在隐藏层中应用 leaky ReLU 激活函数,将隐藏层映射到输出图像,最后将输出值映射至 [-1, 1] 的范围并将其展平成图片张量形式;
  3. 定义判别器:同样继承 nn.Module 类并初始化隐藏层维度(hidden_dim)和两个全连接层 fc1 和 fc2。在 forward 方法中,将输入的二维图片张量展平为向量 x,通过全连接层 fc1 映射到隐藏层,应用 leaky ReLU 激活函数,最后将隐藏层输出通过全连接层 fc2 映射为单个输出值,并将其映射至 [0, 1] 的范围,表示输入是否是真实图像的概率;
  4. GAN 的训练过程:定义生成器和判别器的优化器(optimizer),设置目标损失函数(loss function),以及迭代训练的循环。每轮迭代都以真实图像或噪声输入生成器作为输入,计算生成器对应的生成图像,将真实的和生成的数据放入判别器中分别计算输出,分别计算判别器对真实数据的损失和对生成数据的损失,然后利用这两个损失更新生成器和判别器的参数。最后的目标是使得生成器生成尽可能逼真的数据,欺骗判别器,从而将其指导生成尽可能逼真的数据。

定义超参数、优化器、损失函数

定义超参数、优化器、损失函数。

# 定义超参数
lr = 0.0002
z_dim = 100
num_epochs = 50

generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

# 定义优化器和损失函数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

criterion = nn.BCELoss()

训练网络

# 训练网络,使用给定的 epoch 数量
for epoch in range(num_epochs):
    # 遍历数据集的 mini-batches
    for i, (real_images, _) in enumerate(data_loader):
        # 将真实图像传递给设备
        real_images = real_images.to(device)

        # 定义真实标签为 1,假的标签为 0
        real_labels = torch.ones(real_images.size(0), 1).to(device)
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)

        # 训练生成器
        # 生成随机的噪音 z
        z = torch.randn(real_images.size(0), z_dim).to(device)

        # 生成 fake_images
        fake_images = generator(z)

        # 将 fake_images 传递给判别器,得到输出 fake_output
        fake_output = discriminator(fake_images)

        # 计算生成器的损失值 loss_G
        # 需要将 fake_output 与真实标签 real_labels 进行比较
        loss_G = criterion(fake_output, real_labels)

        # 重置 generator 的优化器,清除梯度
        optimizer_G.zero_grad()

        # 计算 generator 的梯度
        loss_G.backward()

        # 更新 generator 参数,使用优化器 optimizer_G
        optimizer_G.step()

        # 训练判别器
        # 将真实图像 real_images 通过判别器得到输出 real_output
        real_output = discriminator(real_images)

        # 计算判别器对真实图像的损失值 loss_D_real
        # 需要将 real_output 与真实标签 real_labels 进行比较
        loss_D_real = criterion(real_output, real_labels)

        # 计算判别器对生成器生成的 fake_images 的损失值 loss_D_fake
        # 需要将 fake_output 与假的标签 fake_labels 进行比较
        fake_output = discriminator(fake_images.detach())
        loss_D_fake = criterion(fake_output, fake_labels)

        # 计算判别器总损失值 loss_D,即 loss_D_real 和 loss_D_fake 的和
        loss_D = loss_D_real + loss_D_fake

        # 重置 discriminator 的优化器,清除梯度
        optimizer_D.zero_grad()

        # 计算 discriminator 的梯度
        loss_D.backward()

        # 更新 discriminator 参数,使用优化器 optimizer_D
        optimizer_D.step()

        # 输出损失值
        # 每100次迭代输出一次
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, len(data_loader), loss_D.item(), loss_G.item()))

这段代码表示了一个完整的 GAN 利用交替梯度下降算法进行训练的过程。在每个 epoch 内,我们遍历数据集中的所有 mini-batches,并对每个 mini-batch 迭代进行以下步骤:

  1. 将真实图片 real_images 传递给设备。
  2. 定义标签,其中真实图像对应的标签为 1,生成器生成的假图像对应的标签为 0。
  3. 训练生成器。从一个经过随机初始化的噪声向量 z 开始,使用生成器学习生成伪造图片。将生成的图片 fake_images 传递给判别器,并使用判别器输出来计算生成器的损失值。这个损失值反映了生成器生成的图片与真实标签之间的差异。
  4. 重置生成器的优化器,并清空梯度。
  5. 对生成器的损失值进行反向传播计算梯度。
  6. 使用优化器来更新生成器的参数,使其在下一个迭代中能够更好地生成伪造图片。
  7. 训练判别器。使用真实图片来计算判别器对真实图片的损失值,用假图片来计算判别器对伪造图片的损失值,并将这些值相加以得到判别器的总损失。使用反向传播来计算梯度并更新判别器参数。
  8. 每 100 次迭代输出一次损失值,以跟踪训练的进展。
  9. 循环在所有训练数据上结束时,一次完整的迭代称为一个 epoch。重复执行此步骤 num_epochs 次来完成 GAN 的训练。

使用生成器生成新图像

在训练结束后,可以使用生成器来生成新图像。

# 生成一些测试数据
# 随机生成一些长度为 z_dim 的、位于设备上的向量 z
z = torch.randn(16, z_dim).to(device)

# 使用生成器从 z 中生成一些假的图片
fake_images = generator(z).detach().cpu()

# 显示生成的图像
# 创建一个图形对象,大小为 4x4 英寸
fig = plt.figure(figsize=(4, 4))

# 在图形对象中创建4x4的网格,以显示输出的16张假图像
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(fake_images[i][0], cmap='gray')
    plt.axis('off')

# 显示绘制的图形
plt.show()

这段代码的目的是在训练 GAN 结束后,使用生成器生成一些随机的假图像进行展示。下面是具体步骤:

  1. 随机生成一些长度为 z_dim 的、位于设备上的向量 z
  2. 使用生成器从 z 中生成一些假的图片。
  3. 创建一个 4x4 英寸大小的图形对象。
  4. 在图形对象中创建一个 4x4 的网格,用于显示生成的 16 张图片。
  5. 使用 plt.imshow() 函数将图像附加到当前子图中。该函数会在 matplotlib 窗口中显示生成的图像。
  6. 从绘图中删除坐标轴。
  7. 显示生成的图形。

结论

GAN 是一种强大的生成模型,可以用于生成各种不同类型的数据。本教程展示了如何使用 PyTorch 实现 GAN,并使用 MNIST 数据集进行训练。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

百年孤独百年

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值