tensorflow Examples:<1>实现Nearest Neighbor

本文详细介绍了如何使用TensorFlow实现Nearest Neighbor算法,通过实例演示了在MNIST数据集上的应用,帮助理解近邻搜索在深度学习中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow 实现Nearest Neighbor

#coding: utf-8

#Env:
#python         2.7
#tensorflow     1.1.0
#numpy          1.12.1


import numpy as np
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

#导入mnist数据集
mnist = input_data.read_data_sets('data2/', one_hot=True)

#重置图(这个是为了用于多次运行)
tf.reset_default_graph()

#使用训练集数目为5000条
#使用验证集(测试集)数目为300
Xtr, Ytr = mnist.train.next_batch(5000)
Xte, Yte = mnist.test.next_batch(300)

xtr = tf.placeholder('float', [None, 784])
xte = tf.placeholder('float', [784])

#计算各个对应位置的距离(减法使用广播形式)
#底下俩作用相同
distance = tf.reduce_sum(tf.abs(tf.subtract(xtr, xte)), reduction_indices=1)
#distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

#寻找距离最近(即最相似的行所在位置)
pred = tf.arg_min(distance, 0)

accuracy = 0.

#初始化
init = tf.global_variables_initializer()


with tf.Session() as sess:
    sess.run(init)

    for i in range(len(Xte)):
        #计算最相近的所在行位置
        nn_index = sess.run(pred, feed_dict={xtr:Xtr, xte: Xte[i, :]})

        #取出测试集上最相近行对应的label与真是label对比
        print 'Test', i, 'Prediction: ', np.argmax(Ytr[nn_index]),\
               'True Class: ', np.argmax(Yte[i])
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
    print 'Done!'
    print 'Accuracy: ', accuracy
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值