载入数据
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn import datasets
digits = datasets.load_digits()
得到数据集中的数据
X = digits.data
y = digits.target
可视化一下,二进制图像显示
随便选一个样本
some_digit = X[111]
some_digit_image = some_digit.reshape(8,8)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary)
plt.show()

y[111]
输出为4。
调用sklearn库
数据预处理
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
创建kNN模型并训练
from sklearn.neighbors import KNeighborsClassifier
kNN_classifier = KNeighborsClassifier(n_neighbors=3)
kNN_classifier.fit(X_train, y_train)
测试模型准确率
kNN_classifier.score(X_test,y_test)
输出
0.9861111111111112
本文通过使用sklearn库中的手写数字数据集,演示了如何加载数据、进行数据预处理,并应用k近邻(kNN)算法进行分类预测。通过可视化部分样本,读者可以直观了解数据特点。在划分训练集和测试集后,kNN模型达到了98.6%的准确率。
3万+

被折叠的 条评论
为什么被折叠?



