机器学习实战——python实现knn算法

本文详细介绍了KNN算法的工作原理及其Python实现过程,并通过一个具体案例演示了如何使用KNN进行分类。主要内容包括距离计算、分类步骤及数据可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

knn算法描述

对需要分类的点依次执行以下操作:
1.计算已知类别数据集中每个点与该点之间的距离
2.按照距离递增顺序排序
3.选取与该点距离最近的k个点
4.确定前k个点所在类别出现的频率
5.返回前k个点出现频率最高的类别作为该点的预测分类

knn算法实现

数据处理

#从文件中读取数据,返回的数据和分类均为二维数组
def loadDataSet(filename):
    dataSet = []
    labels = []
    fr = open(filename)
    for line in fr.readlines():
        lineArr = line.strip().split(",")
        dataSet.append([float(lineArr[0]),float(lineArr[1])])
        labels.append([float(lineArr[2])])
    return dataSet , labels

knn算法

#计算两个向量之间的欧氏距离
def calDist(X1 , X2):
    sum = 0
    for x1 , x2 in zip(X1 , X2):
        sum += (x1 - x2) ** 2
    return sum ** 0.5

def knn(data , dataSet , labels , k):
    n = shape(dataSet)[0]
    for i in range(n):
        dist = calDist(data , dataSet[i])
        #只记录两点之间的距离和已知点的类别
        labels[i].append(dist)
    #按照距离递增排序
    labels.sort(key=lambda x:x[1])
    count = {}
    #统计每个类别出现的频率
    for i in range(k):
        key = labels[i][0]
        if count.has_key(key):
            count[key] += 1
        else : count[key] = 1
    #按频率递减排序
    sortCount = sorted(count.items(),key=lambda item:item[1],reverse=True)
    return sortCount[0][0]#返回频率最高的key,即label

结果测试

已知类别数据(来源于西瓜书+虚构)

0.697,0.460,1
0.774,0.376,1
0.720,0.330,1
0.634,0.264,1
0.608,0.318,1
0.556,0.215,1
0.403,0.237,1
0.481,0.149,1
0.437,0.211,1
0.525,0.186,1
0.666,0.091,0
0.639,0.161,0
0.657,0.198,0
0.593,0.042,0
0.719,0.103,0
0.671,0.196,0
0.703,0.121,0
0.614,0.116,0

绘图方法

def drawPoints(data , dataSet, labels):
    xcord1 = [];
    ycord1 = [];
    xcord2 = [];
    ycord2 = [];
    for i in range(shape(dataSet)[0]):
        if labels[i][0] == 0:
            xcord1.append(dataSet[i][0])
            ycord1.append(dataSet[i][1])
        if labels[i][0] == 1:
            xcord2.append(dataSet[i][0])
            ycord2.append(dataSet[i][1])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(xcord1, ycord1, s=30, c='blue', marker='s',label=0)
    ax.scatter(xcord2, ycord2, s=30, c='green',label=1)
    ax.scatter(data[0], data[1], s=30, c='red',label="testdata")
    plt.legend(loc='upper right')
    plt.show()

测试代码

dataSet , labels = loadDataSet('dataSet.txt')
data = [0.6767,0.2122]
drawPoints(data , dataSet, labels)
newlabels = knn(data, dataSet , labels , 5)
print newlabels

运行结果

这里写图片描述

这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值