这段代码是一个手写数字识别程序,使用的是一个简单的神经网络模型。通过加载训练好的模型(sample_weight.pkl
),它对 MNIST 测试集进行预测,并计算模型的准确率。接下来,我会逐步解析这段代码的主要部分。
1. 导入所需库
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
sys
和os
用于设置路径,使得代码能够导入父目录中的文件。numpy
用于处理数据,特别是矩阵运算。pickle
用于加载保存的神经网络模型权重。load_mnist
用于加载 MNIST 数据集。sigmoid
和softmax
是常用的激活函数,sigmoid
用于隐藏层,softmax
用于输出层。
2. 获取数据
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
load_mnist
函数加载 MNIST 数据集。normalize=True
:将图像像素值归一化到 [0, 1] 之间。flatten=True
:将图像展平成一个向量,大小为 784(28x28)。one_hot_label=False
:标签为整数形式,而不是 one-hot 编码形式。
- 返回测试集的输入数据
x_test
和标签t_test
。
3. 初始化神经网络
def init_network():
with open("../dataset/sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
- 从
sample_weight.pkl
文件加载训练好的神经网络权重(包括权重矩阵W1
,W2
,W3
和偏置b1
,b2
,b3
)。 - 这些权重是用来对输入数据进行前向传播的。
4. 前向传播与预测
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
- 通过前向传播过程,计算每个输入
x
的输出。a1
,a2
,a3
是每层的加权输入。z1
,z2
是经过激活函数(sigmoid
)处理后的结果。y
是通过softmax
函数得到的输出概率。
5. 计算准确率
x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
p = np.argmax(y) # 获取概率最高的元素的索引
if p == t[i]:
accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
- 获取测试数据
x
和标签t
。 network
是加载的训练好的神经网络。accuracy_cnt
变量用于记录正确预测的数量。- 对每一张图像进行预测:
- 通过
predict
函数获得预测结果y
。 - 使用
np.argmax(y)
获取最大概率对应的数字。 - 如果预测值
p
与真实标签t[i]
相等,增加accuracy_cnt
。
- 通过
- 最后,输出准确率,即正确预测的数量除以测试集的总样本数。
6. 问题和改进建议
-
加载模型路径检查: 确保
../dataset/sample_weight.pkl
文件存在并包含有效的训练权重文件。如果文件路径错误或文件损坏,模型将无法加载,导致预测失败。 -
性能改进: 如果想提高代码效率,可以将每个样本的预测批量处理,而不是逐个处理。例如,可以使用 NumPy 的矩阵运算来加速预测过程。
-
数据集问题: 如果你想使用自己的数据集来进行测试,需要确保数据格式与 MNIST 数据集一致(例如图片大小和颜色通道)。对于自定义图片,建议按如下方式预处理:
- 将图片调整为28x28大小。
- 灰度化(如果是彩色图片)。
- 将像素值归一化到 [0, 1]。
-
验证训练效果: 如果在测试集上的准确率不高,可能是因为训练数据不足或者网络架构有问题。可以尝试调试训练过程,或者在更大的数据集上重新训练。
希望这段解析有助于你理解代码的工作原理。如果有更多问题,或者需要进一步的帮助,欢迎继续提问!