基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字

本文介绍了如何使用Keras搭建一个简单的生成对抗网络(GAN),用于生成手写数字。文章涵盖GAN的基本结构,生成器和判别器的实现,以及训练过程的详细解释,并提供了完整代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

0、前言

GAN网络是近两年深度学习领域的新秀,本文和之后的基于Keras的生成对抗网络系列旨在浅显理解传统GAN,分享学习心得。

关于为什么使用Keras进行搭建,而不是如日中天的TensorFlow和pytorch,因为keras是对新手最友好,最简单实用的深度学习框架,因此首先利用keras进行搭建,等充分理解GAN的结构和参数之后,再利用TensorFlow或pytorch进行搭建。

一、GAN结构

普通GAN的结构就是利用全连接层+激活函数,这里不过多介绍。

二、函数代码

2.1 生成器Generator

生成器的目标是输入一行正态分布随机数,生成手写体图片,因此它的输入是一个长度为N的一维的向量,输出一个28,28,1维的图片。
函数代码:

    #定义生成器模型
    #输入:100维向量;输出:28*28图像
    
生成对抗网络GAN)是一种强大的深度学习技术,可以用于许多应用,包括手写体数字识别。在这里,我将向你展示如何使用Python实现手写体数字识别GAN。 首先,需要安装必要的库,包括TensorFlow和Keras。可以使用以下命令安装它们: ``` pip install tensorflow pip install keras ``` 接下来,我们需要准备MNIST数据集。可以使用以下代码加载数据集: ```python from keras.datasets import mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() ``` 现在,我们将创建两个神经网络 - 一个生成器和一个判别器。生成器将随机噪声作为输入并生成手写数字图像,而判别器将接受手写数字图像并输出它们的真假。 ```python from keras.models import Sequential from keras.layers import Dense, Flatten, Reshape, Dropout from keras.layers.convolutional import Conv2DTranspose, Conv2D from keras.layers.normalization import BatchNormalization from keras.layers.advanced_activations import LeakyReLU def create_generator(): generator = Sequential() generator.add(Dense(128 * 7 * 7, input_dim=100, activation=LeakyReLU(0.2))) generator.add(Reshape((7, 7, 128))) generator.add(Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation=LeakyReLU(0.2))) generator.add(BatchNormalization()) generator.add(Conv2DTranspose(1, (3,3), strides=(2,2), padding='same', activation='tanh')) return generator def create_discriminator(): discriminator = Sequential() discriminator.add(Conv2D(64, (3,3), padding='same', input_shape=(28,28,1))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.25)) discriminator.add(Conv2D(128, (3,3), strides=(2,2), padding='same')) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.25)) discriminator.add(Flatten()) discriminator.add(Dense(1, activation='sigmoid')) return discriminator generator = create_generator() discriminator = create_discriminator() ``` 现在,我们将定义GAN模型。我们将训练生成器和判别器,以便它们可以通过相互竞争的方式共同学习并提高性能。 ```python from keras.optimizers import Adam def create_gan(generator, discriminator): discriminator.trainable = False gan = Sequential() gan.add(generator) gan.add(discriminator) gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5)) return gan gan = create_gan(generator, discriminator) ``` 现在我们将定义训练循环: ```python import numpy as np import matplotlib.pyplot as plt def train(generator, discriminator, gan, x_train, epochs=50, batch_size=128): # Rescale -1 to 1 x_train = x_train / 127.5 - 1. x_train = np.expand_dims(x_train, axis=3) real = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # Train discriminator idx = np.random.randint(0, x_train.shape[0], batch_size) real_imgs = x_train[idx] noise = np.random.normal(0, 1, (batch_size, 100)) fake_imgs = generator.predict(noise) d_loss_real = discriminator.train_on_batch(real_imgs, real) d_loss_fake = discriminator.train_on_batch(fake_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Train generator noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = gan.train_on_batch(noise, real) # Plot the progress print("Epoch {}, Discriminator Loss: {}, Generator Loss: {}".format(epoch, d_loss, g_loss)) # Generate some digits to check the progress if epoch % 10 == 0: noise = np.random.normal(0, 1, (25, 100)) gen_imgs = generator.predict(noise) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(5, 5) cnt = 0 for i in range(5): for j in range(5): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() train(generator, discriminator, gan, x_train, epochs=1000, batch_size=128) ``` 在训练之后,我们可以使用生成器来生成一些手写数字图像: ```python noise = np.random.normal(0, 1, (25, 100)) gen_imgs = generator.predict(noise) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(5, 5) cnt = 0 for i in range(5): for j in range(5): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() ``` 这就是用Python实现手写体数字识别GAN的方法。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wendy_ya

您的鼓励将是我创作的最大动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值