手写数字识别 neuralnet_mnist.py 代码解读 来自GPT

这段代码是一个手写数字识别程序,使用的是一个简单的神经网络模型。通过加载训练好的模型(sample_weight.pkl),它对 MNIST 测试集进行预测,并计算模型的准确率。接下来,我会逐步解析这段代码的主要部分。

1. 导入所需库

import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
  • sysos 用于设置路径,使得代码能够导入父目录中的文件。
  • numpy 用于处理数据,特别是矩阵运算。
  • pickle 用于加载保存的神经网络模型权重。
  • load_mnist 用于加载 MNIST 数据集。
  • sigmoidsoftmax 是常用的激活函数,sigmoid 用于隐藏层,softmax 用于输出层。

2. 获取数据

def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test
  • load_mnist 函数加载 MNIST 数据集。
    • normalize=True:将图像像素值归一化到 [0, 1] 之间。
    • flatten=True:将图像展平成一个向量,大小为 784(28x28)。
    • one_hot_label=False:标签为整数形式,而不是 one-hot 编码形式。
  • 返回测试集的输入数据 x_test 和标签 t_test

3. 初始化神经网络

def init_network():
    with open("../dataset/sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network
  • sample_weight.pkl 文件加载训练好的神经网络权重(包括权重矩阵 W1, W2, W3 和偏置 b1, b2, b3)。
  • 这些权重是用来对输入数据进行前向传播的。

4. 前向传播与预测

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y
  • 通过前向传播过程,计算每个输入 x 的输出。
    • a1, a2, a3 是每层的加权输入。
    • z1, z2 是经过激活函数(sigmoid)处理后的结果。
    • y 是通过 softmax 函数得到的输出概率。

5. 计算准确率

x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p = np.argmax(y)  # 获取概率最高的元素的索引
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
  • 获取测试数据 x 和标签 t
  • network 是加载的训练好的神经网络。
  • accuracy_cnt 变量用于记录正确预测的数量。
  • 对每一张图像进行预测:
    • 通过 predict 函数获得预测结果 y
    • 使用 np.argmax(y) 获取最大概率对应的数字。
    • 如果预测值 p 与真实标签 t[i] 相等,增加 accuracy_cnt
  • 最后,输出准确率,即正确预测的数量除以测试集的总样本数。

6. 问题和改进建议

  1. 加载模型路径检查: 确保 ../dataset/sample_weight.pkl 文件存在并包含有效的训练权重文件。如果文件路径错误或文件损坏,模型将无法加载,导致预测失败。

  2. 性能改进: 如果想提高代码效率,可以将每个样本的预测批量处理,而不是逐个处理。例如,可以使用 NumPy 的矩阵运算来加速预测过程。

  3. 数据集问题: 如果你想使用自己的数据集来进行测试,需要确保数据格式与 MNIST 数据集一致(例如图片大小和颜色通道)。对于自定义图片,建议按如下方式预处理:

    • 将图片调整为28x28大小。
    • 灰度化(如果是彩色图片)。
    • 将像素值归一化到 [0, 1]。
  4. 验证训练效果: 如果在测试集上的准确率不高,可能是因为训练数据不足或者网络架构有问题。可以尝试调试训练过程,或者在更大的数据集上重新训练。

希望这段解析有助于你理解代码的工作原理。如果有更多问题,或者需要进一步的帮助,欢迎继续提问!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阿崽meitoufa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值