因为嫌线性时间扫描的方法太慢了,写了两天写了kd树,结果发现,kd树一样慢!!!
因为mnist特征长度为784,训练数据为60000,在kd树建造时最大深度只有14(0开始计算),子空间的判断终止条件是测试点与父节点在单一特征上的距离要大于队列中存储的距离最大值,两个784维的特征点的距离怎么可能比单一维度的距离要小呢,所以这个条件是几乎不可能满足的,因而每次测试时实际又遍历了整棵树,,,心累,,,
代码如下,心累不想写说明,,,
#k近邻算法
#算法过程:1.根据距离度量,在训练集找到与x最邻近的k个点,涵盖这k个点的邻域为Nk(x)
# 2.在Nk(x)中根据分类决策规则(多数表决)决定x的类别y:y=argmax(sigama)I(yi=cj)
#k近邻算法没有学习过程
#因为太慢了,只测了100个点
#k=25,accuracy=98%,time=85.1s
import numpy as np
import time
import operator
#加载数据过程类似
def loadData(filename):
print('start to read data')
image_array=[]
label_array=[]
file=open(filename,'r')
for line in file.readlines():
#数据格式:每个样本一行,以','为间隔,以'\n'结尾的字符串,首字符为类别,后面跟28*28个像素值
curline=line.strip().split(',')
#strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
label_array.append(int(curline[0]))
image=[int(num) for num in curline[1:]]
#这里并没有用除255,只在意它的相对大小,整数计算比浮点数快。
image_array.append(image)
return image_array,label_array
#距离度量
def eucl_dist(x1,x2):
result=np.sqrt(np.sum(np.square(x1-x2)))
#print('result',result)
return result
#线性扫描,实在是太慢了
'''
采用线性扫描的方式,因为构造kd树太复杂
def get_closest(image_array,label_array,x,k):
distance_list=[0]*len(image_array)
for i in range