本文利用KNN算法,实现手写体数字的识别
knn算法概述
knn算法又称为k近邻分类(k-nearest neighbor classification)算法,核心思想:给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个相邻点的信息进行预测。
通常,在分类任务中可使用"投票法",即将这k个样本中出现最多的类别标记作为预测结果;在回归任务中可使用“平均法”,即将这k个样本的实际值输入标记的平均值作为预测结果;还可以基于距离远近进行加权平均或者加权投票,距离越近的样本权重越大。
下面通过一个简单的例子说明一下:
如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。
由此也说明了KNN算法的结果很大程度取决于K的选择。
在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:
接下来对KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
算法实战:利用KNN算法实现手写体数字的识别
1.手写体数字准备
将0~9,10个数据的手写体图片转化成文本存储(图片像素颜色,黑色用数字1替代,白色用0替代,图片像素大小为32*32),尽可能多的提供训练数据(将图片转化为文本,需要用到PIL模块,可参考博文:https://blog.youkuaiyun.com/d1240673769/article/details/77150964)。
如下图,手写体数字0转化的结果:
训练集:
2.用同样的方法提供一批测试样本集
注:文中用到的数据集文件可在这里进行下载 https://download.youkuaiyun.com/download/d1240673769/20813109
3.代码示例
import numpy as np
import operator
import os
#KNN算法
def knn(k,testdata,traindata,labels):#(k,测试集,训练集,分类)
traindatasize=traindata.shape[0]#行数
#测试样本(一维)和训练集样本数不一样,因此需要将测试集样本数扩展成和训练集一样多(因为需要求这个样本和所有训练样本之间的距离)
#从行方向扩展 tile(a,(size,1))
dif=np.tile(testdata,(traindatasize,1))-traindata
#计算距离
sqdif=dif**2
sumsqdif=sqdif.sum(axis=1)
distance=sumsqdif**0.5
sortdistance=distance.argsort()#从小到大排列,结果返回元素位置
count={}
for i in range(k):
vote=labels[sortdistance[i]]
#统计每一类列样本的数量
count[vote]=count.get(vote,0)+1
sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
#取包含样本数量最多的那一类别
return sortcount[0][0]
#加载数据,将文件转化为数组形式
def datatoarray(filename):
arr=[]
fh=open(filename)
for i in range(32):
thisline=fh.readline()
for j in range(32):
arr.append(int(thisline[j]))
return arr
#获取文件的lable
def get_labels(filename):
label=int(filename.split('_')[0])
return label
#建立训练数据
def train_data():
labels=[]
trainlist=os.listdir('traindata/')
num=len(trainlist)
#长度1024(列),每一行存储一个文件
#用一个数组存储所有训练数据,行:文件总数,列:1024
trainarr=np.zeros((num,1024))
for i in range(num):
thisfile=trainlist[i]
labels.append(get_labels(thisfile))
trainarr[i,:]=datatoarray("traindata/"+thisfile)
return trainarr,labels
#用测试数据调用KNN算法进行测试
def datatest():
a=[]#准确结果
b=[]#预测结果
traindata,labels=train_data()
testlist=os.listdir('testdata/')
fh=open('result_knn.csv','a')
for test in testlist:
testfile='testdata/'+test
testdata=datatoarray(testfile)
result=knn(3,testdata,traindata,labels)
#将预测结果存在文本中
fh.write(test+'-----------'+str(result)+'\n')
a.append(int(test.split('_')[0]))
b.append(int(result))
fh.close()
return a,b
if __name__=='__main__':
a,b=datatest()
num=0
for i in range(len(a)):
if(a[i]==b[i]):
num+=1
else:
print("预测失误:",a[i],"预测为",b[i])
print("测试样本数为:",len(a))
print("预测成功数为:",num)
print("模型准确率为:",num/len(a))
运行结果