一、判别模型和生成模型
有监督的机器学习中,我们可以概述为通过很多有标记的数据,训练出一个模型,然后利用这个,对输入的X进行预测输出的Y。这个模型一般有两种:
决策函数:
Y=f(X)
条件概率分布:
P(Y|X)
根据通过学习数据来获取这两种模型的方法,我们可以分为判别方法和生成方法 :
判别方法
由数据 直接学习决策函数Y=f(X)或条件概率分布P(Y|X)作为预测模型,即为判
别模型。判别方法关心的是对于给定的输入X,应该预测什么样的输出Y 。
经典判别模型
支持向量机(SVM)、k近邻法、感知机、决策树、逻辑回归、线性回归、最大熵模型、提升方法、条件随机场(CRF)
生成方法
由数据学习 联合概率分布P(X,Y), 然后由P(Y|X)=P(X,Y)/P(X)求出概率分布P(Y|X)作为预测的模型。该方法表示了给定输入X与产生输出Y的生成关系。P(Y|X)作为的预测的模型就是生成模型 ;
经典生成模型
朴素贝叶斯、隐马尔可夫(em算法)
两种模型的对比 :
1、生成模型可以还原出联合概率分布(还原数据本身相似度),而判别方法不能;即
生成模型不只可以用于做预测。也可以用来做模拟。
2、生成方法的学习收敛速度更快,当样本容量增加的时候,学到的模型可以更快的收
敛于真实模型;
3、当存在隐变量时,仍可以利用生成方法学习,此时判别方法不能用;
4、判别学习不能反映训练数据本身的特性,但它寻找不同类别之间的最优分类面,反
映的是异类数据之间的差异,直接面对预测,往往学习的准确率更高,由于直接学习
P(Y|X)或Y=f(X),从而可以简化学习; 简单的说,生成模型是从大量的数据中找规
律,属于统计学习;而判别模型只关心不同类型的数据的差别,利用差别来分类。
6. 生成模型需要尽量大量的数据,要不还原出的联合概率会不准。且传统概率生成模型
一般都需要进行马可夫链式的采样和推断,训练速度比较慢
7. 由生成模型可以得到判别模型,但由判别模型得不到生成模型
举例说明:
1、识别语言种类
2、识别人脸
二、什么是GAN 网络?
GAN 启发自博弈论中的二人零和博弈(two-player game)(即二人的利益之和为零, 一方的所
得正是另一方的所失),GAN 模型中的两位博弈方分别由生成式模型(generative model)和
判别式模型(discriminative model)充当。作者为Ian J. Goodfellow。
生成模型 G 捕捉样本数据的分布,用服从某一分布(均匀分布,高斯分布等)的噪声 z 生成
一个类似真实训练数据的样本,追求效果是越像真实样本越好;判别模型 D 是一个二分类器,
估计一个样本来自于训练数据(而非生成数据)的概率,如果样本来自于真实的训练数据,D
输出大概率,否则,D 输出小概率。
论文《Generative Adversarial Nets》- 2014论文中理论证明了,这场游戏收敛时,存在一个全局最优解:生成器生成的分布=数据本身的分布。 并且采用论文中的训练方式,GAN模型终将收敛 。
训练方式:
目标函数:
胡伯的PM 模型与“ 好小伙” 的GAN网络
胡伯的PM模型
2.1 GAN的缺点
1.GAN 采用对抗学习的准则, 理论上还 不能判断模型的收敛性和均衡点的存在性. 训练过 程需要保证两个对抗网络的平衡和同步, 否则难以得到很好的训练效果. 而实际过程中两个对抗网络 的同步不易把控, 训练过程可能不稳定.
2.另外, 作为以神经网络为基础的生成式模型, GAN存在神 经网络类模型的一般性缺陷, 即可解释性差.
3.另外, GAN 生成的样本虽然具有多样性, 但是多模态样本容易导致崩溃模 式 (Collapse mode) 现象:生成器只会生成一两种类别的样本
2.2 GAN的发展:针对训练不稳定
2.2.1 针对训练不稳定:WGAN
2.3 GAN的发展:针对模型结构的优化
GAN网络只能随机产生一个类别 → CGAN , 指定类别来生成 指定类别来生成
不太善于生成离散数据,如文本 → Seq-GAN,能够生成离散序列的生成式模型
Pix2Pix,CycleGAN,StarGAN: 图像翻译
InfoGAN: 有一定可解释性的生成模型
Auxiliary Classifier GAN (AC-GAN) :判别器多分类
……
2.3.1 CGAN: 按指定类别来生成样本
2.3.2 CGAN: 按指定类别来生成样本
三、GAN 与 强化学习
问:GAN与强化学习(RL)原则之间有什么相似之处(如果有的话)?我对这两者都不是专家(只有非常基本的了解),我觉得GAN的“generator - discriminator”的想法和RL的“agent -environment interaction”有着紧密的联系。是这样吗?
“我也不是RL的专家,但我认为GAN是使用RL来解决生成建模问题的一种方式。GAN的不同之处在于,奖励函数对行为是完全已知和可微分的,奖励是非固定的,以及奖励是agent的策略的一个函数。但我认为GAN基本上可以说就是RL。”——Ian J.Goodfellow
四、精彩纷呈的GAN 网络
Pix2Pix
五、GAN源代码
牢记 Keras 四步走 :Build:构建模型 → G生成器, D判别器
Compile:模型(定义学习过程)
Fit:准备数据+训练
Evaluate/Predict: 查看结果
(详见代码)
GAN
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys,os
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
#Batch Normalizationn的思想则是对于每一组batch,在网络的每一层中,分feature对输入进行normalization,对各个feature分别normalization,即对网络中每一层的单个神经元输入,计算均值和方差后,再进行normalization。
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=<