做这个试验,主要是为了自己敲代码巩固对kNN和python一些基本语法的认识。
1.kNN的核心是分类算法,代码主要借鉴自《机器学习实战》第二章源码。
2.从MNIST导出数据形成自己需要的格式,借鉴了python-mnist0.5,参看
https://github.com/sorki/python-mnist/blob/master/mnist/loader.py
3.对numpy的使用,贯穿了各个代码段,主要是对一维、二维数组的一些处理,比如tile()函数来由小的数据结构扩张成指定的二维数组。
实测了MAC上运行python2.7,后来在widows上运行了python3.5,期间针对python3的语法做了些修改,如果您要测试代码,请自行下载MNIST(http://yann.lecun.com/exdb/mnist/),把四个文件解压缩后存到不含中文名的路径下,注意在
def handwritingClassTest():
函数中更新自己实际的路径。
实测,近邻数为3时,误差率2.86%。
上代码:
kNN.py
from numpy import *
import operator
import loadMINST
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key =operator.itemgetter(1), reverse=True)
#print sortedClassCount
return sortedClassCount[0][0]
def handwritingClassTest():
trainingMat, hwLabels, size = loadMINST.load('D:/electrical/temp/MNIST_data/train-images-idx3-ubyte','D:/electrical/temp/MNIST_data/train-labels-idx1-ubyte')
dataUnderTest, classNumStr, size = loadMINST.load('D:/electrical/temp\MNIST_data/t10k-images-idx3-ubyte','D:/electrical/temp/MNIST_data/t10k-labels-idx1-ubyte')
errorCount = 0.0
for i in range(size):
classifierResult = classify0(dataUnderTest[i,:], trainingMat, hwLabels, 3)
print("the NO.%d classifier came back with: %d, the real answer is: %d, error count is: %d" % (i, classifierResult, classNumStr[i], errorCount))
if (classifierResult != classNumStr[i]): errorCount +=+ 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount/float(size)))
loadMINST.py
import struct
import array
import numpy
#https://github.com/sorki/python-mnist/blob/master/mnist/loader.py
def load(path_img, path_lbl):
labels = []
with open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got {}'.format(magic))
label_data = array.array("B", file.read())
for i in range(size):
labels.append(label_data[i])
with open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051,'
'got {}'.format(magic))
image_data = array.array("B", file.read())
images = numpy.zeros((size, rows * cols))
for i in range(size):
if((i%2000==0) or (i+1==size)):
print("%d numbers imported" % (i))
images[i, :] = image_data[i * rows * cols: (i + 1) * rows * cols]
return images, labels, size
test.py
import kNN
kNN.handwritingClassTest()
编辑器使用的pycharm,在内置的终端敲下:python test.py即可。