参考《机器学习实战》
理论定义:K-近邻算法采用测量不同特征值之间的距离方法进行分类。
优点:精度高,、对异常值不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高。
适用数据范围:数值型和标称型。
步骤流程:
1.计算已知类别数据集中的点到当前点之间的距离
2.按照距离递增次序排序
3.选取与当前点距离最小的k个点
4.确定前k个点所在类别的出现频率
5.返回前k个点出现频率最高的类别作为当前点的预测分类。
#-*-coding:utf-8-*-
'实现KNN算法'
# 引入模块
from numpy import*
import operator
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) #训练数据
labels = ['A','A','B','B'] #标签
return group,labels
def classfy0(inx, dataSet, labels, k): #inx表示输入向量,dataSet表示输入的训练样本集,labels是标签向量,参数k表示用于选择最近邻居的数目
dataSetSize = dataSet.shape[0] #dataSet.shap表示数组各维的大小
diffMat = tile(inx, (dataSetSize,1)) - dataSet #tilt(A,reps)
sqDiffMat = diffMat **2
sqDistances = sqDiffMat.sum(axis=1) #矩阵每一行相加 sum(a,axis=0)表所有和
distances = sqDistances**0.5
#print distances
sortedDistIndicies = distances.argsort() #返回数组值从小到大的索引值
#print sortedDistIndicies
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #dist.get(key,default="None") 如果字典中不存在此键,则返回default值
#print classCount
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) #operator.itemgetter(1) 按照第二个域进行迭代 reverse逆序
return sortedClassCount[0][0]
if __name__ == '__main__':
group,labels = createDataSet()
inputNum = raw_input("please input your number:") #类型为字符串
list_inputNum = [int(x) for x in inputNum] #转为数组
k = raw_input("please input your k:")
k = int(k)
result = classfy0(list_inputNum,group,labels,k)
print result
运行测试:
#-*-coding:utf-8-*-
'''
Created on Sep 16, 2010
kNN: k Nearest Neighbors
Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number)
Output: the most popular class label
@author: pbharrin
'''
# 引入模块
from numpy import*
import operator
from os import listdir
import matplotlib
import matplotlib.pyplot as plt
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) #训练数据
labels = ['A','A','B','B'] #标签
return group,labels
def classify0(inx, dataSet, labels, k): #inx表示输入向量,dataSet表示输入的训练样本集,labels是标签向量,参数k表示用于选择最近邻居的数目
dataSetSize = dataSet.shape[0] #dataSet.shap表示数组各维的大小
diffMat = tile(inx, (dataSetSize,1)) - dataSet #tilt(A,reps)
sqDiffMat = diffMat **2
sqDistances = sqDiffMat.sum(axis=1) #矩阵每一行相加 sum(a,axis=0)表所有和
distances = sqDistances**0.5
#print distances
sortedDistIndicies = distances.argsort() #返回数组值从小到大的索引值
#print sortedDistIndicies
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #dist.get(key,default="None") 如果字典中不存在此键,则返回default值
#print classCount
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) #operator.itemgetter(1) 按照第二个域进行迭代 reverse逆序
return sortedClassCount[0][0]
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines()) #get the number of lines in the file
returnMat = zeros((numberOfLines,3)) #prepare matrix to return
classLabelVector = [] #prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
'''
if listFromLine[-1] == 'largeDoses':
listFromLine[-1] = 3
elif listFromLine[-1] == 'smallDoses':
listFromLine[-1] = 2
elif listFromLine[-1] == 'didntLike':
listFromLine[-1] = 1
'''
classLabelVector.append(int(listFromLine[-1]))
#classLabelVector.append(listFromLine[-1])
index += 1
return returnMat,classLabelVector
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
return normDataSet, ranges, minVals
def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount/float(numTestVecs))
print errorCount
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') #load the training set
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
def test():
group,labels = createDataSet()
inputNum = raw_input("please input your number:") #类型为字符串
list_inputNum = [int(x) for x in inputNum] #转为数组
k = raw_input("please input your k:")
k = int(k)
result = classify0(list_inputNum,group,labels,k)
print result
def test1():
filename = 'datingTestSet.txt'
datingDataMat,datingLabels = file2matrix(filename)
print datingDataMat
print datingLabels[0:20]
def pic_one():
fig = plt.figure()
ax = fig.add_subplot(111)
filename = 'datingTestSet.txt'
datingDataMat,datingLabels = file2matrix(filename)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],
15.0*array(datingLabels),15.0*array(datingLabels))
#p1=ax.scatter(datingDataMat[:,1],datingDataMat[:,2],marker='x', color='b')
#p2=ax.scatter(datingDataMat[:,1],datingDataMat[:,2],marker='o', color='c')
plt.xlabel('x')
plt.ylabel('y')
#plt.legend((p1,p2),('aa','bb'))
plt.show()
if __name__ == '__main__':
#test1()
#pic_one()
datingClassTest()