[深度学习从入门到女装]keras实战-GAN(MNIST)

本文介绍使用GAN进行mnist的数据生成,但目前未调试完成,存在bug,训练几个batch后loss会固定,作者正在寻找原因,还给出了相关代码文件名。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文使用GAN进行mnist的数据生成

----------------------------------------------------

注意:未调试完成,存在bug,训练几个batch之后,loss会固定

还在寻找原因,如果有大神能找到的话,麻烦评论告诉我

------------------------------------------------

 

GAN_mnist.py

from keras.engine import Model,Input
from keras.layers import Conv2DTranspose,Conv2D,MaxPooling2D,GlobalAveragePooling2D,Dense
from keras.optimizers import Adam
from keras.losses import binary_crossentropy
import keras.backend as K


def log_loss_discriminator(y_true, y_pred):
    return - K.log(K.maximum(K.epsilon(), y_pred))


def log_loss_generator(y_true, y_pred):
    return K.log(K.maximum(K.epsilon(), 1. - y_pred))

def generative_model(input_shape):

    print("generative_model:")
    #(?,7,7,1)
    inputs=Input(input_shape)

    layer=inputs
    print(str(layer.get_shape()))
    #(?,14,14,32)
    layer=Conv2DTranspose(32,[2,2],strides=[2,2])(layer)
    print(str(layer.get_shape()))
    #(?,28,28,64)
    layer=Conv2DTranspose(64,[2,2],strides=[2,2])(layer)
    print(str(layer.get_shape()))
    #(?,28,28,32)
    layer=Conv2D(32,[3,3],strides=[1,1],padding='same')(layer)
    print(str(layer.get_shape()))
    #(?,28,28,1)
    layer=Conv2D(1,[3,3],strides=[1,1],padding='same')(layer)
    print(str(layer.get_shape()))

    outputs=layer

    model = Model(inputs=inputs, outputs=outputs)

    #model.compile(optimizer=Adam(), loss=categorical_crossentropy, metrics=['accuracy'])
    print("---------------------")
    return model,outputs


def discriminator_model(input_shape):
    #(?,28,28,1)
    print("discriminator_model:")
    inputs=Input(input_shape)

    layer=inputs
    print(str(layer.get_shape()))

    layer=Conv2D(32,[3,3],strides=[1,1],padding='same')(layer)
    print(str(layer.get_shape()))
    #(?,14,14,32)
    layer=MaxPooling2D([2,2],strides=[2,2])(layer)
    print(str(layer.get_shape()))
    layer=Conv2D(64,[3,3],strides=[1,1],padding='same')(layer)
    print(str(layer.get_shape()))
    #(?,7,7,64)
    layer = MaxPooling2D([2, 2], strides=[2, 2])(layer)
    print(str(layer.get_shape()))
    layer=Conv2D(128,[3,3],strides=[1,1],padding='same')(layer)
    print(str(layer.get_shape()))
    #(?,4,4,128)
    layer = MaxPooling2D([2, 2], strides=[2, 2])(layer)
    print(str(layer.get_shape()))
    layer = Conv2D(256, [3, 3], strides=[1, 1], padding='same')(layer)
    print(str(layer.get_shape()))
    layer=GlobalAveragePooling2D()(layer)
    print(str(layer.get_shape()))
    layer=Dense(1,activation='sigmoid')(layer)
    print(str(layer.get_shape()))
    outputs=layer

    model=Model(inputs=inputs,outputs=outputs)

    model.compile(optimizer=Adam(),loss=log_loss_discriminator)
    print("---------------------")
    return model,outputs

def gan_model(gan_in_shape,g_model,d_model):



    gan_in=Input(gan_in_shape)

    gan_out=d_model(g_model(gan_in))

    model=Model(inputs=gan_in,outputs=gan_out)

    model.compile(optimizer=Adam(),loss=log_loss_generator)

    return model

 

 

train.py

import GAN_mnist
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

def train_model():
    mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)
    g_model,g_out=GAN_mnist.generative_model([7,7,1])
    d_model,d_out=GAN_mnist.discriminator_model([28,28,1])
    gan_model = GAN_mnist.gan_model([7, 7, 1],g_model,d_model)
    batch_size=32

    for i in range(500):

        for k in range(10):
            d_model.trainable=True
            x_d,y_d=get_dis_input_output(batch_size,g_model,mnist)
            d_loss=d_model.train_on_batch(x_d,y_d)


        d_model.trainable=False
        x_gan=np.random.uniform(0,1,size=[batch_size,7,7,1])
        y_gan=np.zeros([batch_size,1])
        gan_loss=gan_model.train_on_batch(x_gan,y_gan)



        print(d_loss,gan_loss)



def get_dis_input_output(batch_size,g_model,mnist):

    g_in=np.random.rand(batch_size,7,7,1)
    x_false =g_model.predict(g_in)
    x_true,_=mnist.train.next_batch(batch_size)
    x_true = np.reshape(x_true, [batch_size,28,28,1])


    y_false=np.zeros([batch_size,1])
    y_true=np.zeros([batch_size,1])

    x=np.concatenate([x_true,x_false],axis=0)
    y=np.concatenate([y_true,y_false],axis=0)

    return x,y

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值