生成对抗网络:合成新数据的利器
1. 生成对抗网络(GAN)简介
GAN 的总体目标是合成与训练数据集具有相同分布的新数据。原始形式的 GAN 属于无监督学习任务,因为不需要标记数据。不过,对原始 GAN 的扩展可用于半监督和监督任务。
2014 年,Ian Goodfellow 及其同事首次提出了 GAN 的概念,用于使用深度神经网络合成新图像。最初的 GAN 架构基于全连接层,类似于多层感知机架构,用于生成低分辨率的类似 MNIST 的手写数字,主要作为概念验证。
自提出以来,GAN 得到了众多改进,并在工程和科学的不同领域有广泛应用,如计算机视觉中的图像到图像翻译、图像超分辨率、图像修复等。例如,现在的 GAN 模型能够生成高分辨率的人脸图像,可在 https://www.thispersondoesnotexist.com/ 上查看示例。
2. 从自编码器开始
在讨论 GAN 的工作原理之前,先了解自编码器。自编码器可以压缩和解压缩训练数据,虽然标准自编码器不能生成新数据,但理解其功能有助于后续理解 GAN。
自编码器由编码器网络和解码器网络串联组成。编码器接收一个 d 维的输入特征向量 𝒙(𝒙∈𝑅𝑑),并将其编码为一个 p 维的向量 𝒛(𝒛∈𝑅𝑝),即 𝒛 = 𝑓(𝒙)。编码后的向量 𝒛 也称为潜在向量或潜在特征表示,通常 p < d,因此编码器起到数据压缩的作用。解码器则从低维的潜在向量 𝒛 中解压缩出 𝒙̂,可表示为 𝒙̂ = 𝑔(𝒛)。
2.1 自编码器与降维的联系
自编码器也可作为一种降维技术。当编码器和解码器两个子网络都没有非线性时,自编码器方法与主成分分析(PCA)几乎相同。
如果假设单层编码器(无隐藏层和非线性激活函数)的权重用矩阵 U 表示,则编码器建模为 𝒛 = 𝑼𝑇𝒙,单层线性解码器建模为 𝒙̂ = 𝑼𝒛,将两者结合得到 𝒙̂ = 𝑼𝑼𝑇𝒙,这与 PCA 类似,只是 PCA 有额外的正交归一约束:𝑼𝑼𝑇 = 𝑰𝑛×𝑛。
2.2 基于潜在空间大小的其他类型自编码器
- 欠完备自编码器 :潜在空间的维度通常低于输入的维度(p < d),潜在向量常被称为“瓶颈”,这种配置适合降维。
- 过完备自编码器 :潜在向量 𝒛 的维度大于输入示例的维度(p > d)。训练过完备自编码器时,存在一个简单的解决方案,即编码器和解码器直接将输入特征复制到输出层,但这并不实用。不过,通过对训练过程进行一些修改,过完备自编码器可用于降噪。在训练时,向输入示例添加随机噪声 𝜖,网络学习从含噪信号 𝒙 + 𝜖 中重建干净的示例 𝒙。在评估时,提供自然含噪的新示例以去除其中的噪声,这种自编码器架构和训练方法称为去噪自编码器。
3. 用于合成新数据的生成模型
自编码器是确定性模型,训练后只能通过压缩表示的转换来重建输入,无法生成新数据。而生成模型可以从随机向量 𝒛(对应潜在表示)生成新示例 𝒙̃。随机向量 𝒛 来自具有完全已知特征的简单分布,例如均匀分布 [–1, 1] 或标准正态分布。
自编码器的解码器组件与生成模型有一些相似之处,它们都以潜在向量 𝒛 为输入,并在与 𝒙 相同的空间中返回输出。但两者的主要区别在于,自编码器中 𝒛 的分布未知,而生成模型中 𝒛 的分布是完全可表征的。可以将自编码器推广为生成模型,例如变分自编码器(VAEs)。
在 VAE 中,接收输入示例 𝒙 时,编码器网络会计算潜在向量分布的两个矩:均值 𝜇 和方差 𝜎²。训练时,网络会迫使这些矩与标准正态分布(零均值和单位方差)相匹配。训练完成后,丢弃编码器,使用解码器网络通过输入来自“学习到”的高斯分布的随机 𝒛 向量来生成新示例 𝒙̃。
除了 VAEs,还有其他类型的生成模型,如自回归模型和归一化流模型。但本文主要关注 GAN 模型,它是深度学习中最新且最流行的生成模型之一。
4. 使用 GAN 生成新样本
假设存在一个网络,接收从已知分布采样的随机向量 𝒛 并生成输出图像 𝒙,将其称为生成器(G),记为 𝒙̃ = 𝐺(𝒛)。目标是生成各种图像,如人脸图像、建筑物图像、动物图像或手写数字。
初始时,生成器网络的权重随机初始化,输出图像看起来像白噪声。如果存在一个评估图像质量的函数(评估器函数),就可以根据其反馈调整生成器网络的权重,以提高生成图像的质量。
实际上,GAN 模型包含一个额外的神经网络,称为判别器(D),它是一个分类器,用于区分合成图像 𝒙̃ 和真实图像 𝒙。生成器和判别器两个网络一起训练,在训练过程中,它们相互对抗。生成器学习改进其输出以欺骗判别器,同时判别器变得更善于检测合成图像。
4.1 GAN 模型中生成器和判别器网络的损失函数
GAN 的目标函数如下:
[
V(\theta^{(D)}, \theta^{(G)}) = E_{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_{z}(z)} [\log (1 - D(G(z)))]
]
其中,$V(\theta^{(D)}, \theta^{(G)})$ 称为价值函数,我们希望相对于判别器(D)最大化其值,相对于生成器(G)最小化其值,即 $\min_{G} \max_{D} V(\theta^{(D)}, \theta^{(G)})$。
$D(x)$ 表示输入示例 𝒙 是真实还是合成的概率。$E_{x \sim p_{data}(x)}[\log D(x)]$ 是关于数据分布(真实示例的分布)中示例的括号内数量的期望值;$E_{z \sim p_{z}(z)} [\log (1 - D(G(z)))]$ 是关于输入向量 𝒛 的分布的括号内数量的期望值。
训练 GAN 模型的一个步骤需要两个优化步骤:
1. 最大化判别器的收益。
2. 最小化生成器的收益。
实际训练时,交替进行这两个优化步骤:固定一个网络的参数,优化另一个网络的权重,然后交换。
当固定生成器网络,优化判别器时,价值函数的两项都有助于优化判别器,第一项对应真实示例的损失,第二项对应合成示例的损失,目标是最大化 $V(\theta^{(D)}, \theta^{(G)})$,使判别器更好地区分真实和合成图像。
当固定判别器,优化生成器时,只有价值函数的第二项对生成器的梯度有贡献,目标是最小化 $V(\theta^{(D)}, \theta^{(G)})$,即 $\min_{G} E_{z \sim p_{z}(z)} [\log (1 - D(G(z)))]$。但在训练早期,这个函数会出现梯度消失的问题,因为早期生成的输出与真实示例相差很大,$D(G(z))$ 会非常接近零。为解决这个问题,可将最小化目标改写为 $\max_{G} E_{z \sim p_{z}(z)} [\log (D(G(z)))]$。
4.2 训练 GAN 时的数据标签
判别器是一个二分类器,类别标签为 0(合成图像)和 1(真实图像),可使用二元交叉熵损失函数。判别器的真实标签确定如下:
| 数据类型 | 真实标签 |
| ---- | ---- |
| 真实图像(𝒙) | 1 |
| 生成器的输出($G(z)$) | 0 |
训练生成器时,希望其合成逼真的图像,因此当生成器的输出未被判别器分类为真实时,要对其进行惩罚。所以在计算生成器的损失函数时,假设生成器输出的真实标签为 1。
以下是一个简单 GAN 模型的步骤流程图:
graph LR
classDef startend fill:#F5EBFF,stroke:#BE8FED,stroke-width:2px
classDef process fill:#E5F6FF,stroke:#73A6FF,stroke-width:2px
classDef decision fill:#FFF6CC,stroke:#FFBC52,stroke-width:2px
A([开始]):::startend --> B(初始化生成器和判别器的权重):::process
B --> C(从已知分布采样随机向量 𝒛):::process
C --> D(生成器 G 生成图像 𝒙̃ = G(𝒛)):::process
D --> E(判别器 D 对真实图像 𝒙 和生成图像 𝒙̃ 进行分类):::process
E --> F(计算判别器的损失并优化判别器的权重):::process
F --> G(固定判别器,计算生成器的损失并优化生成器的权重):::process
G --> H{是否达到训练轮数}:::decision
H -- 否 --> C
H -- 是 --> I([结束]):::startend
5. 从零实现 GAN
接下来将介绍如何实现和训练一个 GAN 模型来生成新的手写数字,如 MNIST 数字。由于在普通中央处理器(CPU)上训练可能需要很长时间,下面将介绍如何设置 Google Colab 环境,以便在图形处理单元(GPU)上运行计算。
5.1 在 Google Colab 上训练 GAN 模型
本章的一些代码示例可能需要大量计算资源,超出普通笔记本电脑或无 GPU 的工作站的能力。如果有支持 NVIDIA GPU 的计算设备,并安装了 CUDA 和 cuDNN 库,可以使用其加速计算。但对于大多数人来说,可使用 Google Colaboratory 环境(Google Colab),这是一个免费的云计算服务。
Google Colab 提供在云端运行的 Jupyter Notebook 实例,笔记本可保存在 Google Drive 或 GitHub 上。平台提供多种计算资源,如 CPU、GPU 和张量处理单元(TPU),但执行时间目前限制为 12 小时。本章的代码块最多需要两到三小时的计算时间,不会有问题。如果用于其他需要超过 12 小时的项目,建议使用检查点并保存中间检查点。
访问 Google Colab 很简单,访问 https://colab.research.google.com ,会自动进入一个提示窗口,可看到现有的 Jupyter 笔记本。从提示窗口中点击“GOOGLE DRIVE”标签,将笔记本保存在 Google Drive 上。
要创建新笔记本,点击提示窗口底部的“NEW PYTHON 3 NOTEBOOK”链接,将创建并打开一个新笔记本,代码示例会自动保存,可在 Google Drive 的“Colab Notebooks”目录中访问。
为了利用 GPU 运行代码示例,从笔记本菜单栏的“Runtime”选项中点击“Change runtime type”,选择“GPU”。
以下是在 Google Colab 上设置 GPU 环境的步骤总结:
1. 访问 https://colab.research.google.com 。
2. 点击“GOOGLE DRIVE”标签保存笔记本。
3. 点击“NEW PYTHON 3 NOTEBOOK”创建新笔记本。
4. 从“Runtime”选项中点击“Change runtime type”,选择“GPU”。
通过以上步骤,就可以在 Google Colab 上利用 GPU 加速 GAN 模型的训练。
6. 代码实现细节
6.1 导入必要的库
在开始实现 GAN 之前,我们需要导入一些必要的 Python 库,以下是示例代码:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
6.2 定义生成器和判别器网络
生成器网络
生成器接收一个随机向量 z 并生成图像。以下是一个简单的生成器网络示例:
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
判别器网络
判别器接收图像并判断其是真实图像还是生成的图像。以下是一个简单的判别器网络示例:
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
6.3 数据加载
我们将使用 MNIST 数据集进行训练。以下是数据加载的代码:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
6.4 训练 GAN 模型
以下是训练 GAN 模型的主要代码:
# 定义超参数
input_dim = 100
output_dim = 784 # 28x28 图像
learning_rate = 0.0002
num_epochs = 50
# 初始化生成器和判别器
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)
# 定义优化器和损失函数
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1)
# 训练判别器
d_optimizer.zero_grad()
# 真实图像标签为 1
real_labels = torch.ones(batch_size, 1)
real_output = discriminator(real_images)
d_real_loss = criterion(real_output, real_labels)
# 生成假图像
z = torch.randn(batch_size, input_dim)
fake_images = generator(z)
# 假图像标签为 0
fake_labels = torch.zeros(batch_size, 1)
fake_output = discriminator(fake_images.detach())
d_fake_loss = criterion(fake_output, fake_labels)
# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
# 生成器希望判别器将生成的图像判断为真实图像,标签为 1
fake_labels = torch.ones(batch_size, 1)
fake_output = discriminator(fake_images)
g_loss = criterion(fake_output, fake_labels)
g_loss.backward()
g_optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')
6.5 生成新样本
训练完成后,我们可以使用生成器生成新的手写数字图像。以下是生成新样本的代码:
# 生成一些新样本
num_samples = 16
z = torch.randn(num_samples, input_dim)
generated_images = generator(z).detach().numpy()
# 显示生成的图像
fig, axes = plt.subplots(4, 4, figsize=(4, 4))
axes = axes.flatten()
for i in range(num_samples):
img = generated_images[i].reshape(28, 28)
axes[i].imshow(img, cmap='gray')
axes[i].axis('off')
plt.show()
7. 总结
通过以上内容,我们详细介绍了生成对抗网络(GAN)的基本原理、相关概念以及如何从零实现一个简单的 GAN 模型来生成手写数字图像。
7.1 关键知识点回顾
- GAN 基础 :GAN 的目标是合成与训练数据集具有相同分布的新数据,由生成器和判别器两个网络组成,它们在训练过程中相互对抗。
- 自编码器 :自编码器由编码器和解码器组成,可用于数据压缩和解压缩,还可作为降维技术。根据潜在空间大小可分为欠完备和过完备自编码器,过完备自编码器可通过修改训练过程用于降噪。
- 生成模型 :生成模型可以从随机向量生成新示例,与自编码器的解码器有相似之处,但生成模型中随机向量的分布是可表征的。常见的生成模型包括 VAEs、自回归模型和归一化流模型等,本文主要关注 GAN 模型。
- GAN 训练 :GAN 训练需要交替优化生成器和判别器的权重,使用特定的损失函数和数据标签。在训练早期,生成器的损失函数可能会出现梯度消失问题,可通过改写目标函数来解决。
- 代码实现 :我们使用 PyTorch 实现了一个简单的 GAN 模型,包括定义生成器和判别器网络、加载数据、训练模型以及生成新样本等步骤。
7.2 未来展望
GAN 作为一种强大的生成模型,在图像生成、数据增强、图像编辑等领域有广泛的应用前景。未来,我们可以进一步探索 GAN 的改进和扩展,例如:
- 提高生成图像的质量 :通过改进网络架构、优化损失函数等方法,生成更加逼真、高质量的图像。
- 拓展应用领域 :将 GAN 应用于更多领域,如视频生成、语音合成、医学图像分析等。
- 解决训练稳定性问题 :GAN 训练过程中可能会出现不稳定的情况,如模式崩溃、梯度消失等,需要研究更有效的训练方法和技巧来解决这些问题。
总之,GAN 是一个充满挑战和机遇的研究领域,相信在未来会有更多的突破和创新。
以下是整个 GAN 实现流程的总结表格:
| 步骤 | 操作 | 代码示例 |
| ---- | ---- | ---- |
| 1 | 导入必要的库 | import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt |
| 2 | 定义生成器和判别器网络 | class Generator(nn.Module): ...
class Discriminator(nn.Module): ... |
| 3 | 数据加载 | transform = transforms.Compose([...])
train_dataset = datasets.MNIST(...)
train_loader = torch.utils.data.DataLoader(...) |
| 4 | 训练 GAN 模型 | for epoch in range(num_epochs): ... |
| 5 | 生成新样本 | z = torch.randn(num_samples, input_dim)
generated_images = generator(z).detach().numpy()
fig, axes = plt.subplots(...) |
以下是整个流程的 mermaid 流程图:
graph LR
classDef startend fill:#F5EBFF,stroke:#BE8FED,stroke-width:2px
classDef process fill:#E5F6FF,stroke:#73A6FF,stroke-width:2px
classDef decision fill:#FFF6CC,stroke:#FFBC52,stroke-width:2px
A([开始]):::startend --> B(导入必要的库):::process
B --> C(定义生成器和判别器网络):::process
C --> D(数据加载):::process
D --> E(训练 GAN 模型):::process
E --> F(生成新样本):::process
F --> G([结束]):::startend
超级会员免费看

46

被折叠的 条评论
为什么被折叠?



