__author__ = '糖衣豆豆' from numpy import * from os import listdir import operator #从列方向扩展 #tile(a,(size,1)) #实现KNN算法,需要指定k,需要测试数据集,需要训练数据集,类别名(标签), def knn(k,testdata,traindata,labels): #通过shape获得行数 traindatasize=traindata.shape[0] #扩展testdata的维数,tile函数可以扩展testdata和traindata相同的行数,然后和traindata的向量相减计算测试机和训练集的差值 dif=tile(testdata,(traindatasize,1))-traindata #计算差值的平方 sqdif=dif**2 #计算平方和,每一行的各列求和,axis=1每一行的各列求和 sumsqdif=sqdif.sum(axis=1) #开方 distance=sumsqdif**0.5 #排序 sortdistance=distance.argsort() #空字典 count={} #选择距离最短的k for i in range(0,k): #获取类别,下标决定属于哪一类 vote=labels[sortdistance[i]] #整理为一定格式,得到类别vote,每出现一次统计一次 count[vote]=count.get(vote,0)+1 #取出最多的类别,reverse=True表示降序 sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True) return sortcount[0][0] #图片处理 #先将图片转为固定宽高,比如32*32,然后再转为文本 ''' from PIL import Image im=Image.open("~/Downloads/123.png") fh=open("~/Downloads/123_txt","a") width=im.size[0] height=im.size[1] #k=im.getpixel((1,9)) #print(k) for i in range(0,width): for j in range(0,height): cl=im.getpixel((i,j)) clall=cl[0]+cl[1]+cl[2] if(clall==0): #黑色 fh.write("1") else: fh.write("0") fh.write("\n") fh.close() ''' #加载数据 #将数据转为数组 def datatoarray(fname): arr=[] fh=open(fname) #图片是32*32的横轴每次读取32 for i in range(0,32): thisline=fh.readline() #读每一行 for j in range(0,32): #读入到数组里 arr.append(int(thisline[j])) return arr arr1=datatoarray("~/coding/python/data/testandtraindata/testdata/0_74.txt") #print(arr1) #建立一个函数,取文件的前缀 def seplabel(fname): filestr=fname.split(".")[0] label=int(filestr.split("_")[0]) return label #建立训练数据 def traindata(): #存储类别 labels=[] #得到训练目录下所有的文件 trainfile=listdir("~/coding/python/data/testandtraindata/traindata") #取当前文件有多少个 num=len(trainfile) #生成一个多少行多少列的向量,行的长度应该是32*32=1024(列),每一行存储一个文件 #用一个数组存储所有训练数据,行:文件总数,列:1024 trainarr=zeros((num,1024)) #第一层循环文件 for i in range(0,num): thisfname=trainfile[i] #调用seplabel函数 thislabel=seplabel(thisfname) #存到数组里 labels.append(thislabel) #调用datatoarray函数,i,:处理重复读取 trainarr[i,:]=datatoarray("~/coding/python/data/testandtraindata/traindata/"+thisfname) return trainarr,labels #用测试数据条用KNN算法去测试,看是否能够准确识别 def datatest(): trainarr,labels=traindata() testlist=listdir("~/coding/python/data/testandtraindata/testdata") tnum=len(testlist) for i in range(0,tnum): thistestfile=testlist[i] testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile) rknn=knn(3,testarr,trainarr,labels) print(rknn) #a=datatest() #print(a) #抽某一个文件测试文件出来进行验证 trainarr,labels=traindata() thistestfile="8_15.txt" testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile) rknn=knn(3,testarr,trainarr,labels) print(rknn)