GAN学习记录(三)——半监督生成对抗网络(SGAN)

本文介绍半监督生成对抗网络(SGAN),其鉴别器能区分多种类别,不仅能辨别真假样本还能正确分类真实样本。通过少量带标签数据训练,SGAN的分类准确性接近全监督分类器。

半监督生成对抗网络(SGAN)

半监督生成对抗网络(Semi-Supervised GAN,SGAN)是一种生成对抗网络,其鉴别器是多分类器。这里的鉴别器不只是区分两个类(真和假),而是学会区分N+1类,其中N是训练数据集中的类数,生成器生成的伪样本增加了一个类。

结构区别

在这里插入图片描述

与传统GAN相比,SGAN区分多个类的任务不仅影响了鉴别器本身,还增加了SGAN架构、训练过程和训练目标的复杂性。

SGAN生成器的目的与原始GAN相同:接收一个随机数向量并生成伪样本,力求使伪样本与训练数据集别无二致。但是,SGAN鉴别器与原始GAN实现有很大不同。它接收3种输入:生成器生成的伪样本X*、训练数据集中无标签的真实样本X和有标签的真实样本X,y。其中y表示给定样本X的标签。

训练区别

除了计算判别器的损失值,还必须计算有监督训练样本的损失:D(x,y)。所以说,SCAN有两种损失值:有监督损失和无监督损失。

在SGAN中主要关心的反而是鉴别器。训练过程的目标是使该网络成为仅使用一小部分标签数据的半监督分类器,其准确率尽可能接近全监督的分类器(其训练数据集中的每个样本都有标签)。生成器的目标是通过提供附加信息(它生成的伪数据)来帮助鉴别器学习数据中的相关模式,从而提高其分类准确率。训练结束时,生成器将被丢弃,而训练有素的鉴别器将被用作分类器。

SGAN的实现

结构图

在这里插入图片描述

为了解决区分真实标签的多分类问题,鉴别器使用了softmax函数,该函数给出了在给定数量的类别(本例中为10类)上的概率分布。给一个给定类别标签分配的概率越高,鉴别器就越确信该样本属于这一给定的类。为了计算分类误差,使用了交叉熵损失,以测量输出概率与目标独热编码标签之间的差异。

# 导入包
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import backend as K
from keras.datasets import mnist
from keras.layers import Dropout, Lambda, Concatenate, Input, Dense, Flatten, Reshape, Activation, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
# 模型输入维度
img_rows = 28
img_cols = 28
channels = 1
# 图像大小
img_shape = (img_rows, img_cols, channels)
# 噪声向量大小
z_dim = 100

num_classes = 10

尽管MNIST训练数据集里有50000个有标签的训练图像,但我们仅将其中的一小部分(由num_labeled参数决定)用于训练,并假设其余图像都是无标签的。我们这样来实现这一点:取批量有标签数据时仅从前num_labeled个图像采样,而在取批量无标签数据时从其余(50000 – num_labeled)个图像中采样。

class Dataset:
    def __init__(self, num_labeled):
        self.num_labeled = num_labeled
        mnist=tf.keras.datasets.mnist
        (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data('./MNIST')
        def preprocess_imgs(X):
            X = (X.astype(np.float32) - 127.5)/127.5
            X = np.expand_dims(X, axis=3)
            return X
        def preprocess_labels(y):
            return y.reshape(-1, 1)
        self.x_train = preprocess_imgs(self.x_train)
        self.y_train = preprocess_labels
评论 6
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值