第二章 k-近邻算法
【参考书籍】机器学习实战(Machine Learning in Action)
本章内容:
- k-近邻分类算法
对未知类别属性数据集中的每个点依次执行以下操作:
(1)计算已知类别数据集中的点与当前点之间距离;
(2)按照距离递增次序排序;
(3)选取与当前点距离最小的k个点;
(4)确定前k个点所在类别的出现频率;
(5)返回前k个点出现频率最高的类别作为当前点的预测分类。 - 从文本文件中解析和导入数据
- 使用Matplotlib创建散点图
归一化数值
归一化数值:处理不同取值范围的特征值时,通常采用的方法是将数值归一化,入将取值范围处理为0到1或者-1到1之间。将取值范围的特征值转化为0到1区间内的值:
newValue = (oldValue - min) / (max - min)
其中min和max分别是数据集中最小、最大特征值。评估算法的正确率:通常使用提供的已有数据的90%作为训练样本来训练分类器,而使用余下的10%的数据去测试分类器,检测分类器的正确率。
使用的函数
函数名/属性 | 功能 |
---|---|
array() | 创建一个数组 |
shape | 数组或矩阵的各个维的大小 |
tile(A, reps) | 将数组A,根据数组reps沿各个维度重复多次,构成一个新的数组。reps的数字从后往前分别对应A的第N个维度的重复次数。 |
sum(arr,axis=1) | 根据行列(轴),求和 |
max(arr,axis=1) | 根据行列(轴),求最大值 |
min(arr,axis=1) | 根据行列(轴),求最小值 |
mean(arr,axis=1) | 根据行列(轴),求平均值 |
argsort() | 得到矩阵中每个元素的排序序号 |
dict.get(key,default) | 获取字典中,一个给定的key对应的值。若key不存在,则返回默认值default。 |
sorted(iterable[, key][, reverse]) | 第一个参数是一个iterable,返回值是一个对iterable中元素进行排序后的列表(list)。 |
open(filename) | 返回一个文件对象 |
fr.readlines() | 读取文件对象fr中的所有行,返回数组 |
fr.readline() | 读取文件对象fr的当前行,返回字符串 |
len(arr) | 返回数组的长度 |
zeros((n,m)) | 创建一个n*m的矩阵,用0填充 |
line.strip() | 删除文本行line后的回车符 |
str.spit(‘\t’) | 使用’\t’分割字符串str,返回一个列表 |
list[-1] | 获取列表的最后一个元素 |
vec.append(item) | 在向量、列表vec后追加元素item |
mat[index, :] | 获取矩阵/数组的第index行的所有元素 |
list[m:n] | 获取列表索引m到n的元素的值 |
plt.figure() | 创建画布? |
fig.add_subplot((m,n,x)) | 把画布分割成m*n的区块,在第x块上绘图 |
scatter() | 绘制散点 |
格式化输出 | |
raw_input(“prompt string”) | 显示提示字符串,将用户的输入转换成string |
input(“prompt string”) | 会根据用户输入变换相应的类型,而且如果要输入字符和字符串的时候必须要用引号包起来 |
range() | range(1,5) #代表从1到5(不包含5); range(1,5,2) #代表从1到5,间隔2(不包含5); range(5) #代表从0到5(不包含5) |
listdir(‘folder’) | from os import listdir,获取给定文件夹下的文件名列表,不含文件路径 |
程序代码:
# coding=utf-8
# "coding=utf-8" 中间不要有空格,代码中有中文注释,需要添加第一行代码
from numpy import *
import operator
# from os import listdir 用于手写数字识别系统的测试代码
from os import listdir
# 导入测试数据
def createDataSet() :
dataSet = array([[1.0,1.1]
,[1.0,1.0]
,[0,0]
,[0,0.1]])
labels = ['A','A','B','B']
return dataSet, labels
# kNN算法
# inX: 用于分类的输入向量
# dataSet: 输入的训练样本集
# labels: 标签向量
def classify0(inX, dataSet, labels, k) :
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k) :
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 功能:将约会数据文本文件导入矩阵
# 第一列表示每年获取的飞行常客里程数
# 第二列表示玩视频游戏所耗时间百分比
# 第三列表示每周所消费的冰激凌公升数
# 最后一列表示标签
def file2matrix(filename) :
fr = open(filename)
arrayOLines = fr.readlines()
# 得到文件行数
numberOfLines = len(arrayOLines)
# 创建一个行列为(numberOfLines, 3)的NumPython矩阵,元素用0填充
returnMat = zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayOLines :
# 截取文本行中回车符
line = line.strip()
# 以'\t'分割字符串,返回一个元素列表
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
# 索引值-1表示列表中的最后一列元素
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
# 归一化数据
# 处理不同取值范围的特征值时,通常采用的方法是将数值归一化,如将取值范围处理为0到1或者-1到1之间。
# 下面公式将任意取值范围的特征值转化为0到1区间内的值
# newValue = (oldValue - min) / (max - min)
# 其中min和max分别是数据集中的最小和最大特征值。
def autoNorm(dataSet) :
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1))
return normDataSet, ranges, minVals
# 功能:分类器针对约会网站的测试代码
def datingClassTest() :
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('C:\Python27\datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0
# 取整个没有按特定目的排序的前10%为测试数据normMat[0:numTestVecs, :]
# 余下90%的数据normMat[numTestVecs:m, :]作为训练样本来训练分类器
for i in range(numTestVecs) :
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], \
datingLabels[numTestVecs:m], 3)
print 'the classifier came back with: %d, the real answer is: %d' \
% (classifierResult, datingLabels[i])
if( classifierResult != datingLabels[i]) :
errorCount += 1.0
print 'The total error rate is %f' % (errorCount/float(numTestVecs))
# 约会网站预测函数
def classifyPerson() :
resultList = ['not at all','in small doses','in large doses'];
# 输入数据
percentTats = float(raw_input('percentage of time spent playing video games?'))
ffMiles = float(raw_input('frequent flier miles earned per year?'))
iceCream = float(raw_input('liters of ice cream consumed per year?'))
# 获取训练样本
datingDataMat, datingLabels = file2matrix('C:\Python27\datingTestSet2.txt')
# 归一化训练样本
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
# 归一化输入数据
classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
print 'You will probably like this person: ', resultList[classifierResult - 1]
# 将图像格式化处理为一个向量,将32*32的二进制图像矩阵转换为1*1024的向量
def img2vector(filename) :
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32) :
lineStr = fr.readline()
for j in range(32) :
returnVect[0, 32*i+j] = int(lineStr[j])
return returnVect
# 手写数字识别系统的测试代码
def handwritingClassTest() :
hwLabels = []
# 获取目录下的文件列表
trainingFileList = listdir('C:/Python27/ml/trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
# 将目录trainingDigits下的所有文件中的训练样本存储在shape为(m,1024)的trainingMat矩阵中
for i in range(m) :
# 带扩展名的文件名0_1.txt
fileNameStr = trainingFileList[i]
# 不带扩展名的文件名0_1
fileStr = fileNameStr.split('.')[0]
# 通过文件名获得样本标签0
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector('C:/Python27/ml/trainingDigits/%s' % fileNameStr)
testFileList = listdir('C:/Python27/ml/testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest) :
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('C:/Python27/ml/trainingDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print 'The classifier came back with: %d, the real answer is: %d' \
% (classifierResult, classNumStr)
if( classifierResult != classNumStr) : errorCount += 1.0
print '\nThe total number of errors is: %d' % errorCount
print '\nThe total error rate is %f' % (errorCount/float(mTest))
在命令行中执行:
>>> import kNN
>>> dataSet, labels = kNN.createDataSet()
>>> kNN.classify0([0,0], dataSet, labels, 3)
'B'
>>> reload(kNN)
<module 'kNN' from 'C:\Python27\kNN.py'>
>>> datingDataMat, datingLabels = kNN.file2matrix('datingTestSet2.txt')
>>> datingDataMat
array([[ 4.09200000e+04, 8.32697600e+00, 9.53952000e-01],
[ 1.44880000e+04, 7.15346900e+00, 1.67390400e+00],
[ 2.60520000e+04, 1.44187100e+00, 8.05124000e-01],
...,
[ 2.65750000e+04, 1.06501020e+01, 8.66627000e-01],
[ 4.81110000e+04, 9.13452800e+00, 7.28045000e-01],
[ 4.37570000e+04, 7.88260100e+00, 1.33244600e+00]])
>>> datingLabels[0:20]
[3, 2, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 2, 1, 1, 1, 1, 1, 2, 3]
>>> reload(kNN)
<module 'kNN' from 'C:\Python27\kNN.py'>
>>> normMat, ranges, minVals = kNN.autoNorm(datingDataMat)
>>> normMat
array([[ 0.44832535, 0.39805139, 0.56233353],
[ 0.15873259, 0.34195467, 0.98724416],
[ 0.28542943, 0.06892523, 0.47449629],
...,
[ 0.29115949, 0.50910294, 0.51079493],
[ 0.52711097, 0.43665451, 0.4290048 ],
[ 0.47940793, 0.3768091 , 0.78571804]])
>>> ranges
array([ 9.12730000e+04, 2.09193490e+01, 1.69436100e+00])
>>> minVals
array([ 0. , 0. , 0.001156])
>>> reload(kNN)
<module 'kNN' from 'C:\Python27\kNN.py'>
>>> kNN.datingClassTest()
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 2, the real answer is: 2
...
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 3, the real answer is: 1
The total error rate is 0.050000
>>> reload(kNN)
<module 'kNN' from 'C:\Python27\kNN.py'>
>>> kNN.classifyPerson()
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this person: in small doses
>>> reload(kNN)
<module 'kNN' from 'C:\Python27\kNN.py'>
>>> testVector = kNN.img2vector('C:\Python27/ml/testDigits/0_13.txt')
>>> testVector[0, 0:31]
array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.])
>>> testVector[0, 32:63]
array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.])
>>> reload(kNN)
>>> kNN.handwritingClassTest()
The classifier came back with: 0, the real answer is: 0
The classifier came back with: 0, the real answer is: 0
The classifier came back with: 0, the real answer is: 0
...
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 1, the real answer is: 8
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 1, the real answer is: 8
...
The classifier came back with: 9, the real answer is: 9
The classifier came back with: 9, the real answer is: 9
The total number of errors is: 15
The total error rate is 0.015856
使用Matplotlib创建散点图
- 使用了datingDataMat矩阵的第二、三列的数据生成散点图
>>> import matplotlib
>>> import matplotlib.pyplot as plt
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111)
>>> ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
<matplotlib.collections.PathCollection object at 0x02C248F0>
>>> plt.show()
- 采用色彩或其他标记来标记不同样本分类,生成散点图。
>>> from numpy import *
>>> import matplotlib
>>> import matplotlib.pyplot as plt
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111)
>>> ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels),
15.0*array(datingLabels))
<matplotlib.collections.PathCollection object at 0x02D641B0>
>>> plt.show()
上图虽然容易区分数据点从属类别,但很难根据这张图得出结论性信息
- 接下来,生成每年赢得的飞行常客里程数与玩视频游戏所占百分比的约会数据散点图。通过此图产生两个特征更容易区分数据点从属的类别。
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111)
>>> ax.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*array(datingLabels),
15.0*array(datingLabels))
<matplotlib.collections.PathCollection object at 0x02EDF570>
>>>
>>> plt.show()