import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import keras
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
# 载入 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# 将图像数据展开成一维向量
train_images_flat = train_images.reshape(train_images.shape[0], -1)
test_images_flat = test_images.reshape(test_images.shape[0], -1)
# 初始化 KNN 分类器
knn_classifier = KNeighborsClassifier(n_neighbors=5)
# 训练 KNN 模型
knn_classifier.fit(train_images_flat, train_labels)
# 在测试集上进行预测
predicted_labels = knn_classifier.predict(test_images_flat)
# 计算准确率
accuracy = accuracy_score(test_labels, predicted_labels)
print('测试集准确率:', accuracy)
# 随机选择一张测试图像进行展示和预测
index = np.random.randint(0, test_images.shape[0])
sample_image = test_images[index]
sample_label = test_labels[index]
# 展示图像
plt.title(f'真实标签: {sample_label}, 预测标签: {predicted_labels[index]}')
plt.imshow(sample_image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()
