基于TensorFlow的mnist数据集的最近邻算法实现代码

本文介绍了一种基于K近邻算法的手写数字识别方法,使用TensorFlow框架和MNIST数据集进行实现。通过计算训练样本与测试样本之间的距离来找到最近邻样本,并据此进行分类预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
#导入mnist数据集
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
 
#5000样本作为训练集 每一个训练和测试样本的数据都是1*784的矩阵,标签是1*10的矩阵并且采用one-hot编码
X_train , Y_train = mnist.train.next_batch(5000)
#600样本作为测试集
X_test , Y_test = mnist.test.next_batch(200)
 
#创建占位符 None代表将来可以选多个样本的,如:[60,784]代表选取60个样本,每一个样本的是784列
x_train = tf.placeholder("float",[None,784])
x_test = tf.placeholder("float",[784])#x_test代表只用一个样本
#计算距离
#tf.negative(-2)的输出的结果是2
#tf.negative(2)的输出的结果是-2
#reduce_sum的参数reduction_indices解释见下图
#计算一个测试样本和训练样本的的距离
#distance 返回的是N个训练样本的和单个测试样本的距离
distance = tf.reduce_sum(tf.abs(tf.add(x_train,tf.negative(x_test))),reduction_indices=1)
#的到距离最短的训练样本的索引
prediction = tf.arg_min(distance,0)
accuracy = 0
#初始化变量
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
 
    for i in range(len(X_test)):#遍历整个测试集,每次用一个的测试样本和整个训练样本的做距离运算
        #获得最近邻
        # 获得训练集中与本次参与运算的测试样本最近的样本编号
        nn_index = sess.run(prediction,feed_dict={x_train:X_train,x_test:X_test[i,:]})
        #打印样本编号的预测类别和准确类别
        print("Test",i,"Prediction:",np.argmax(Y_train[nn_index]),"True Class:",np.argmax(Y_test[i]))
        if np.argmax(Y_train[nn_index]) == np.argmax(Y_test[i]):
            #如果预测正确。更新准确率
            accuracy += 1./len(X_test)
    print("完成!")

    print("准确率:",accuracy)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值