10行代码实现KNN(K近邻)

本文通过实例演示了K近邻算法的基本原理与实现过程,包括使用sklearn生成数据集、KNN分类应用及KD树的构建与搜索。通过代码展示了如何进行数据集划分、距离计算及最近邻居的选取。

1.K近邻一句话概括:近朱者赤,近墨者黑  +  投票选举

    话不多说,上代码


2.用sklearn生成数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
n_samples = 5000
n_bins = 3  # use 3 bins for calibration_curve as we have 3 clusters here
centers = [(-1, -1), (5, 5)]
X, y = make_blobs(n_samples=n_samples, n_features=2, cluster_std=1.0,
                  centers=centers, shuffle=False, random_state=42)

y[:n_samples // 2] = 0
y[n_samples // 2:] = 1
sample_weight = np.random.RandomState(42).rand(y.shape[0])
X_train, X_test, y_train, y_test, sw_train, sw_test = \
    train_test_split(X, y, sample_weight, test_size=0.1, random_state=42)


3.分类--KNN

k = 500
sum1 = 0
def result(dist,k):
    index = dist.argsort()#获取排序后的数组下标
    index = index[:k]
    out = y_train[index].tolist()
    return out.count(0) < k - out.count(0)
for i in range(len(X_test)):
    dist = distance(X_test[i], X_train)
    sum1 = sum1 + np.equal(y_test[i],result(dist,k))#y_train[dist.index(min(dist))])
print(np.float(sum1/len(X_test)))

4.画图显示

    

plt.scatter(X_train[:,0],X_train[:,1],c=y_train)
plt.scatter(X_test[:,0],X_test[:,1],c='b')
plt.show()


5.拓展--KD数存储与搜索

class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

    def create(self, dataSet, depth):  # 创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)  # 求出样本行,列
            midIndex = m / 2  # 中间数的索引位置
            axis = depth % n  # 判断以哪个轴划分数据,对应书中算法3.2(2)公式j()
            sortedDataSet = self.sort(dataSet, axis)  # 进行排序
            node = Node(sortedDataSet[midIndex])  # 将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex]  # 将中位数的左边创建2个副本
            rightDataSet = sortedDataSet[midIndex + 1:]
            print(leftDataSet)
            print(rightDataSet)
            node.lchild = self.create(leftDataSet, depth + 1)  # 将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth + 1)
            return node
        else:
            return None

    def sort(self, dataSet, axis):  # 采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]  # 由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j + 1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j + 1]
                    sortDataSet[j + 1] = temp
        print(sortDataSet)
        return sortDataSet

    def preOrder(self, node):
        if node != None:
            print("tttt->%s" % node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    #kd树搜索
    def search(self, tree, x):  # 搜索
        self.nearestPoint = None  # 保存最近的点
        self.nearestValue = 0  # 保存最近的值

        def travel(node, depth=0):  # 递归搜索
            if node != None:  # 递归终止条件
                n = len(x)  # 特征数
                axis = depth % n  # 计算轴
                if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找
                    travel(node.lchild, depth + 1)
                else:
                    travel(node.rchild, depth + 1)

                    # 以下是递归完毕,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  # 目标和节点的距离判断
                if (self.nearestPoint == None):  # 确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  # 确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth + 1)
                    else:
                        travel(node.lchild, depth + 1)

        travel(tree)
        return self.nearestPoint


    def dist(self, x1, x2):  # 欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5


<think>好的,用户想用MATLAB在10代码实现k-近邻算法,并且之前已经给出过一些参考。我需要先回顾之前的回答,确保这次补充完整且符合用户需求。首先,用户提供的代码中第六需要补充hist的部分。之前的回答中使用了histcounts函数来统计近邻的类别,这可能就是关键所在。用户现在希望用10代码完成,所以需要代码简洁。接下来,我需要检查之前的示例是否正确。之前的代码包括数据生成、网格创建、距离计算、排序找k近邻、统计类别、多数投票以及可视化步骤。这些步骤是否能在10内完成?看起来是的,但需要确保每一都是必要的,并且没有冗余。 用户提到的引用中有几个不同的实现方法,比如使用pdist2计算距离,sort排序获取索引。这些方法在MATLAB中都是高效的,应该保留。histcounts的参数可能需要调整,因为Y是逻辑值,转换为数值后0和1,所以分箱区间设为0.5到2.5,分成两个区间,分别对应0和1。这样统计每个测试点的k近邻中属于正类和负类的数量。 然后,多数投票通过比较两列的统计数,确定预测标签。这一步是否正确?是的,如果counts(:,2)(即类别1的数量)大于counts(:,1),则预测为1,否则为0。这样就能得到预测结果。可视化部分需要绘制等高线,使用reshape将预测结果转换为网格形状,然后contourf填充颜色。这部分是否在10内?需要确认代码数,确保不超过。现在需要检查是否有遗漏的步骤,比如数据标准化或其他处理?用户的问题中没有提到,所以可能不需要。但用户引用中提到不同量纲的问题,这在实现中是否处理?当前代码没有处理,所以可能需要提醒用户注意,但代码本身不包含标准化步骤,以保持简洁。 最后,确保代码数正确。将每个分号分隔的部分算作一,例如数据准备部分可能合并为两,核心算法部分两,预测一,可视化三,总共大约八,剩下的两可能需要调整。或者可能用户将注释也算在内?需要仔细检查。例如:1. X= rand(100,2)*2-1; Y= sum(X,2)> randn(100,1)/10;k=3;2.[xg,yg] =meshgrid(-1:0.01:1); test= [xg(:) yg(:)]; 3. D= pdist2(test,X); [~,idx]= sort(D,2);knn= idx(:,1:k);4.counts =histcounts(Y(knn),0.5:1:2.5,2);5.pred =counts(:,2)> counts(:,1); 6. contourf(xg,yg,reshape(pred,size(xg))); holdon;7.plot(X(Y,1),X(Y,2),'ok','MarkerFaceColor','b');8.plot(X(~Y,1),X(~Y,2),'ok','MarkerFaceColor','r');9.%或者添加其他必要的设置,如boxon等这样看来,可能刚好在10内,包括注释和绘图命令。需要确保所有步骤正确,没有遗漏。另外,用户引用中的代码示例有使用循环计算距离,但这里用pdist2更高效。所以应该保留,确保代码简洁高效。 最后,生成相关问题需要考虑k-NN的扩展和常见问题,如加权、高维问题、参数调优等,参考之前的回答和相关引用。</think>以下是使用MATLAB在10实现k-近邻算法的完整代码及解析: ```matlab % 生成数据与初始化 X = rand(100,2)*2-1; Y = sum(X,2) > randn(100,1)/10; k=3; % 生成二维数据与标签[^3] [xg,yg] = meshgrid(-1:0.01:1); test = [xg(:) yg(:)]; % 创建测试网格 % KNN核心算法 D = pdist2(test,X); [~,idx] = sort(D,2); knn = idx(:,1:k); % 计算距离并排序[^4] counts = histcounts(Y(knn), 0.5:1:2.5, 2); % 统计类别频次[^1] pred = counts(:,2) > counts(:,1); % 多数投票决策 % 可视化结果 contourf(xg, yg, reshape(pred,size(xg))); hold on; box on; axis equal; % 绘制决策边界 plot(X(Y,1),X(Y,2),'ok','MarkerFaceColor','b'); % 绘制正类样本 plot(X(~Y,1),X(~Y,2),'ok','MarkerFaceColor','r'); % 绘制负类样本 ``` ### 关键实现说明 1. **距离计算**:使用`pdist2`函数计算测试点与训练样本的欧氏距离矩阵$D \in \mathbb{R}^{N_{test} \times N_{train}}$,其中$N_{test}$是测试样本数,$N_{train}=100$是训练样本数[^4] 2. **近邻索引**:通过`sort`函数对距离矩阵按排序,获取前$k$个最近邻的索引$idx \in \mathbb{R}^{N_{test} \times k}$[^3] 3. **类别统计**:`histcounts`函数将逻辑标签转换为数值后统计区间$[0.5,1.5)$和$[1.5,2.5)$的频次,对应原始标签的0和1[^1] 4. **决策边界**:`contourf`函数将预测结果映射到网格坐标,通过颜色填充展示分类区域
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值