
import numpy as np
import time
def data_load(filename):
'''
:param filename:
:return: dataArr,labelArr
'''
print('start read file')
dataArr,labelArr = [],[]
with open(filename,'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(',')
if int(line[0]) >= 5:
labelArr.append(1)
else:
labelArr.append(-1)
dataArr.append([int(num)/255 for num in line[1:]])
print('End')
return dataArr,labelArr
def cal_distance(x1,x2):
'''
:param x1:
:param x2:
:return: 两点之间的欧式距离
'''
x1 = np.array(x1)
x2 = np.array(x2)
return np.sqrt(np.sum(np.square(x1-x2)))
def knn(traindata,trainlabel,x,k):
'''
:param traindata: 训练集的数据
:param trainlabel: 训练集的标签
:param x: 目标点
:param k: k近邻的k
:return: 目标点的预测标签
'''
dis_list = []
for i in range(len(traindata)):
dis_list.append(cal_distance(traindata[i],x))
klist = np.argsort(np.array(dis_list))[:k]
klist_label = [trainlabel[key] for key in klist]
klabel = [klist_label.count(key) for key in list(set(klist_label))]
return list(set(klist_label))[np.argsort(np.array(klabel))[-1]]
def model_test(traindata,trainlabel,testdata,testlabel,k):
sum = len(traindata)
rigSum = 0
for i in range(200):
print('iter:{}'.format(i))
y_pred = knn(traindata,trainlabel,testdata[i],k)
if testlabel[i] == y_pred:
rigSum += 1
return rigSum/200*100
if __name__ == '__main__':
k = 25
traindata,trainlabel = data_load('dataset/mnist_train.csv')
testdata,testlabel = data_load('dataset/mnist_test.csv')
start = time.time()
print(model_test(traindata,trainlabel,testdata,testlabel,k))
end = time.time()
print('训练时间:{}'.format(end-start))
