生成对抗性神经网络

本文详细介绍了生成对抗网络(GAN)的基本概念、工作原理和架构,以及在训练过程中生成器和鉴别器的角色。通过实例展示了如何在Python和TensorFlow环境下实现GAN,包括模型创建、损失函数定义和优化。此外,文章讨论了GAN的常见问题,如训练稳定性、不适用文本数据的原因,以及模型崩溃的可能原因。GAN在图像生成、无监督学习等领域有广泛应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

生成对抗性神经网络

!!代码地址!!

论文

作者: Lorna

邮箱: shiyipaisizuo@gmail.com

英文原版

配置需求

  • 显卡: A TiTAN V 或更高.
  • 硬盘: 128G SSD.
  • Python版本: python3.5 或更高.
  • CUDA: cuda10.
  • CUDNN: cudnn7.4.5 或更高.
  • Tensorflow-gpu: 2.0.0-alpla0.

运行以下代码。

pip install -r requirements.txt

GAN是什么?

生成对抗网络(GANs)是当今计算机科学中最有趣的概念之一。
两个模型通过对抗性过程同时训练。
生成器(“艺术家”)学会创建看起来真实的图像,而鉴别器(“艺术评论家”)学会区分真实图像和赝品。

在训练过程中,生成器逐渐变得更擅长创建看起来真实的图像,而鉴别器则变得更擅长区分它们。
当鉴别器无法分辨真伪图像时,该过程达到平衡。

下面的动画展示了生成器在经过50个时代的训练后生成的一系列图像。
这些图像一开始是随机噪声,随着时间的推移越来越像手写数字。

一、介绍

1.1 原理

这是一张关于GAN的流程图
GAN

GAN主要的灵感来源是零和游戏在博弈论思想,应用于深学习神经网络,是通过生成网络G(发电机)和判别D(鉴频器)网络游戏不断,从而使G学习数据分布,如果用在图像生成训练完成后,G可以从一个随机数生成逼真的图像。
G和D的主要功能是:

  • G是一个生成网络,它接收一个随机噪声z(随机数),通过噪声生成图像。

  • D是一个判断图像是否“真实”的网络。它的输入参数是x, x代表一张图片,输出D (x)代表x是一张真实图片的概率。如果是1,代表100%真实的图像,如果是0,代表不可能的图像。

在训练过程中,生成网络G的目标是生成尽可能多的真实图像来欺骗网络D,而D的目标是试图将G生成的假图像与真实图像区分开来。这样,G和D构成一个动态的“博弈过程”,最终的均衡点为纳什均衡点。

1.2 体系结构

通过对目标的优化,可以调整概率生成模型的参数,使概率分布与实际数据分布尽可能接近。

那么,如何定义适当的优化目标或损失呢?
在传统的生成模型中,一般采用数据的似然作为优化目标,而GAN创新性地使用了另一个优化目标。

  • 首先,引入判别模型(常用模型包括支持向量机和多层神经网络)。

  • 其次,其优化过程是在生成模型和判别模型之间找到纳什均衡。

GAN建立的学习框架实际上是生成模型和判别模型之间的模拟博弈。
生成模型的目的是尽可能多地模拟、建模和学习真实数据的分布规律。
判别模型是判断一个输入数据是来自真实的数据分布还是生成的模型。
通过这两个内部模型之间的持续竞争,提高了生成和区分这两个模型的能力。

当一个模型具有很强的区分能力时。
如果生成的模型数据仍然存在混淆,不能正确判断,那么我们认为生成的模型实际上已经了解了真实数据的分布情况。

1.3 GAN特性

特点:

  • low与传统模式相比,有两种不同的网络,而不是单一的网络,采用的是对抗训练方法和训练方式。

  • 更新信息中的低GAN梯度G来自判别式D,而不是来自样本数据。

优势:

  • low GAN是一个涌现模型,相对于其他生成模型(玻尔兹曼机和GSNs),它只通过反向传播,不需要复杂的马尔可夫链。

  • 与其它所有机型相比,GAN能生产出更清晰、真实的样品

  • low GAN是一种无监督学习训练,可广泛应用于半监督学习和无监督学习领域。

  • 与变分自编码器相比,GANs不引入任何确定性偏差,变分方法引入确定性偏差,因为它们优化了对数似然的下界而不是似然本身,这似乎导致VAEs生成的实例比GANs更加模糊。

  • 与VAE、GANs的变分下界相比较低,如果判别器训练良好,则生成器可以学习完善训练样本分布。换句话说,GANs是逐渐一致的,但是VAE是有偏见的。

  • GAN——应用于一些场景,比如图片风格转换、超分辨率,图像完成、噪声去除,避免了损失函数设计的困难,只要有一个基准,直接鉴别器,其余的对抗训练。

缺点:

  • 训练GAN需要达到Nash均衡,有时可以通过梯度下降法实现,有时则不能。我们还没有找到一个很好的方法来达到纳什均衡,所以与VAE或PixelRNN相比,GAN的训练是不稳定的,但我认为在实践中它比训练玻尔兹曼机更稳定。

  • GAN不适用于处理离散数据,如文本。

  • GAN存在训练不稳定、梯度消失和模态崩溃等问题。

二、实现

加载和准备数据集,将使用MNIST数据集来训练生成器和鉴别器。生成器将生成类似MNIST数据的手写数字。

import tensorflow as tf


def load_dataset(mnist_size, mnist_batch_size, cifar_size, cifar_batch_size,):
  """ load mnist and cifar10 dataset to shuffle.

  Args:
    mnist_size: mnist dataset size.
    mnist_batch_size: every train dataset of mnist.
    cifar_size: cifar10 dataset size.
    cifar_batch_size: every train dataset of cifar10.

  Returns:
    mnist dataset, cifar10 dataset

  """
  # load mnist data
  (mnist_train_images, mnist_train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

  # load cifar10 data
  (cifar_train_images, cifar_train_labels), (_, _) = tf.keras.datasets.cifar10.load_data()

  mnist_train_images = mnist_train_images.reshape(mnist_train_images.shape[0], 28, 28, 1).astype('float32')
  mnist_train_images = (mnist_train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

  cifar_train_images = cifar_train_images.reshape(cifar_train_images.shape[0], 32, 32, 3).astype('float32')
  cifar_train_images = (cifar_train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

  # Batch and shuffle the data
  mnist_train_dataset = tf.data.Dataset.from_tensor_slices(mnist_train_images)
  mnist_train_dataset = mnist_train_dataset.shuffle(mnist_size).batch(mnist_batch_size)

  cifar_train_dataset = tf.data.Dataset
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值