mnist数据集(附python代码)

文章分别展示了如何使用Tensorflow进行MNIST数据集的预测,以及在Pytorch中利用MNIST训练生成对抗网络(GAN)。Tensorflow部分创建了一个简单的分类器,加载或训练模型并进行预测。Pytorch部分则详细阐述了GAN的训练过程,包括生成器和判别器的构建,损失函数的计算,并给出了训练结果的展示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. Tensorflow利用MNIST数据集预测示例

使用 tensorflow.keras 接口,组装神经网络层次,训练并预测

import tensorflow as tf
from tensorflow.keras.datasets import mnist
import os


class Classifier:
    def __init__(self) -> None:
        self.x_train = None
        self.x_test = None
        self.model_file = '/tmp/tensor.keras.model'
        self.has_load = False
        self.has_model_load = False

    def load_or_train(self):
        if self.has_load:
            return {'err': 0}
        try:

            if os.path.exists(self.model_file):
                self.__load_model()
            else:
                self.__train_model()
            self.has_load = True
            return {'err': 0}
        except Exception as e:
            return {'err': 1, 'msg': str(e)}

    def __create_model(self):
        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation='softmax')
        ])

        model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        return model

    def __train_model(self):
        if self.has_model_load:
            return
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0
        self.model = self.__create_model()
        self.model.fit(x_train, y_train, epochs=5)
        self.model.evaluate(x_test,  y_test, verbose=2)
        self.model.save(self.model_file)
        self.has_model_load = True

    def __load_model(self):
        if self.has_model_load:
            return
        self.model = tf.keras.models.load_model(self.model_file)
        self.model.summary()
        self.has_model_load = True

    def predict(self, test_images):
        if not self.has_load:
            return {'err': 1, 'msg': "分类器还没加载"}
        result = self.model.predict(test_images)
        return {'err': 0, 'result': result}

if __name__ == "__main__":

    cl = Classifier()

    # 加载或训练模型
    ret = cl.load_or_train()
    if ret['err'] != 0:
        print(ret['msg'])
    else:
        # 测试数据
        (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
        test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

        # 预测
        ret = cl.predict(test_images)
        if ret['err'] == 0:
            print('预测结果:', ret['result'])
                                                                                    
        else:
            print('预测失败:{}', ret['msg'])

2. Pytorch利用MNIST数据集训练生成对抗网络(GAN)

2.1训练

epoch=1000

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
 
 
class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
 
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img
 
 
class Discriminator(nn.Module):  # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
 
    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity
 
 
def gen_img_plot(model, test_input):
    pred = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((pred[i] + 1) / 2)
        plt.axis('off')
    plt.show(block=False)
    plt.pause(3)  # 停留0.5s
    plt.close()
 
 
# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000
 
# 数据集载入和数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
 
# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 
# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)
 
# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
 
# 开始训练模型
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)
 
        # 训练判别器
        z = torch.FloatTensor(batch_size, latent_dim).to(device)
        z.data.normal_(0, 1)
        fake_imgs = generator(z)  # 生成器生成假的图片
 
        real_labels = torch.full((batch_size, 1), 1.0).to(device)
        fake_labels = torch.full((batch_size, 1), 0.0).to(device)
 
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
 
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
 
        # 训练生成器
        z.data.normal_(0, 1)
        fake_imgs = generator(z)
 
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()
 
        torch.save(generator.state_dict(), "Generator_mnist.pth")
 
    print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")
 
# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

2.2测试

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random
 
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
 
 
class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
 
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img
 
 
# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声
 
model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()
 
pred = np.squeeze(model(test_data).detach().cpu().numpy())
 
for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow((pred[i] + 1) / 2)
    plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

2.3结果

在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示
在这里插入图片描述

2.4GAN的损失函数曲线

一开始训练时,gan的损失函数的曲线是类似这样的,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。
在这里插入图片描述

在这里插入图片描述
效果
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猛码Memmat

欢迎支持,随缘打赏 ~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值