K近邻算法之距离求解

k近邻算法需要注意的3大要素是,K值,距离度量和决策方式,以下代码表示的是距离度量方式,L1 L2和L3的三种计算方式

# -*-coding:utf-8-*-
import math
point_1 = [1, 1]
point_2 = [5, 1]
point_3 = [4, 4]
def get_distance_by_1():
    #L1 是曼哈顿离,计算公式是 每个向量差绝对值的和
    distance_1 = abs(point_2[0] - point_1[0]) + abs(point_2[1]-point_1[1])
    distance_2 = abs(point_3[0] - point_1[0]) + abs(point_3[1]-point_1[1])
    if distance_1>distance_2:
        print("L1点3比较近")
    else:
        print("L1点2比较近")
def get_distance_by_2():
    distance_1 = abs(point_2[0] - point_1[0])*abs(point_2[0] - point_1[0]) + abs(point_2[1] - point_1[1])*abs(point_2[1] - point_1[1])
    distance_1 = math.sqrt(distance_1)
    distance_2 = abs(point_3[0] - point_1[0])*abs(point_3[0] - point_1[0]) + abs(point_3[1] - point_1[1])*abs(point_3[1] - point_1[1])
    distance_2 = math.sqrt(distance_2)
    if(distance_1>distance_2):
        print("L2点3比较近")
    else:
        print("L2点2比较近")
def get_distance_by_3():
    if abs(point_2[0] - point_1[0]) >abs(point_2[1] - point_1[1]):
        distance_1 = abs(point_2[0] - point_1[0])
    else:
        distance_1 = abs(point_2[1] - point_1[1])
    if abs(point_3[0] - point_1[0]) > abs(point_3[1] - point_1[1]):
        distance_2 = abs(point_3[0] - point_1[0])
    else:
        distance_2 = abs(point_3[1] - point_1[1])
    if distance_1>distance_2:
        print("L3 点3比较近")
    else:
        print("L3 点2比较近")
get_distance_by_1()
get_distance_by_2()
get_distance_by_3()

自己写了半天还是看了牛逼人的代码觉得比较好

粘贴如下

# --*-- coding:utf-8 --*--
import numpy as np


class Node:  # 结点
    def __init__(self, data, lchild=None, rchild=None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild


class KdTree:  # kd树
    def __init__(self):
        self.kdTree = None

    def create(self, dataSet, depth):  # 创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)  # 求出样本行,列
            midIndex = int(m / 2)  # 中间数的索引位置
            axis = depth % n  # 判断以哪个轴划分数据
            sortedDataSet = self.sort(dataSet, axis)  # 进行排序
            node = Node(sortedDataSet[midIndex])  # 将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex]  # 将中位数的左边创建2改副本
            rightDataSet = sortedDataSet[midIndex + 1:]
          #  print(leftDataSet)
           # print(rightDataSet)
            node.lchild = self.create(leftDataSet, depth + 1)  # 将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth + 1)
            return node
        else:
            return None

    def sort(self, dataSet, axis):  # 采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]  # 由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j + 1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j + 1]
                    sortDataSet[j + 1] = temp
        #print(sortDataSet)
        return sortDataSet

    def preOrder(self, node):  # 前序遍历
        if node != None:
            print("tttt->%s" % node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    def search(self, tree, x):  # 搜索
        self.nearestPoint = None  # 保存最近的点
        self.nearestValue = 0  # 保存最近的值

        def travel(node, depth=0):  # 递归搜索
            if node != None:  # 递归终止条件
                n = len(x)  # 特征数
                axis = depth % n  # 计算轴
                if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找
                    travel(node.lchild, depth + 1)
                else:
                    travel(node.rchild, depth + 1)

                # 以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  # 目标和节点的距离判断
                if (self.nearestPoint == None):  # 确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  # 确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth + 1)
                    else:
                        travel(node.lchild, depth + 1)

        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2):  # 欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5


if __name__ == '__main__':
    dataSet = [[2, 3],
               [5, 4],
               [9, 6],
               [4, 7],
               [8, 1],
               [7, 2]]
    x = [5, 3]
    kdtree = KdTree()
    tree = kdtree.create(dataSet, 0)
    #kdtree.preOrder(tree)
    print(kdtree.search(tree, x))


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

世纪殇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值