- 使用
PIL
(Python Imaging Library) 加载本地图片,并进行相应的预处理。 - 用加载的网络对每张图片进行预测,输出预测结果。
import numpy as np
import pickle
from PIL import Image
import matplotlib.pyplot as plt
from common.functions import sigmoid, softmax
# 初始化网络
def init_network():
with open("../dataset/sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
# 预处理图像
def preprocess_image(image_path):
# 打开图像并转换为灰度
img = Image.open(image_path).convert('L')
# 将图像调整为 28x28 像素
img = img.resize((28, 28))
# 转换为 numpy 数组
img_array = np.array(img)
# 归一化处理(将像素值从 [0, 255] 映射到 [0, 1])
img_array = img_array / 255.0
# 将图像展平为 1D 数组,适应模型输入格式
img_array = img_array.flatten()
return img_array
# 前向传播和预测
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
# 前向传播
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
# 获取测试图片路径
image_path = '..\dataset\digit1.jpg' # 替换为你的图片路径
# 加载网络
network = init_network()
# 预处理并预测图像
img_array = preprocess_image(image_path)
y = predict(network, img_array)
predicted_digit = np.argmax(y)
# 显示图像和预测结果
img = Image.open(image_path).convert('L')
plt.imshow(img, cmap='gray')
plt.title(f"Prediction: {predicted_digit}")
plt.axis('off')
plt.show()
# 输出预测的结果
print(f"Predicted Digit: {predicted_digit}")
图片颜色反转之后只对了一个