Kd树实现K近邻算法——鸢尾花分类

本文介绍了一种高效的空间数据结构——Kd树,并通过Python代码实现Kd树的构建及最近邻点搜索过程。使用鸢尾花数据集进行演示,展示了如何利用Kd树进行分类任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Kd树讲解推荐文章:https://www.cnblogs.com/bambipai/p/8435797.html

代码实现:

import numpy as np
from sklearn.model_selection import train_test_split


class Node:
    def __init__(self, value, lson=None, rson=None):
        self.lson = lson
        self.rson = rson
        self.val = value


class KdTree:
    def __init__(self, aixes):
        self.tree = None
        self.aixes = aixes
        self.nearestpoint = None
        self.nearestval = 0
        self.set = [] #用来保存最近邻点

    def creat(self, data, depth): #kd树创建
        if len(data) == 0:
            return None
        sort_val = depth % self.aixes
        mid = int(len(data) / 2)
        temp_data = data[data[:, sort_val].argsort()]
        node = Node(temp_data[mid])
        node.lson = self.creat(temp_data[:mid], depth + 1)
        node.rson = self.creat(temp_data[mid + 1:], depth + 1)
        return node

    def computer_dis(self, node_x, node_y):
        return ((node_x - node_y) ** 2).sum() ** 0.5

    def order(self, node):
        if node is None:
            return
        print(node.val)
        self.order(node.lson)
        self.order(node.rson)

    def check(self, tes): 
        for obj in self.set:
            if (obj == tes).all():  
                return True
        return False

    def search(self, pro_data, node, depth):
        if node is None:
            return
        aiex = depth % self.aixes
        if pro_data[aiex] < node.val[aiex]:
            self.search(pro_data, node.lson, depth + 1)
        else:
            self.search(pro_data, node.rson, depth + 1)

        dis = self.computer_dis(pro_data, node.val)
        if self.nearestpoint is None or self.nearestval > dis:
            if self.check(node.val) is False: #已经是近邻点,不用在考虑
                self.nearestpoint = node.val
                self.nearestval = dis
        if node.lson != None or node.rson != None:
            if abs(pro_data[aiex] - node.val[aiex]) <= self.nearestval:
                if pro_data[aiex] > node.val[aiex]:
                    self.search(pro_data, node.lson, depth + 1)
                else:
                    self.search(pro_data, node.rson, depth + 1)


def main():
    train_data = []
    train_target = []
    with open("iris.csv", "r", encoding="utf-8") as f:
        for line in f.readlines():
            temp_line = line.replace("\n", "").split(",")
            temp_x = []
            for element in temp_line[:-1]:
                temp_x.append(float(element))
            train_data.append(temp_x)
            train_target.append(temp_line[-1])
    tup_train_data = [tuple(obj) for obj in train_data]
    table = dict(zip(tup_train_data, train_target))
    train_X, test_X, train_y, test_y = train_test_split(train_data, train_target, random_state=0)
    KT = KdTree(len(train_X[0]))
    tree = KT.creat(np.array(train_X), 0)
    k = 3
    out = []
    for obj1, obj2 in zip(test_X, test_y):
        count = {}
        KT.set = []
        for i in range(k):
            KT.nearestpoint = None
            KT.nearestval = 0
            KT.search(obj1, tree, 0)
            res = KT.nearestpoint
            KT.set.append(res)
            val = table.get(tuple(res))
            if count.get(val) is None:
                count[val] = 1
            else:
                count[val] += 1
        out.append(max(count, key=count.get) == obj2)
    print("k:",k,sum(out)/len(out))
main()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值