手把手教你使用KNN算法(Python实现)

本文手把手教你使用KNN算法,详细解释了算法的基本原理,并通过Python代码展示了如何进行KNN分类。同时,用Breast Cancer数据集进行了实战应用,演示了如何使用scikit-learn库实现KNN算法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上节我们简单进行了KNN算法的说明,想想假期结束再回味一下!

Knn算法基本原理:

假设我有如下两个数据集:

dataset = {'black':[ [1,2], [2,3], [3,1] ], 'red':[ [6,5], [7,7], [8,6] ] }

另外有一点绿颜色标记(3.5,5.3), KNN的任务就是判断这个点(下图中的绿点)该划分到哪个组。


KNN分类算法超级简单:只需使用初中所学的两点距离公式(欧拉距离公式),计算绿点到各组的距离,看绿点和哪组更接近。K代表取离绿点最近的k个点,这k个点如果其中属于红点个数占多数,我们就认为绿点应该划分为红组,反之,则划分为黑组。如果有两组数据(如上图),k值最小应为3(X轴坐标3.5)。

除了K-Nearest Neighbor之外还有其它分组的方法,如Radius-Based Neighbor。此方法后面在做介绍。

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import math
import numpy as np
from matplotlib import pyplot
from collections import Counter
import warnings
 
# k-Nearest Neighbor算法
def k_nearest_neighbors ( data , predict , k = 5 ) :
 
     if len ( data ) >= k :
         warnings . warn ( "k is too small" )
 
     # 计算predict点到各点的距离
     distances = [ ]
     for group in data :
         for features in data [ group ] :
             #euclidean_distance = np.sqrt(np.sum((np.array(features)-np.array(predict))**2))   # 计算欧拉距离,这个方法没有下面一行代码快
             euclidean_distance = np . linalg . norm ( np . array ( features ) - np . array ( predict ) )
             distances . append ( [ euclidean_distance , group ] )
 
     sorted_distances = [ i [ 1 ]    for i in sorted ( distances ) ]
     top_nearest = sorted_distances [ : k ]
 
     #print(top_nearest)  ['red','black','red']
     group_res = Counter ( top_nearest ) . most_common ( 1 ) [ 0 ] [ 0 ]
     confidence = Counter ( top_nearest ) . most_common ( 1 ) [ 0 ] [ 1 ] * 1.0 / k
     # confidences是对本次分类的确定程度,例如(red,red,red),(red,red,black)都分为red组,但是前者显的更自信
     return group_res , confidence
 
if __name__ == '__main__' :
 
     dataset = { 'black' : [ [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 1 ] ] , 'red' : [ [ 6 , 5 ] , [ 7 , 7 ] , [ 8 , 6 ] ] }
     new_features = [ 3.5 , 5.2 ]    # 判断这个样本属于哪个组
 
     for i in dataset :
         for ii in dataset [ i ] :
             pyplot . scatter ( ii [ 0 ] , ii [ 1 ] , s = 50 , color = i )
 
     which_group , confidence = k_nearest_neighbors ( dataset , new_features , k = 3 )
     print ( which_group , confidence )
 
     pyplot . scatter ( new_features [ 0 ] , new_features [ 1 ] , s = 100 , color = which_group )
 
     pyplot . show ( )

结果如下所示


归为红色一类的概率为:0.66666666

我们使用实际数据进行应用

数据集(Breast Cancer):https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Original%29

点击download: Data Folder/breast-cancer-wisconsin.data(复制粘贴到txt文件再重命名)

