K邻算法实现手写字体的识别,python

该博客主要介绍了如何运用K近邻(KNN)算法进行图像分类。首先,它导入了必要的库,如`sklearn`和`tensorflow`,并加载了Fashion_MNIST数据集。然后,通过`train_test_split`对数据进行了划分,接着使用`StandardScaler`进行预处理。之后,定义了一个`GridSearchCV`来寻找最佳的K值和距离度量参数p,并用找到的最佳参数训练KNN模型。最后,展示了一部分预测结果,并计算了模型的准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# -*- coding: utf-8 -*-

"K邻算法实现"
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
import matplotlib.pyplot as plt
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

from sklearn.preprocessing import StandardScaler
def get_data():
    data = load_digits()
    x, y = data.data, data.target
    print(x.shape)
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size= 0.3, random_state= 20)
    model = StandardScaler()
    train_x = model.fit_transform(train_x)
    test_x = model.transform(test_x)
    return train_x, test_x, train_y, test_y

def model_fit(x, y):
    paras = {'n_neighbors': [5, 6, 7, 8, 9, 10], 'p': [1, 2]}
    model = KNeighborsClassifier()
    gs = GridSearchCV(model, paras, verbose=2, cv= 5)
    gs.fit(x, y)
    print('最佳模型:', gs.best_params_, '准确率:',gs.best_score_)
    
def train(train_x, test_x, train_y, test_y):
    model = KNeighborsClassifier(5, p= 1)
    model.fit(train_x, train_y)
    pre_y = model.predict(test_x)
    show_img(test_x, pre_y, test_y)
    print(model.score(test_x, test_y))
def show_img(test_x, pre_x, test_y):
    num_row = 5
    num_col = 3
    
    plt.figure(figsize = (num_row, num_col* 2))
    plt.grid(False)
    for i in range(num_row * num_col):
        plt.subplot(num_row, num_col, i + 1)
        show_num_img(test_x[i], pre_x[i], test_y[i])
    plt.tight_layout()    
    plt.show()
def show_num_img(img, pre, y):
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img.reshape(8, 8), cmap=plt.cm.binary)
    color = 'green'
    if pre != y:
        color = 'red'
    plt.xlabel("{0}({1})".format(pre, y, color=color))
if __name__ == "__main__":
    train_x, test_x, train_y, test_y = get_data()
    #model_fit(x, y)

    train(train_x, test_x, train_y, test_y)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东哥aigc

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值