生成对抗网络(GAN)的原理、构建与训练
1. GAN 训练步骤与原理
1.1 训练步骤
GAN 的训练过程包含四个主要步骤,每个步骤都有其特定的目标和作用:
1.
寻找假阴性(False Negatives)
:将真实票据输入鉴别器。若鉴别器将真实票据误判为伪造票据,误差函数会驱动反向传播更新鉴别器的权重,使其更好地识别真实票据。
2.
寻找真阴性(True Negatives)
:将生成器的输出直接连接到鉴别器的输入。随机数输入生成器,生成伪造票据后输入鉴别器。若鉴别器正确识别出伪造票据,会通过反向传播计算梯度,但仅更新生成器的权重,此时鉴别器的权重被冻结。
3.
寻找假阳性(False Positives)
:给鉴别器一个伪造票据。若鉴别器将其误判为真实票据,则更新鉴别器的权重,使其更好地识别伪造票据。
4.
重复真阴性步骤
:重复第二步,确保在每一轮训练中,鉴别器和生成器都有两次更新的机会。
1.2 训练原理
通过这四个步骤的循环训练,鉴别器会不断提高识别真实票据和发现伪造票据错误的能力,而生成器则会不断学习如何创建难以被鉴别器识别的伪造票据。这种训练方式使得两个网络在“学习战斗”中相互促进,最终达到一种平衡状态,即纳什均衡。
1.3 为何称为“对抗”
虽然从表面上看,生成器和鉴别器似乎是在合作,但“对抗”这个词源于博弈论。可以将鉴别器想象成警察侦探,生成器想象成独自作案的伪造者,伪造者试图欺骗侦探,而侦探则努力识别伪造品,二者处于对立关系。这种对抗性的视角为理解 GAN 提供了一种不同的思路。
2. GAN 的网络结构
2.1 鉴别器(Discriminator)
鉴别器是三个模型中最简单的一个。它以样本作为输入,输出一个单一的值,表示网络对输入样本来自训练集而非伪造样本的置信度。输出值为 1 表示鉴别器确定输入是真实票据,值为 0 表示确定是伪造票据,值为 0.5 表示无法确定。鉴别器的构建没有太多限制,可以是浅层或深层网络,使用各种类型的层,如全连接层、卷积层、循环层、变压器层等。
2.2 生成器(Generator)
生成器以一组随机数作为输入,输出一个合成样本。在货币伪造的例子中,输出将是一张图像。生成器的损失函数单独来看没有意义,通常在一些实现中甚至不会定义。训练生成器时,将其与鉴别器连接,生成器从组合网络的损失函数中学习。训练完成后,通常会丢弃鉴别器,保留生成器,因为生成器的目的是用于生成新的数据。生成器的构建也没有太多限制,可以是浅层或深层网络,使用各种类型的层。
以下是生成器和鉴别器的构建示例:
-
简单生成器
:
- 输入:四个随机值,均匀选自 0 到 1 的范围。
- 第一层:全连接层,有 16 个神经元,使用 Leaky ReLU 激活函数(负数值缩放因子为 0.1)。
- 第二层:全连接层,有 2 个神经元,无激活函数。输出为一个点的 (x, y) 坐标。
-
简单鉴别器
:
- 输入:一个 (x, y) 点。
- 前两层:全连接层,每层有 16 个神经元,使用 Leaky ReLU 激活函数。
- 最后一层:全连接层,有 1 个神经元,使用 sigmoid 激活函数。输出为网络对输入来自训练集分布的置信度。
2.3 组合模型
将生成器和鉴别器连接在一起形成组合模型。生成器的输出是一个 (x, y) 对,正好作为鉴别器的输入。需要注意的是,组合模型中的生成器和鉴别器与单独的生成器和鉴别器是同一模型,只是连接在一起形成一个更大的模型。
3. GAN 训练过程的详细说明
3.1 训练步骤的具体操作
-
第一步:寻找假阴性
- 操作 :将真实票据输入鉴别器。
- 目标 :若鉴别器将真实票据误判为伪造票据,通过误差函数驱动反向传播更新鉴别器权重,提高其识别真实票据的能力。
- 示例图 :
graph LR
A[真实票据数据库] --> B[鉴别器]
B --> C{判断结果}
C -->|假| D[误差函数]
D --> E[反向传播更新权重]
C -->|真| F[无更新]
-
第二步:寻找真阴性
- 操作 :随机数输入生成器,生成伪造票据后输入鉴别器。
- 目标 :若鉴别器正确识别出伪造票据,通过反向传播计算梯度,但仅更新生成器权重,鉴别器权重冻结。
- 示例图 :
graph LR
A[随机数] --> B[生成器]
B --> C[伪造票据]
C --> D[鉴别器]
D --> E{判断结果}
E -->|假| F[误差函数]
F --> G[反向传播计算梯度]
G --> H[更新生成器权重]
E -->|真| I[无更新]
-
第三步:寻找假阳性
- 操作 :给鉴别器一个伪造票据。
- 目标 :若鉴别器将伪造票据误判为真实票据,更新鉴别器权重,提高其识别伪造票据的能力。
- 示例图 :
graph LR
A[随机数] --> B[生成器]
B --> C[伪造票据]
C --> D[鉴别器]
D --> E{判断结果}
E -->|真| F[误差函数]
F --> G[反向传播更新权重]
E -->|假| H[无更新]
-
第四步:重复真阴性步骤
- 操作 :同第二步。
- 目标 :确保鉴别器和生成器在每一轮训练中都有两次更新机会。
3.2 训练中的注意事项
在训练生成器时,使用组合模型时不希望同时训练鉴别器。虽然需要通过鉴别器进行反向传播来计算生成器的梯度,但只更新生成器的权重。训练鉴别器和生成器时应交替进行,以确保两个模型以大致相同的速率进行训练。
4. GAN 在实际中的应用示例
4.1 简单数据集设定
为了便于理解和可视化,我们选择一个简单的二维数据集。将训练集中的所有样本想象成抽象空间中的点云,这里的“真实”样本是具有高斯分布的二维点云,中心点为 (5, 5),标准差为 1。
4.2 生成器和鉴别器的目标
- 生成器 :尝试学习如何将输入的随机数转换为看起来属于该高斯分布的点,使鉴别器无法区分真实点和生成器生成的合成点。
- 鉴别器 :能够准确判断输入的点是来自原始高斯分布还是生成器生成的伪造点。
4.3 训练的挑战与应对
GAN 训练具有挑战性,因为它对网络架构和超参数非常敏感。微小的架构变化或超参数调整可能会导致模型性能的巨大差异。因此,在开发 GAN 时,需要使用特定的数据进行实验,尽快找到合适的设计和超参数。这通常意味着对训练数据的小片段进行大量小实验,以寻找合适的网络和超参数。
4.4 简单模型的构建
我们构建了一个简单的生成器和鉴别器:
| 模型 | 输入 | 中间层 | 输出 |
| ---- | ---- | ---- | ---- |
| 生成器 | 4 个随机数(0 - 1 均匀分布) | 16 个神经元的全连接层(Leaky ReLU 激活) -> 2 个神经元的全连接层(无激活) | (x, y) 坐标 |
| 鉴别器 | (x, y) 点 | 16 个神经元的全连接层(Leaky ReLU 激活) -> 16 个神经元的全连接层(Leaky ReLU 激活) -> 1 个神经元的全连接层(sigmoid 激活) | 置信度(0 - 1) |
通过以上步骤和方法,我们可以构建并训练一个简单的 GAN 系统,使其在二维数据集中实现生成与鉴别功能。在实际应用中,根据不同的数据集和任务需求,需要进一步调整网络架构和超参数,以获得更好的性能。
5. 构建与训练简单 GAN 的详细流程
5.1 构建网络模型
生成器模型
我们已经了解到简单生成器的结构,以下是对其构建的更详细解释:
- 输入层接收 4 个随机数,这些随机数均匀分布在 0 到 1 的范围内。
- 第一个全连接层包含 16 个神经元,使用 Leaky ReLU 激活函数。Leaky ReLU 与普通 ReLU 类似,但对于负数值,它会将其乘以一个小的缩放因子(这里是 0.1),避免了神经元死亡的问题。
- 第二个全连接层有 2 个神经元,没有使用激活函数。这一层的输出就是一个点的 (x, y) 坐标。
鉴别器模型
简单鉴别器的构建步骤如下:
- 输入层接收一个 (x, y) 点。
- 前两个全连接层每层都有 16 个神经元,同样使用 Leaky ReLU 激活函数,以增加模型的非线性表达能力。
- 最后一个全连接层只有 1 个神经元,使用 sigmoid 激活函数。sigmoid 函数将输出值映射到 0 到 1 的范围内,表示网络对输入点来自训练集分布的置信度。
5.2 训练流程
整体流程概述
GAN 的训练过程是一个迭代的过程,每一轮训练包含四个步骤,不断交替更新生成器和鉴别器的权重,以达到纳什均衡。
具体步骤详解
-
寻找假阴性
- 从真实样本数据库中选取一批真实票据作为输入,将其输入到鉴别器中。
- 鉴别器对输入样本进行判断,输出判断结果。
- 如果鉴别器将真实票据误判为伪造票据,误差函数会计算误差,并通过反向传播算法更新鉴别器的权重,使其更好地识别真实票据。
- 示例代码(伪代码):
for real_sample in real_samples:
output = discriminator(real_sample)
if output < threshold: # 误判为假
error = error_function(output, 1)
discriminator.backward(error)
discriminator.update_weights()
-
寻找真阴性
- 生成一批随机数,将其输入到生成器中,生成伪造票据。
- 将生成的伪造票据输入到鉴别器中。
- 如果鉴别器正确识别出伪造票据,通过反向传播算法计算梯度,但只更新生成器的权重,鉴别器的权重保持不变(即冻结鉴别器)。
- 示例代码(伪代码):
for _ in range(batch_size):
random_numbers = generate_random_numbers()
fake_sample = generator(random_numbers)
output = discriminator(fake_sample)
if output < threshold: # 正确识别为假
error = error_function(output, 0)
combined_network.backward(error)
generator.update_weights()
-
寻找假阳性
- 再次生成一批随机数,输入到生成器中生成伪造票据。
- 将伪造票据输入到鉴别器中。
- 如果鉴别器将伪造票据误判为真实票据,误差函数计算误差,并通过反向传播算法更新鉴别器的权重,提高其识别伪造票据的能力。
- 示例代码(伪代码):
for _ in range(batch_size):
random_numbers = generate_random_numbers()
fake_sample = generator(random_numbers)
output = discriminator(fake_sample)
if output > threshold: # 误判为真
error = error_function(output, 0)
discriminator.backward(error)
discriminator.update_weights()
-
重复真阴性步骤
- 重复第二步的操作,确保在每一轮训练中,生成器和鉴别器都有两次更新的机会,以保证它们以大致相同的速率进行学习。
5.3 训练过程的可视化
为了更好地理解训练过程,我们可以对训练过程进行可视化。例如,在二维数据集中,我们可以绘制生成器生成的点和真实样本的分布,观察随着训练的进行,生成器生成的点是否越来越接近真实样本的分布。
graph LR
classDef startend fill:#F5EBFF,stroke:#BE8FED,stroke-width:2px;
classDef process fill:#E5F6FF,stroke:#73A6FF,stroke-width:2px;
classDef decision fill:#FFF6CC,stroke:#FFBC52,stroke-width:2px;
A([开始训练]):::startend --> B(寻找假阴性):::process
B --> C(寻找真阴性):::process
C --> D(寻找假阳性):::process
D --> E(重复真阴性步骤):::process
E --> F{是否达到训练轮数}:::decision
F -->|否| B
F -->|是| G([结束训练]):::startend
6. 总结与展望
6.1 总结
生成对抗网络(GAN)是一种强大的深度学习模型,由生成器和鉴别器两个网络组成。通过对抗训练的方式,生成器不断学习如何生成更逼真的样本,而鉴别器则不断提高识别真实样本和伪造样本的能力。在训练过程中,通过交替更新生成器和鉴别器的权重,最终使两个网络达到纳什均衡。
在实际应用中,我们可以选择合适的数据集和网络架构,构建并训练 GAN 系统。例如,在二维数据集中,我们可以构建简单的生成器和鉴别器,通过不断调整超参数和网络架构,使生成器生成的样本越来越接近真实样本的分布。
6.2 展望
虽然 GAN 已经取得了很多成功的应用,但仍然面临一些挑战。例如,GAN 训练不稳定,容易出现模式崩溃、梯度消失等问题。未来的研究可以从以下几个方面进行探索:
-
改进训练算法
:研究更稳定、高效的训练算法,解决 GAN 训练中的不稳定问题。
-
拓展应用领域
:将 GAN 应用到更多的领域,如医学图像生成、自然语言处理等。
-
融合其他技术
:将 GAN 与其他深度学习技术(如卷积神经网络、循环神经网络等)相结合,提高模型的性能。
通过不断的研究和实践,相信 GAN 将会在更多的领域发挥重要作用,为人工智能的发展做出更大的贡献。
超级会员免费看
969

被折叠的 条评论
为什么被折叠?



