生成对抗网络:合成新数据的强大工具
1. 生成对抗网络基础
生成对抗网络(GAN)的总体目标是合成与训练数据集具有相同分布的新数据。原始形式的 GAN 属于无监督学习任务,因为不需要标记数据。不过,对原始 GAN 的扩展可以应用于半监督和监督任务。
2014 年,Ian Goodfellow 及其同事首次提出了 GAN 的一般概念,用于使用深度神经网络合成新图像。最初提出的 GAN 架构基于全连接层,类似于多层感知机架构,并用于生成低分辨率的类似 MNIST 的手写数字,主要作为概念验证。
自提出以来,原始作者和其他研究人员提出了许多改进和不同领域的应用。在计算机视觉领域,GAN 可用于图像到图像的翻译、图像超分辨率、图像修复等。例如,近期的研究成果使 GAN 能够生成新的高分辨率人脸图像,可在 https://www.thispersondoesnotexist.com/ 上查看示例。
2. 从自编码器开始
在讨论 GAN 的工作原理之前,先了解自编码器。自编码器由编码器网络和解码器网络串联组成。编码器接收与示例 x 相关的 d 维输入特征向量 𝒙𝒙∈𝑅𝑅𝑑𝑑,并将其编码为 p 维向量 𝒛𝒛∈𝑅𝑅𝑝𝑝,即学习函数 𝒛𝒛 = 𝑓𝑓(𝒙𝒙)。编码后的向量 𝒛𝒛 也称为潜在向量或潜在特征表示,通常其维度小于输入示例的维度(p < d),因此编码器起到数据压缩的作用。解码器从低维潜在向量 𝒛𝒛 解压缩得到 𝒙𝒙̂ ,可看作函数 𝒙𝒙̂ = 𝑔𝑔(𝒛𝒛)。
2.1 自编码器与降维的关系
自编码器也可作为降维技术。当编码器和解码器两个子网络都没有非线性时,自编码器方法与主成分分析(PCA)几乎相同。
假设单层编码器(无隐藏层和非线性激活函数)的权重用矩阵 U 表示,则编码器建模为 𝒛𝒛 = 𝑼𝑼𝑇𝑇𝒙𝒙,单层线性解码器建模为 𝒙𝒙̂ = 𝑼𝑼𝑼𝑼,两者结合得到 𝒙𝒙̂ = 𝑼𝑼𝑼𝑼𝑇𝑇𝒙𝒙,与 PCA 相同,只是 PCA 有额外的正交归一化约束 𝑼𝑼𝑼𝑼𝑇𝑇 = 𝑰𝑰𝑛𝑛×𝑛𝑛。
2.2 基于潜在空间大小的其他类型自编码器
自编码器的潜在空间维度通常低于输入维度(p < d),这种配置称为欠完备自编码器,潜在向量常被称为“瓶颈”。还有一种过完备自编码器,其潜在向量 𝒛𝒛 的维度大于输入示例的维度(p > d)。
训练过完备自编码器时,存在编码器和解码器简单复制输入特征到输出层的平凡解,这并不实用。但通过对训练过程进行一些修改,过完备自编码器可用于降噪。在训练时,向输入示例添加随机噪声 𝝐𝝐,网络学习从噪声信号 𝒙𝒙 + 𝝐𝝐 中重建干净示例 x。评估时,提供自然带有噪声的新示例以去除现有噪声,这种自编码器架构和训练方法称为去噪自编码器。
以下是自编码器相关信息的表格总结:
| 类型 | 潜在空间维度与输入维度关系 | 特点 | 应用 |
| — | — | — | — |
| 欠完备自编码器 | p < d | 适合降维,潜在向量为“瓶颈” | 降维 |
| 过完备自编码器 | p > d | 训练有平凡解,需修改训练过程 | 降噪 |
3. 合成新数据的生成模型
自编码器是确定性模型,训练后只能从低维压缩版本重建输入,无法生成新数据。而生成模型可以从随机向量 𝒛𝒛(对应潜在表示)生成新示例 𝒙𝒙̃ 。随机向量 𝒛𝒛 来自具有完全已知特征的简单分布,如均匀分布 [–1, 1] 或标准正态分布。
自编码器的解码器组件与生成模型有相似之处,都接收潜在向量 𝒛𝒛 作为输入,并在与 𝒙𝒙 相同的空间中返回输出。但两者的主要区别在于,自编码器中 𝒛𝒛 的分布未知,而生成模型中 𝒛𝒛 的分布完全可表征。
可以将自编码器推广为生成模型,如变分自编码器(VAE)。VAE 接收输入示例 x 时,编码器网络计算潜在向量分布的两个矩:均值 𝝁𝝁 和方差 𝝈𝝈2。训练时,网络被迫使这些矩与标准正态分布匹配。训练完成后,丢弃编码器,使用解码器网络通过输入来自“学习”的高斯分布的随机 𝒛𝒛 向量生成新示例 𝒙𝒙̃ 。
除了 VAE,还有其他类型的生成模型,如自回归模型和归一化流模型。本文将重点介绍 GAN 模型,它是深度学习中最新且最流行的生成模型之一。
以下是生成模型相关信息的 mermaid 流程图:
graph LR
A[生成模型] --> B[VAE]
A --> C[自回归模型]
A --> D[归一化流模型]
A --> E[GAN模型]
4. 用 GAN 生成新样本
简单来说,假设有一个网络接收从已知分布采样的随机向量 𝒛𝒛 并生成输出图像 𝒙𝒙̃ ,称为生成器(G),表示为 𝒙𝒙̃ = 𝐺𝐺(𝒛𝒛)。目标是生成各种图像,如人脸图像、建筑物图像、动物图像或手写数字。
初始化网络时,权重是随机的,因此在调整权重之前,生成的图像看起来像白噪声。假设存在一个评估图像质量的函数(评估器函数),可以根据其反馈调整生成器的权重,使生成的图像更逼真。
然而,是否存在这样一个通用的评估图像质量的函数以及如何定义它是个问题。人类可以轻松评估输出图像的质量,但无法将大脑的评估结果反向传播到网络。因此,GAN 的一般思想是设计一个神经网络模型(判别器 D)来完成这个任务,判别器是一个分类器,用于区分合成图像 𝒙𝒙̃ 和真实图像 𝒙𝒙。
在 GAN 模型中,生成器和判别器一起训练。起初,生成器生成的图像不真实,判别器也难以区分真实图像和合成图像。但随着训练的进行,两个网络相互作用并不断改进。生成器学习改进输出以欺骗判别器,判别器则变得更擅长检测合成图像。
5. 理解 GAN 模型中生成器和判别器网络的损失函数
GAN 的目标函数为:
𝑉𝑉(𝜃𝜃(𝐷𝐷), 𝜃𝜃(𝐺𝐺)) = 𝐸𝐸𝒙𝒙∼𝑝𝑝𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑(𝒙𝒙)[log 𝐷𝐷(𝒙𝒙)] + 𝐸𝐸𝒛𝒛∼𝑝𝑝𝒛𝒛(𝒛𝒛) [log (1 −𝐷𝐷(𝐺𝐺(𝒛𝒛)))]
其中,𝑉𝑉(𝜃𝜃(𝐷𝐷), 𝜃𝜃(𝐺𝐺)) 称为价值函数,可解释为收益。我们希望相对于判别器(D)最大化其值,相对于生成器(G)最小化其值,即 min
𝐺𝐺max
𝐷𝐷
𝑉𝑉(𝜃𝜃(𝐷𝐷), 𝜃𝜃(𝐺𝐺))。
D(x) 表示输入示例 x 是真实还是假(即生成)的概率。𝐸𝐸𝒙𝒙∼𝑝𝑝𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑(𝒙𝒙)[log𝐷𝐷(𝒙𝒙)] 指相对于数据分布(真实示例的分布)中示例的括号内数量的期望值;𝐸𝐸𝒛𝒛∼𝑝𝑝𝒛𝒛(𝒛𝒛) [log (1 −𝐷𝐷(𝐺𝐺(𝒛𝒛)))] 指相对于输入 𝒛𝒛 向量分布的数量的期望值。
训练 GAN 模型的一个步骤需要两个优化步骤:
1. 最大化判别器的收益。
2. 最小化生成器的收益。
实际训练中,交替进行这两个优化步骤:固定一个网络的参数并优化另一个网络的权重,然后固定第二个网络并优化第一个网络,每个训练迭代都重复这个过程。
当固定生成器网络并优化判别器时,价值函数的两项都有助于优化判别器,第一项对应真实示例的损失,第二项对应假示例的损失,目标是最大化 𝑉𝑉(𝜃𝜃(𝐷𝐷), 𝜃𝜃(𝐺𝐺)),使判别器更好地区分真实和生成的图像。
优化判别器后,固定判别器并优化生成器。此时,只有价值函数的第二项对生成器的梯度有贡献,目标是最小化 𝑉𝑉(𝜃𝜃(𝐷𝐷), 𝜃𝜃(𝐺𝐺)),可写为 min
𝐺𝐺𝐸𝐸𝒛𝒛∼𝑝𝑝𝒛𝒛(𝒛𝒛) [log (1 −𝐷𝐷(𝐺𝐺(𝒛𝒛)))]。但在训练早期,这个函数会出现梯度消失的问题,因为学习过程早期生成器的输出与真实示例相差很大,D(G(z)) 会非常接近零,这种现象称为饱和。为解决这个问题,将最小化目标改写为 max
𝐺𝐺
𝐸𝐸𝒛𝒛∼𝑝𝑝𝒛𝒛(𝒛𝒛) [log (𝐷𝐷(𝐺𝐺(𝒛𝒛)))]。
这意味着在训练生成器时,可以交换真实和假示例的标签,并进行常规的函数最小化。判别器是一个二元分类器,使用二元交叉熵损失函数,其地面真值标签如下:
| 数据类型 | 地面真值标签 |
| — | — |
| 真实图像 𝒙𝒙 | 1 |
| 生成器输出 𝐺𝐺(𝒛𝒛) | 0 |
训练生成器时,希望生成器合成逼真的图像,因此当生成器的输出未被判别器分类为真实时进行惩罚,计算生成器损失函数时,假设生成器输出的地面真值标签为 1。
以下是简单 GAN 模型的步骤列表:
1. 初始化生成器和判别器的权重。
2. 交替进行判别器和生成器的优化:
- 固定生成器,优化判别器,最大化价值函数。
- 固定判别器,优化生成器,最小化价值函数(处理梯度消失问题后)。
3. 重复步骤 2 直到训练完成。
6. 从零实现 GAN
接下来将介绍如何实现和训练一个 GAN 模型来生成新的图像,如 MNIST 数字。由于在普通中央处理器(CPU)上训练可能需要很长时间,因此将介绍如何设置 Google Colab 环境,以便在图形处理单元(GPU)上运行计算。
6.1 在 Google Colab 上训练 GAN 模型
一些代码示例可能需要大量的计算资源,超出普通笔记本电脑或没有 GPU 的工作站的能力。如果有支持 NVIDIA GPU 的计算设备并安装了 CUDA 和 cuDNN 库,可以使用它来加速计算。
若没有高性能计算资源,可以使用 Google Colaboratory 环境(Google Colab),这是一个免费的云计算服务(在大多数国家可用)。Google Colab 提供在云端运行的 Jupyter Notebook 实例,笔记本可以保存在 Google Drive 或 GitHub 上。平台提供各种计算资源,如 CPU、GPU 和张量处理单元(TPU),但执行时间目前限制为 12 小时,因此运行时间超过 12 小时的笔记本会被中断。
访问 Google Colab 很简单,访问 https://colab.research.google.com,会自动进入一个提示窗口,显示现有的 Jupyter 笔记本。从该提示窗口点击“GOOGLE DRIVE”标签,将笔记本保存在 Google Drive 上。
创建新笔记本的步骤如下:
1. 在提示窗口底部点击“NEW PYTHON 3 NOTEBOOK”链接,将创建并打开一个新笔记本,在其中编写的代码示例会自动保存,可在 Google Drive 的“Colab Notebooks”目录中访问。
2. 为了利用 GPU 运行代码示例,从笔记本菜单栏的“Runtime”选项中点击“Change runtime type”,选择“GPU”。
通过以上步骤,可以在 Google Colab 上高效地训练 GAN 模型。
7. 深入探讨 GAN 训练细节
7.1 训练稳定性问题
在训练 GAN 模型时,稳定性是一个关键问题。由于生成器和判别器之间的对抗关系,训练过程可能会变得不稳定,出现梯度消失、梯度爆炸或模式崩溃等问题。
-
梯度消失
:如前面所述,在训练早期,生成器的输出与真实数据差异较大,判别器能够轻易区分,导致生成器的梯度接近零,训练难以进行。通过将生成器的损失函数改写为 max
𝐺𝐺
𝐸𝐸𝒛𝒛∼𝑝𝑝𝒛𝒛(𝒛𝒛) [log (𝐷𝐷(𝐺𝐺(𝒛𝒛)))] 可以缓解这一问题。 - 梯度爆炸 :当梯度在反向传播过程中不断累积,变得非常大时,会导致模型参数更新过度,使得训练无法收敛。可以通过梯度裁剪(Gradient Clipping)来解决,即限制梯度的最大值。
- 模式崩溃 :生成器可能会陷入只生成有限几种模式的情况,无法覆盖数据的全部分布。为了避免模式崩溃,可以采用一些技巧,如引入噪声、使用批量归一化(Batch Normalization)等。
7.2 超参数调整
GAN 模型的性能很大程度上取决于超参数的选择,以下是一些重要的超参数及其影响:
| 超参数 | 作用 | 影响 |
|---|---|---|
| 学习率 | 控制参数更新的步长 | 过大可能导致训练不稳定,过小则训练速度慢 |
| 批量大小 | 每次训练使用的样本数量 | 较大的批量可以使梯度估计更稳定,但可能会增加内存需求 |
| 优化器 | 用于更新模型参数 | 不同的优化器有不同的收敛速度和稳定性,如 Adam、SGD 等 |
在实际应用中,需要通过实验来选择合适的超参数组合,以达到最佳的训练效果。
7.3 训练过程监控
为了确保 GAN 模型的训练正常进行,需要对训练过程进行监控。可以监控以下指标:
- 损失函数值 :观察生成器和判别器的损失函数值的变化趋势,判断训练是否收敛。
- 生成图像质量 :定期查看生成器生成的图像,直观评估模型的性能。
- 分布匹配度 :可以使用一些统计指标,如 Frechet Inception Distance(FID)来衡量生成数据分布与真实数据分布的匹配程度。
以下是训练过程监控的 mermaid 流程图:
graph LR
A[开始训练] --> B[计算损失函数值]
B --> C[生成图像]
C --> D[评估图像质量]
B --> E[计算分布匹配度]
D --> F{是否满足停止条件}
E --> F
F -- 是 --> G[结束训练]
F -- 否 --> A
8. GAN 的应用领域
8.1 计算机视觉
- 图像生成 :可以生成各种类型的图像,如人脸、风景、动物等。例如,通过训练 GAN 模型,可以生成逼真的人脸图像,用于游戏、影视等领域。
- 图像翻译 :实现不同风格之间的图像转换,如将白天的图像转换为夜晚的图像,将卡通风格的图像转换为写实风格的图像。
- 图像超分辨率 :从低分辨率图像生成高分辨率图像,提高图像的清晰度和质量。
8.2 数据增强
在机器学习和深度学习中,数据量往往是一个限制因素。GAN 可以用于生成额外的训练数据,从而增强模型的泛化能力。例如,在图像分类任务中,可以使用 GAN 生成更多的图像样本,扩充训练数据集。
8.3 音频处理
GAN 也可以应用于音频领域,如生成语音、音乐等。通过训练合适的 GAN 模型,可以生成自然流畅的语音或具有特定风格的音乐。
以下是 GAN 应用领域的列表总结:
1. 计算机视觉
- 图像生成
- 图像翻译
- 图像超分辨率
2. 数据增强
3. 音频处理
9. 总结与展望
生成对抗网络(GAN)作为一种强大的生成模型,在深度学习领域取得了显著的成果。通过生成器和判别器之间的对抗训练,GAN 能够学习数据的分布并生成逼真的新数据。
在实际应用中,虽然 GAN 已经取得了很多成功,但仍然面临一些挑战,如训练稳定性、模式崩溃等问题。未来的研究方向可能包括进一步改进 GAN 的训练算法、探索新的应用领域以及提高生成数据的质量和多样性。
通过本文的介绍,希望读者对 GAN 有了更深入的了解,并能够尝试在自己的项目中应用 GAN 技术,创造出更多有价值的成果。
总之,GAN 为我们提供了一种全新的视角来处理数据生成问题,相信在未来会有更广泛的应用和发展。
超级会员免费看
2828

被折叠的 条评论
为什么被折叠?



