import time
import numpy as np
def loaddata(fileName):
'''
加载数据集
:param fileName: 要加载的数据集路径
:return:
'''
print('start to read file')
dataArr = []
labelArr = []
fr = open(fileName, 'r')
for line in fr.readlines():
curline = line.strip().split(',')
dataArr.append([int(num) for num in curline[1:]])
labelArr.append(int(curline[0]))
return dataArr, labelArr
def calDist(x1, x2):
'''
计算两个样本之间得距离(一般情况使用欧氏距离)
:param x1:向量1
:param x2:向量2
:return:两向量之间得距离
'''
return np.sqrt(np.sum(np.square(x1 - x2)))
def getClosest(trainDataMat, trainLabelMat, x, topK):
'''
预测样本x的标记。
获取方式通过找到与样本x最近的topK个点,并查看它们的标签。查找里面占某类标签最多的那类标签
:param trainDataMat:训练集数据集
:param trainLabelMat:训练集标签集
:param x:要预测的样本x
:param topK:选择参考最邻近样本的数目(样本数目的选择关系到正确率,详看3.2.3 K值的选择)
:return:预测的标记
'''
distList = [0] * len(trainLabelMat)
for i in range(len(trainDataMat)):
x1 = trainDataMat[i]
curDist = calDist(x1, x)
distList[i] = curDist
topKlist = np.argsort(np.array(distList))[:topK]
labelList = [0] * 10
for index in topKlist:
labelList[int(trainLabelMat[index])] += 1
return labelList.index(max(labelList))
def model_test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):
'''
测试正确率
:param trainDataArr:训练集数据集
:param trainLabelArr: 训练集标记
:param testDataArr: 测试集数据集
:param testLabelArr: 测试集标记
:param topK: 选择多少个邻近点参考
:return: 正确率
'''
print('start test')
trainDataMat = np.mat(trainDataArr);
trainLabelMat = np.mat(trainLabelArr).T
testDataMat = np.mat(testDataArr);
testLabelMat = np.mat(testLabelArr).T
errorCnt = 0
for i in range(200):
print('test %d:%d' % (i, 200))
x = testDataMat[i]
y = getClosest(trainDataMat, trainLabelMat, x, topK)
if y != testLabelMat[i]:
errorCnt += 1
return 1 - (errorCnt / 200)
if __name__ == '__main__':
start = time.time()
trainDateArr, trainLabelArr = loaddata('Mnist/mnist_train.csv')
testDataArr, testLabelArr = loaddata('Mnist/mnist_test.csv')
accuracy = model_test(trainDateArr, trainLabelArr, testDataArr, testLabelArr, 25)
print('accuracy is:%d' % (accuracy * 100), '%')
end = time.time()
print('time span:', end - start)
accuracy is:97 %
time span: 465.8361213207245