说明
针对《机器学习实战》第19页程序清单2-1的代码,我在理解的基础上,改写了一部分代码并在处理部分数据后完成了绘图
代码
from numpy import *
import matplotlib.pyplot as plt
def createInfo():
dataSet = [[2,2],[2.5,2.4],[1.6,1.7],[0.7,0.7],[1,0],[1,1]]
labels = ['big','big','big','small','small','small']
return array(dataSet),array(labels)
def classify(inX,dataSet,labels,k):
rows = dataSet.shape[0]
inX = tile(inX,(rows,1))
inX = inX-dataSet
inX = (inX**2).sum(axis=1)
distance = inX**0.5
index = distance.argsort()
dict = {}
for i in range(k):
pos = labels[index[i]]
dict[pos] = dict.get(pos,0)+1
Dict = sorted(dict.items(),key=lambda x:x[1],reverse=True)
return Dict[0][0]
dataSet,labels = createInfo()
print(classify([1,1.5],dataSet,labels,3))
num_label = []
for item in labels:
if item=='big':
num_label.append(2)
else:
num_label.append(1)
num_label.append(3)
num_label = array(num_label)
dataSet = row_stack((dataSet,[1,1.5]))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dataSet[:,0],dataSet[:,1],15*num_label,num_label)
plt.show()
注意
scatter中的参数都要保证是array类型而不能是list类型,否则会报错
总结
kNN算法逻辑较为简单,大家可以利用debug模式逐步查看每个变量,就能逐渐理解代码前后的逻辑