代码如下:(if __name__=='__main__':前面代码一样

import math
import numpy as np
from collections import Counter
import warnings
import pandas as pd
import random
 
# k-Nearest Neighbor算法
def k_nearest_neighbors ( data , predict , k = 5 ) :
 
     if len ( data ) >= k :
         warnings . warn ( "k is too small" )
 
     # 计算predict点到各点的距离
     distances = [ ]
     for group in data :
         for features in data [ group ] :
             euclidean_distance = np . linalg . norm ( np . array ( features ) - np . array ( predict ) )
             distances . append ( [ euclidean_distance , group ] )
 
     sorted_distances = [ i [ 1 ]    for i in sorted ( distances ) ]
     top_nearest = sorted_distances [ : k ]
 
     group_res = Counter ( top_nearest ) . most_common ( 1 ) [ 0 ] [ 0 ]
     confidence = Counter ( top_nearest ) . most_common ( 1 ) [ 0 ] [ 1 ] * 1.0 / k
 
     return group_res , confidence
if __name__=='__main__':
    df=pd.read_csv('iris.csv')#加载数据
    #print (df.head())
    #print(df.shape)
     df . replace ( '?' , np . nan , inplace = True )    # -99999
     df . dropna ( inplace = True )    # 去掉无效数据
     #print(df.shape)
     df . drop ( [ 'id' ] , 1 , inplace = True )#去掉id 这一列(第一列名字为id)
 
     # 把数据分成两部分,训练数据和测试数据
     full_data = df . astype ( float ) . values . tolist ( )
 
     random . shuffle ( full_data )
 
     test_size = 0.2    # 测试数据占20%
     train_data = full_data [ : - int ( test_size * len ( full_data ) ) ]
     test_data = full_data [ - int ( test_size * len ( full_data ) ) : ]
 
     train_set = { 2 : [ ] , 4 : [ ] }
     test_set = { 2 : [ ] , 4 : [ ] }
     for i in train_data :
         train_set [ i [ - 1 ] ] . append ( i [ : - 1 ] )
     for i in test_data :
         test_set [ i [ - 1 ] ] . append ( i [ : - 1 ] )
 
     correct = 0
     total = 0
 
     for group in test_set :
         for data in test_set [ group ] :
             res , confidence = k_nearest_neighbors ( train_set , data , k = 5 ) # 你可以调整这个k看看准确率的变化,你也可以使用matplotlib画出k对应的准确率,找到最好的k值
             if group == res :
                 correct += 1
             else :
                 print ( confidence )
             total += 1
 
     print ( correct / total )    # 准确率
 
     print ( k_nearest_neighbors ( train_set , [ 4 , 2 , 1 , 1 , 1 , 2 , 3 , 2 , 1 ] , k = 5 ) ) # 预测一条记录

结果如下所示:


使用scikit-learn 中K临近算法

代码如下:

import numpy as np
from sklearn import preprocessing , cross_validation , neighbors    # cross_validation已deprecated,使用model_selection替代
import pandas as pd
 
df=pd.read_csv('iris.csv')#加载exel数据
#print(df.head())
#print(df.shape)
df . replace ( '?' , np . nan , inplace = True )    # -99999
df . dropna ( inplace = True )
#print(df.shape)
df . drop ( [ 'id' ] , 1 , inplace = True )
 
X = np . array ( df . drop ( [ 'class' ] , 1 ) )
Y = np . array ( df [ 'class' ] )
 
X_trian , X_test , Y_train , Y_test = cross_validation . train_test_split ( X , Y , test_size = 0.2 )
 
clf = neighbors . KNeighborsClassifier ( )
clf . fit ( X_trian , Y_train )
 
accuracy = clf . score ( X_test , Y_test )
print ( accuracy )
 
sample = np . array ( [ 4 , 2 , 1 , 1 , 1 , 2 , 3 , 2 , 1 ] )
print ( sample . reshape ( 1 , - 1 ) )
print ( clf . predict ( sample . reshape ( 1 , - 1 ) ) )

结果如下:(里面有个警告但不妨碍结果)


scikit-learn中的算法和我们上面实现的算法原理完全一样,只是它的效率更高,支持的参数更全。

(以上内容学习于大熊猫)



人工智能时代,编程已成为一项基本技能。Python,人工智能时代最佳的编程入门语言。本系列课程分为三部分:手把手你学Python(基础篇)、手把手你学Python(进阶篇)、手把手你学Python(实战篇)。面向零基础用户,从无到有,从易到难,层层递进,带你遨游Python世界;采用案例驱动,即学即练即用,将学习落到实处。人工智能时代,编程已成为一项基本技能。国内一些发达省市,已将编程纳入中小学材;编程门槛大幅降低,已不再是计算机行业的专利;学编程训练思维,受益终生;掌握编程可有效,提升工作效率。Python,人工智能时代最佳的编程入门语言。设计人性化,语法简单,容易掌握,近年来热度不断攀升;丰富的内置标准库,强大的第三方库,大大缩减编程工作量;网络爬虫、数据处理、科学计算方面的优势,适用于各行各业;强大的技术体系,能够胜任Web开发、系统运维、人工智能等主流领域本系列课程的主要内容安排。   基础篇:语法基础程序结构基本数据结构函数与异常处理常见库操作文件操作 …………   进阶篇:面向对象思想数据库操作Numpy库介绍Pandas库介绍数据可视化机器学习算法…………   实战篇:网络爬虫原理Requests库学习电影网站信息抓取Scrapy爬虫框架研招网数据抓取学位论文数据抓取…………课程学特点:零基础,从无到有,从易到难,层层递进,带你遨游Python的世界;理论联系实践,案例驱动,即学即练即用,将学习落到实处;提供学习交流平台,在线答疑,自学途中不迷茫;本课程适用人群:计算机相关专业的新生准备转型从事数据处理的职场人员各行各业在职数据处理人员希望从事科学研究的人员程序设计爱好者课程目录安排如下: 第9Python面向对象9_1_类和对象9_2_实例变量9_3_类变量9_4_类中的方法9_5_类的继承9_6_对象的拷贝9_7_本章小结9_8_练习讲解19_9_练习讲解2第10章 Python操作数据库10_1_数据库基础10_2_结构化查询语句SQL10_3_Python操作数据库核心API10_4_Python操作数据库案例10_5_本章小结10_6_练习讲解第11章 Numpy入门与实践11_1_数组对象-ndarray11_2_索引和切片(上)11_3_索引和切片(下)11_4_Numpy中的通用函数11_5_数组运算11_6_本章小结11_7_练习讲解第12章 Pandas入门与实践12_1_Series和Index介绍12_2_Series数据访问和常用方法12_3_DataFrame创建与数据访问12_4_DataFrame中的属性和方法12_5_DataFrame的合并12_6_Pandas加载数据和缺失值处理12_7_Pandas中分组操作12_8_Pandas中数据合并操作12_9_Pandas综合案例12_10_本章小结12_11_练习讲解第13章 数据可视化-matplotlib13_1_pyplot绘图基础13_2_绘制线形图13_3_绘制直方图13_4_绘制条形图13_5_绘制饼状图13_6_绘制散点图13_7_生成词云图13_8_本章小结13_9_练习讲解第14章 机器学习库 Scikit-learn14_1_机器学习基础14_2_鸢尾花数据读取和可视化14_3_自己写KNN算法实现14_4_调用sklearn中的KNN算法14_5_波士顿房价预测问题14_6_手写数字识别14_7_本章小结 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值