基于生成对抗网络(GAN)生成手写数字(MNIST)

GAN生成手写数字识别

生成对抗网络介绍:

生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。
生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。
判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。

生成器:

     def Generator(self):
        #构建模型
        gen=models.Sequential()
        gen.add(layers.Dense(256,input_dim=self.dims))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(512))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(1024))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
        gen.add(Reshape(self.shape))

        #输出模型
        gen.summary()

        #生成随机噪声
        noise=Input(shape=(self.dims,))
        #生成图片
        img=gen(noise)

        return Model(noise,img)

判别器:

    def Discrimation(self):
        #构建模型
        model=models.Sequential()
        model.add(layers.Flatten(input_shape=self.shape))
        model.add(layers.Dense(512))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(256))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(1,activation='sigmoid'))
        #输出模型
        model.summary()
        #输入
        input=Input(shape=(self.shape))
        #输出
        output=model(input)

        return Model(input,output)

训练代码:

#coding=utf-8
from keras import layers,datasets,models,optimizers
from keras.layers import Reshape,Input
from keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.datasets import mnist

class GAN():
    def __init__(self):

        self.row=28
        self.col=28
        self.channel=1
        self.shape=(self.row,self.col,self.channel)
        self.dims=100

        optimizer=optimizers.Adam(0.0002, 0.5)

        # 判别器
        self.Dis = self.Discrimation()
        self.Dis.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])

        # 生成器
        self.Gen = self.Generator()
        # 随机噪声
        z = Input(shape=(self.dims,))
        img = self.Gen(z)
        # 判别器不训练
        self.Dis.trainable = False
        # 对假图片进行预测
        validity = self.Dis(img)
        #返回模型
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def Generator(self):
        #构建模型
        gen=models.Sequential()
        gen.add(layers.Dense(256,input_dim=self.dims))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(512))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(1024))
        gen.add(layers.LeakyReLU(alpha=0.2))
        gen.add(layers.BatchNormalization(momentum=0.8))

        gen.add(layers.Dense(self.row*self.col*self.channel,activation='sigmoid'))
        gen.add(Reshape(self.shape))

        #输出模型
        gen.summary()

        #生成随机噪声
        noise=Input(shape=(self.dims,))
        #生成图片
        img=gen(noise)

        return Model(noise,img)

    def Discrimation(self):
        #构建模型
        model=models.Sequential()
        model.add(layers.Flatten(input_shape=self.shape))
        model.add(layers.Dense(512))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(256))
        model.add(layers.LeakyReLU(alpha=0.2))
        model.add(layers.Dense(1,activation='sigmoid'))
        #输出模型
        model.summary()
        #输入
        input=Input(shape=(self.shape))
        #输出
        output=model(input)

        return Model(input,output)

    def train(self,epochs,batch_size,save_size):
        #加载数据集
        (train_x,_),(_,_)=mnist.load_data()
        #数据集归一化
        train_x=train_x/255.
        train_x=np.expand_dims(train_x,axis=3)

        #生成标签
        real=np.ones((batch_size,1))
        fake=np.zeros((batch_size,1))

        for epoch in range(epochs):

            # 随机选取一批图片
            idx = np.random.randint(0, train_x.shape[0], batch_size)
            imgs = train_x[idx]

            # 生成一批噪声
            noise = np.random.normal(0, 1, (batch_size, self.dims))

            gen_image=self.Gen(noise)

            d_loss_real=self.Dis.train_on_batch(imgs,real)
            d_loss_fake=self.Dis.train_on_batch(gen_image,fake)
            d_loss=0.5*np.add(d_loss_real,d_loss_fake)

            noise=np.random.normal(0,1,(batch_size,self.dims))

            g_loss=self.combined.train_on_batch(noise,real)

            print('Epoch:{} ,D_loss:{}, D_acc:{} ,G_loss:{} '.format(epoch,d_loss[0],100*d_loss[1],g_loss))

            if epoch%save_size==0:
                self.save_image(epoch)

    def save_image(self,epoch):
        r,c=5,5

        noise=np.random.normal(0,1,(r*c,self.dims))
        gen_img=self.Gen.predict(noise)
        cnt=0

        fig,axs=plt.subplots(r,c)
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_img[cnt,:,:,0],cmap='gray')
                axs[i, j].axis('off')
                cnt+=1

        fig.savefig("image/{}.png".format(epoch))
        plt.close()

    def save(self):
        self.Dis.save('model/Discrimator.h5')
        self.Gen.save('model/Generator.h5')

if __name__ == '__main__':
    gan=GAN()
    gan.train(20000,512,100)
    gan.save()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值