目录
前言
🍨 本文为
中的学习记录博客
[🔗365天深度学习训练营]
🍖 原作者:
[K同学啊]
说在前面
本周任务:基础任务——了解什么是生成对抗网络,生成对抗网络结果是怎么样的,学习本文代码并跑通代码;进阶任务——调用训练好的模型生成新图像
我的环境:Python3.6、Pycharm2020、TensorFlow2.4.0
一、理论基础
生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向,GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。GAN由两个分别被称为生成器和判别器的神经网络组成。其中生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则是真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到提升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本1都输出50%真、50%假的判断。此时,生成器输出的人工样本以及逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。
1.1 生成器
GANs中,生成器G选取随机噪声z作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,他对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设,然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务,这种方法不同于数学方法,学习的过程对人类的理解较不直观。
1.2 判别器
GANs中,判别器D对于输入的样本x,输出一个[0,1]之间的概率数值D(x)。x可能是来自于原始数据集中的真实样本x,也可能是来自于生成器G的人工样本G(z)。通常约定,概率值D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明GAN是一个无监督学习的过程。
1.3 基本原理
GAN是博弈论和机器学习相结合的产物,于2014年lan Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能够自动生成苹果的图片,具备这些功能的算法即任务具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时,采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足lan Goodfellow提出来GAN。
那么GAN是如何完成生成图片这项功能的呢,如上图所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别器模型1D令其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了期满一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。
二、前期准备工作
2.1 定义超参数
- n_epochs:这个参数决定了模型训练的总轮数,轮数越多,模型有更多机会学习数据中的模型,但也可能导致过拟合。
- batch_size:批次大小影响模型每次更新时使用的数据量。较小批次可能导致训练过程波动较大,但可能有助于模型逃离局部最小值;较大的批次则可能使训练更稳定,但需要更多的内存空间。
- lr:学习率控制着模型权重更新的步长,学习率过大可能导致模型在最优解附近震荡甚至发散;学习率过小则可能导致模型收敛速度缓慢或陷入局部最小值。
- b1和b2:这两个参数是Adam优化器的一部分,分别控制一阶矩(梯度的指数移动平均)和二阶矩(梯度平方的指数移动平均)的指数衰减率。它们影响模型更新的稳定性和收敛速度。
- n_cpu:这个参数指定了用于数据加载的CPU数量,可以影响数据预处理和加载的速度,从而影响训练的效率。
- latent_dim:随机向量的维度,它影响生成器生成图像的多样性和质量,维度过低可能导致生成图像缺乏多样性,而维度过高