生成对抗网络GAN

本文详细介绍了生成对抗网络(GAN)的概念,包括判别模型和生成模型的区别,GAN的运作原理,以及其与强化学习的相似性。文中提到了GAN在图像生成、类别指定生成、文本生成等方面的应用,并探讨了GAN的训练问题和一些解决方案,如WGAN和CGAN。同时,文章提供了使用Keras实现GAN的基本步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、判别模型和生成模型

有监督的机器学习中,我们可以概述为通过很多有标记的数据,训练出一个模型,然后利用这个,对输入的X进行预测输出的Y。这个模型一般有两种:
决策函数:
Y=f(X)
条件概率分布:
P(Y|X)
根据通过学习数据来获取这两种模型的方法,我们可以分为判别方法和生成方法 :

判别方法

由数据 直接学习决策函数Y=f(X)或条件概率分布P(Y|X)作为预测模型,即为判
别模型。判别方法关心的是对于给定的输入X,应该预测什么样的输出Y 。

经典判别模型
支持向量机(SVM)、k近邻法、感知机、决策树、逻辑回归、线性回归、最大熵模型、提升方法、条件随机场(CRF)

在这里插入图片描述

生成方法

由数据学习 联合概率分布P(X,Y), 然后由P(Y|X)=P(X,Y)/P(X)求出概率分布P(Y|X)作为预测的模型。该方法表示了给定输入X与产生输出Y的生成关系。P(Y|X)作为的预测的模型就是生成模型 ;
经典生成模型
朴素贝叶斯、隐马尔可夫(em算法)
在这里插入图片描述

两种模型的对比 :

1、生成模型可以还原出联合概率分布(还原数据本身相似度),而判别方法不能;即
生成模型不只可以用于做预测。也可以用来做模拟。
2、生成方法的学习收敛速度更快,当样本容量增加的时候,学到的模型可以更快的收
敛于真实模型;
3、当存在隐变量时,仍可以利用生成方法学习,此时判别方法不能用;
4、判别学习不能反映训练数据本身的特性,但它寻找不同类别之间的最优分类面,反
映的是异类数据之间的差异,直接面对预测,往往学习的准确率更高,由于直接学习
P(Y|X)或Y=f(X),从而可以简化学习; 简单的说,生成模型是从大量的数据中找规
律,属于统计学习;而判别模型只关心不同类型的数据的差别,利用差别来分类。
6. 生成模型需要尽量大量的数据,要不还原出的联合概率会不准。且传统概率生成模型
一般都需要进行马可夫链式的采样和推断,训练速度比较慢
7. 由生成模型可以得到判别模型,但由判别模型得不到生成模型

举例说明:
1、识别语言种类
在这里插入图片描述
2、识别人脸
在这里插入图片描述

二、什么是GAN 网络?

GAN 启发自博弈论中的二人零和博弈(two-player game)(即二人的利益之和为零, 一方的所
得正是另一方的所失),GAN 模型中的两位博弈方分别由生成式模型(generative model)和
判别式模型(discriminative model)充当。作者为Ian J. Goodfellow。
生成模型 G 捕捉样本数据的分布,用服从某一分布(均匀分布,高斯分布等)的噪声 z 生成
一个类似真实训练数据的样本,追求效果是越像真实样本越好;判别模型 D 是一个二分类器,
估计一个样本来自于训练数据(而非生成数据)的概率,如果样本来自于真实的训练数据,D
输出大概率,否则,D 输出小概率。
在这里插入图片描述在这里插入图片描述论文《Generative Adversarial Nets》- 2014论文中理论证明了,这场游戏收敛时,存在一个全局最优解:生成器生成的分布=数据本身的分布。 并且采用论文中的训练方式,GAN模型终将收敛 。
在这里插入图片描述训练方式:
在这里插入图片描述目标函数:
在这里插入图片描述胡伯的PM 模型与“ 好小伙” 的GAN网络
在这里插入图片描述
胡伯的PM模型
在这里插入图片描述
2.1 GAN的缺点
1.GAN 采用对抗学习的准则, 理论上还 不能判断模型的收敛性和均衡点的存在性. 训练过 程需要保证两个对抗网络的平衡和同步, 否则难以得到很好的训练效果. 而实际过程中两个对抗网络 的同步不易把控, 训练过程可能不稳定.
2.另外, 作为以神经网络为基础的生成式模型, GAN存在神 经网络类模型的一般性缺陷, 即可解释性差.
3.另外, GAN 生成的样本虽然具有多样性, 但是多模态样本容易导致崩溃模 式 (Collapse mode) 现象:生成器只会生成一两种类别的样本

在这里插入图片描述2.2 GAN的发展:针对训练不稳定
在这里插入图片描述
2.2.1 针对训练不稳定:WGAN
在这里插入图片描述
2.3 GAN的发展:针对模型结构的优化
GAN网络只能随机产生一个类别 → CGAN , 指定类别来生成 指定类别来生成
不太善于生成离散数据,如文本 → Seq-GAN,能够生成离散序列的生成式模型
Pix2Pix,CycleGAN,StarGAN: 图像翻译
InfoGAN: 有一定可解释性的生成模型
Auxiliary Classifier GAN (AC-GAN) :判别器多分类
……

2.3.1 CGAN: 按指定类别来生成样本
在这里插入图片描述
2.3.2 CGAN: 按指定类别来生成样本
在这里插入图片描述

三、GAN 与 强化学习

问:GAN与强化学习(RL)原则之间有什么相似之处(如果有的话)?我对这两者都不是专家(只有非常基本的了解),我觉得GAN的“generator - discriminator”的想法和RL的“agent -environment interaction”有着紧密的联系。是这样吗?

“我也不是RL的专家,但我认为GAN是使用RL来解决生成建模问题的一种方式。GAN的不同之处在于,奖励函数对行为是完全已知和可微分的,奖励是非固定的,以及奖励是agent的策略的一个函数。但我认为GAN基本上可以说就是RL。”——Ian J.Goodfellow

四、精彩纷呈的GAN 网络

在这里插入图片描述在这里插入图片描述在这里插入图片描述Pix2Pix
在这里插入图片描述
在这里插入图片描述

五、GAN源代码

牢记 Keras 四步走 :Build:构建模型 → G生成器, D判别器
Compile:模型(定义学习过程)
Fit:准备数据+训练
Evaluate/Predict: 查看结果
(详见代码)

GAN

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys,os

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        #Batch Normalizationn的思想则是对于每一组batch,在网络的每一层中,分feature对输入进行normalization,对各个feature分别normalization,即对网络中每一层的单个神经元输入,计算均值和方差后,再进行normalization。
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值