教程:
http://cuijiahua.com/blog/2017/11/ml_1_knn.html
1.KNN简介
K-nearest neighbor
KNN categorizes objects based on the classes of their nearest neighbors in the dataset
KNN predictions assume that objects near each other are similar. Distance metrics, such as Euclidean, city block, cosine, and Chebychev, are used to find the nearest neighbor.
KNN算法的思想不难理解,首先,了解一下最近邻:找出与待预测值距离最近的值,该值就是待预测值可能取到的值。k近邻是指:先计算待预测值与所有已给定的值的距离,将这些距离按从小到大排序,取前k个数中出现次数最多的那个值,就是待预测值取到的值。
KNN算法优缺点:
举个例子:使用k-近邻算法分类一个电影是爱情片还是动作片
上表是我们已有的数据集,即训练样本集。该数据集有两个特征:打斗镜头数和接吻镜头数。除此之外,我们也知道每个电影的所属类型,即分类标签。现给出一部未知分类的电影,根据这两个特征使用k-近邻算法可粗略得出这部电影是什么类型,如,给出打斗镜头数=100,接吻镜头数=5,分别计算(100, 5)与(1, 101), (5, 89), (108, 5), (115, 8)的距离,根据算法可得出这部电影属于动作片
2.距离度量
二维(2个特征)
可使用两点距离计算公式:(欧氏距离在二维空间上的公式)
高维(n个特征n>2)
欧氏距离:
3.k-近邻算法步骤
Step1:计算已知类别数据集中的点与当前点之间的距离;
Step2:按照距离递增次序排序;
Step3:选取与当前点距离最小的k个点;
Step4:确定前k个点所在类别的出现频率;
Step5:返回前k个点所出现频率最高的类别作为当前点的预测分类。
4.k-近邻算法实战
K-近邻算法的一般流程:
Step1:收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据。一般来讲,数据放在txt文本文件中,按照一定的格式进行存储,便于解析及处理。
Step2:准备数据:使用Python解析、预处理数据。
Step3:分析数据:可以使用很多方法对数据进行分析,例如使用Matplotlib将数据可视化。
Step4:测试算法:计算错误率。
Step5:使用算法:错误率在可接受范围内,就可以运行k-近邻算法进行分类。
4.0 实战背景
约会网站配对效果判定
海伦女士一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的任选,但她并不是喜欢每一个人。经过一番总结,她发现自己交往过的人可以进行如下分类:
不喜欢的人
魅力一般的人
极具魅力的人
海伦收集约会数据已经有了一段时间,她把这些数据存放在文本文件datingTestSet.txt中,每个样本数据占据一行,总共有1000行。datingTestSet.txt数据下载:数据集下载
海伦收集的样本数据主要包含以下3种特征:
每年获得的飞行常客里程数
玩视频游戏所消耗时间百分比
每周消费的冰淇淋公升数
4.1 准备数据:数据解析
在将上述特征数据输入到分类器前,必须将待处理的数据的格式改变为分类器可以接收的格式。分类器接收的数据格式:将数据分类两部分,即特征矩阵和对应的分类标签向量
4.2 数据可视化
详见4.7代码。
4.3 准备数据:数据归一化
为什么要使用数据归一化?
举个例子,
根据上表中给出的数据,若想计算样本3和样本4之间的距离,可以使用欧拉公式计算,如下:
很容易发现,上面方程中数字差值最大的属性(飞行里程数)对计算结果影响最大,而冰激凌公升数对结果影响非常非常小,产生这种现象的原因仅仅是因为飞行常客里程数远大于其他特征值。但海伦认为这三种特征是同等重要的,因此作为三个等权重的特征之一,飞行常客里程数并不应该如此严重影响到计算结果。
在处理这种不同取值范围的特征值时,通常采用的方法是将数值归一化。如将取值范围处理为0到1或者-1到1之间。可用下面的公式进行归一化:
其中,min 和max分别是数据集中的最小特征值和最大特征值
实现过程见4.7代码。
4.4 KNN算法实现分类
通常我们只提供已有数据的90%作为训练样本来训练分类器,而使用其余的10%数据去测试分类器,检测分类器的正确率。需要注意的是,10%的测试数据应该是随机选择的,由于海伦提供的数据并没有按照特定目的来排序,所以我们可以随意选择10%数据而不影响其随机性。
实现过程见代码。
4.5 测试算法性能
详见代码
4.6 应用:根据给出的数据输出结果
详见代码
运行给出的结果:
4.7 总代码
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
from matplotlib.font_manager import FontProperties
import numpy as np
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import operator
# 1.准备数据:数据解析