训练生成手写体数字 对抗神经网络

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]

# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

# 生成器
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    return model

# 判别器
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model

# 编译判别器
discriminator = build_discriminator()
discriminator
生成对抗网络(GAN)是一种强大的深度学习技术,可以用于许多应用,包括手写体数字识别。在这里,我将向你展示如何使用Python实现手写体数字识别GAN。 首先,需要安装必要的库,包括TensorFlow和Keras。可以使用以下命令安装它们: ``` pip install tensorflow pip install keras ``` 接下来,我们需要准备MNIST数据集。可以使用以下代码加载数据集: ```python from keras.datasets import mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() ``` 现在,我们将创建两个神经网络 - 一个生成器和一个判别器。生成器将随机噪声作为输入并生成手写数字图像,而判别器将接受手写数字图像并输出它们的真假。 ```python from keras.models import Sequential from keras.layers import Dense, Flatten, Reshape, Dropout from keras.layers.convolutional import Conv2DTranspose, Conv2D from keras.layers.normalization import BatchNormalization from keras.layers.advanced_activations import LeakyReLU def create_generator(): generator = Sequential() generator.add(Dense(128 * 7 * 7, input_dim=100, activation=LeakyReLU(0.2))) generator.add(Reshape((7, 7, 128))) generator.add(Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation=LeakyReLU(0.2))) generator.add(BatchNormalization()) generator.add(Conv2DTranspose(1, (3,3), strides=(2,2), padding='same', activation='tanh')) return generator def create_discriminator(): discriminator = Sequential() discriminator.add(Conv2D(64, (3,3), padding='same', input_shape=(28,28,1))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.25)) discriminator.add(Conv2D(128, (3,3), strides=(2,2), padding='same')) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.25)) discriminator.add(Flatten()) discriminator.add(Dense(1, activation='sigmoid')) return discriminator generator = create_generator() discriminator = create_discriminator() ``` 现在,我们将定义GAN模型。我们将训练生成器和判别器,以便它们可以通过相互竞争的方式共同学习并提高性能。 ```python from keras.optimizers import Adam def create_gan(generator, discriminator): discriminator.trainable = False gan = Sequential() gan.add(generator) gan.add(discriminator) gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5)) return gan gan = create_gan(generator, discriminator) ``` 现在我们将定义训练循环: ```python import numpy as np import matplotlib.pyplot as plt def train(generator, discriminator, gan, x_train, epochs=50, batch_size=128): # Rescale -1 to 1 x_train = x_train / 127.5 - 1. x_train = np.expand_dims(x_train, axis=3) real = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # Train discriminator idx = np.random.randint(0, x_train.shape[0], batch_size) real_imgs = x_train[idx] noise = np.random.normal(0, 1, (batch_size, 100)) fake_imgs = generator.predict(noise) d_loss_real = discriminator.train_on_batch(real_imgs, real) d_loss_fake = discriminator.train_on_batch(fake_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Train generator noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = gan.train_on_batch(noise, real) # Plot the progress print("Epoch {}, Discriminator Loss: {}, Generator Loss: {}".format(epoch, d_loss, g_loss)) # Generate some digits to check the progress if epoch % 10 == 0: noise = np.random.normal(0, 1, (25, 100)) gen_imgs = generator.predict(noise) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(5, 5) cnt = 0 for i in range(5): for j in range(5): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() train(generator, discriminator, gan, x_train, epochs=1000, batch_size=128) ``` 在训练之后,我们可以使用生成器来生成一些手写数字图像: ```python noise = np.random.normal(0, 1, (25, 100)) gen_imgs = generator.predict(noise) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(5, 5) cnt = 0 for i in range(5): for j in range(5): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() ``` 这就是用Python实现手写体数字识别GAN的方法。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值