python实现knn
运行的结果为
实现代码如下:
import numpy as np
import pickle as pic
import heapq as hp
from collections import Counter
class Solution:
path_train=''
train_data=[]
train_label=[]
test_data=[]
test_label=[]
test_record={}
def readData(self,file):
with open(file,'rb') as fo:
dict=pic.load(fo,encoding='bytes')
return dict
def Init(self,path1,path2):
self.path_train=path1
self.test_record=self.readData(path2)#读取测试数据
self.test_data=np.array(self.test_record[b'data'])
self.test_label=np.array(self.test_record[b'labels'])
for i in range(1,6):#读取训练数据
path_temp=self.path_train+str(i)
dict_temp=self.readData(path_temp)
if(i==1):
self.train_data=dict_temp[b'data']
self.train_label=dict_temp[b'labels']
else:
self.train_data=np.append(self.train_data,dict_temp[b'data'],axis=0)
self.train_label+=dict_temp[b'labels']
self.train_data=np.array(self.train_data)
self.train_label=np.array(self.train_label)
def finalTest(self,k):#距离计算L1 norm 使用的是knn
total=np.size(self.test_label)
res_l=[]
for i in range(0,total):
data_t=self.test_data[i]
temp=np.sum(abs(self.train_data-data_t),axis=1)
ind_temp=hp.nsmallest(k,range(len(temp)), temp.take)
lab_temp=self.train_label[ind_temp]
counts = Counter(lab_temp)
t=counts.most_common(k)
res_l.append(t[0][0])
print("acurracy: "+str(np.sum(res_l==self.test_label)/total))
a=Solution()
a.Init("D:\\data\\cifar-10-batches-py\\data_batch_","D:\\data\\cifar-10-batches-py\\test_batch")
for i in range(1,10,3):
a.finalTest(i)