import numpy
import os
from numpy import array
from numpy import tile
import operator
import matplotlib.pyplot as plt
#数据例子
def createDataSet():
group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels=['A','A','B','B'] #标签与点一一对应
return group,labels
'''******************************主要分类函数************************************************************'''
#取距离最近的k各点, 返回 k个点中频率最多的类别作为分类
def classify0(point,dataArray,labels,k):#(测试[,,...] 比较集array 标签集
OneDimension=dataArray.shape[0]
tmpArray=(tile(point,(OneDimension,1))-dataArray)**2 #point平铺成二维矩与其计算各点距离
sqrtArray=tmpArray.sum(1)
sortedArrayIndex=sqrtArray.argsort()#按索引点排序 -列表
#print(sortedArrayIndex)
classCount={} #空字典
for i in range(k):
lab=labels[sortedArrayIndex[i]] #取相应索引点的标签
classCount[lab]=classCount.get(lab,0)+1 #字典中有该key则取其映射值(这里为int),否则返回0
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
#指明关键字
return sortedClassCount[0][0]
#测试
'''
group0,labels0=createDataSet()
print(group0,labels0)
print(classify0([0,0],group0,labels0,3))
'''
'''*********************************************************************'''
'''***************************约会配对分类*******************************'''
'''*********************************************************************'''
#获取文件数据 返回数据Array 标签list
def file2matrix(filename):
file=open(filename)
fileList=file.readlines()#返回全部行 ,行后有\n---列表
returnMat=numpy.zeros((len(fileList),3))
index=0
labels=[]
for st in fileList:
st=st.strip()#移除字符串头尾指定的字符(默认为空字符) 这里移除\n
strList=st.split('\t')#str.split(sep=None, maxsplit=-1 无限制)
returnMat[index,:]=strList[0:3]
labels.append(int(strList[-1]))
index+=1
return returnMat,labels
#数据归一化
def Normalize(dataMat): #Array
min_value=dataMat.min(0)
max_value=dataMat.max(0)
range_value=max_value-min_value
normMat=dataMat-tile(min_value,(dataMat.shape[0],1))
normMat=normMat/tile(range_value,(dataMat.shape[0],1))
return normMat,range_value,min_value
#测试KNN错误率
def datingClassTest():
datingData,datingLabels=file2matrix('datingTestSet2.txt')
datingData,datingRange,datingMinValue=Normalize(datingData)
testnum=int(datingData.shape[0]/10) #100
error_count=0;
for i in range(testnum):
label=classify0(datingData[i],datingData[testnum:datingData.shape[0]],datingLabels[testnum:datingData.shape[0]],3)
if label!=datingLabels[i]:
error_count+=1;
print('错误率:%f'%(error_count/float(testnum)))
#测试
'''
datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
datingDataMat,datingDataRange,datingDataMin=Normalize(datingDataMat)
fg=plt.figure()
subfg1=fg.add_subplot(111)
subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
#subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),1*array(tile([1],(array(datingLabels).shape[0],1))))
plt.xlabel('玩视频耗时百分比')
plt.ylabel('周消耗冰激凌公升数')
plt.show()
datingClassTest()
'''
'''*********************************************************************'''
def classifyPerson():
datingData,datingLabels=file2matrix('datingTestSet2.txt')
datingData,datingRange,datingMinValue=Normalize(datingData)
resultClass=['不喜欢','一般','有魅力']
miles=float(input('每年飞行里程数:'))
game=float(input('玩游戏小号百分比:'))
ice=float(input('每周冰淇淋公升:'))
data=array(([miles,game,ice]-datingMinValue)/datingRange)
label=classify0(data,datingData,datingLabels,3)
print('类型是:',resultClass[label-1]) #数据的分类标签1,2,3
#测试
'''
classifyPerson()
'''
'''*********************************************************************'''
'''*****************************手写识别*********************************'''
'''*********************************************************************'''
def img2vector(filename):
file=open(filename)
returnVec=numpy.zeros((1,1024))
for i in range(32):
fileString=file.readline()
for j in range(32):
returnVec[0,i*32+j]=fileString[j]
file.close()
return returnVec
def handWriteClassTest():
trainList=os.listdir('trainingDigits')
DT=len(trainList)
trainArray=numpy.zeros((DT,1024))
labels=[]
for i in range(DT):
filename=trainList[i]
labels.append(int(filename[0]))
trainArray[i,:]=img2vector('trainingDigits/%s'%filename)
testList=os.listdir('trainingDigits')
DS=len(testList)
error_count=0;
for j in range(DS):
filename=testList[j]
label=int(filename[0])
testArray=img2vector('trainingDigits/%s'%filename)
testLabel=classify0(testArray,trainArray,labels,3)
if label!=testLabel:
error_count+=1
error_rate=error_count/DS
print('错误率:%f'%error_rate)
#测试
'''
handWriteClassTest()
'''
机器学习实战-KNN算法
最新推荐文章于 2023-12-12 08:42:44 发布