遍历KD树的一个尝试,不用递归遍历?也不使用队列?

本文介绍了一种使用Python迭代器实现KD树遍历的方法。通过维护节点状态和游标节点,实现了无需递归和额外队列的数据结构遍历。文章详细解释了遍历逻辑,并提供了完整的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

《情景剧》

看客:KD树?不用递归遍历?也不使用队列?别逗了小子。
我: 您别不信。我就是这么厉害,就是这么自信。
看客:那你敢不敢放出来看看?
我: 那您瞧好嘞……

正文

在本文中,我要给大家介绍一个用Python迭代器实现的KD树(二叉树)遍历的方法。迭代的介绍,可以参考我写的这篇文章。
网址:http://blog.youkuaiyun.com/weixin_37722024/article/details/62424311
在这里,我就不细说迭代器的写法了,而只是对遍历时需要用到的几个成员变量和成员函数 __next__() 中的遍历的逻辑作出说明。
Python迭代器的帮助:https://docs.python.org/3/tutorial/classes.html#iterators

一、KD树的类的构造

构造KD树也就是二叉树的方法比较简单,就是首先构造根节点,然后用递归地方法构造根节点的左子树和右子树。这不在本文的讨论范围。

KD_Node 类中,定义了二叉树必不可少的节点值、左子树和右子树,以及KD树需要的每个节点的 split 。除了这几个基本成员以外,新增了三个成员变量:一个是所有类实例公用的 cur_trav ,这个成员变量保存了一个类型为 KD_Node 的节点,用作搜索时的游标;类实例的成员变量 flag_trav ,是 int 型,保存每一个节点的在遍历过程中的状态;类实例的成员变量 father ,是 KD_Node 型,保存了每个节点的父节点。

这样的结构,每个节点增加一个 int 变量,类增加一个游标节点由所有节点(也是类的实例)共享,增加的内存开销有限,却带来了极大的灵活性,也使得代码变得简洁。用同样的遍历过程,不止能获得各个节点的值,还可以用matplotlib绘制出KD树的空间分割图,甚至模拟KNN的搜索过程。我在draw_KDT函数里就想这么做,还没做完。

二、迭代过程

前面说过了搜索的逻辑都写在 __next__() 里了,具体的流程还请看一下代码。我在开始写迭代器的时候是想用递归写 __next__() 来着,试过,不行,放弃,另起楼灶,就鼓捣出这么个玩意儿。

其中,用作遍历的几个变量,在KD树刚刚建立好后都保持初始值。在开始遍历时,拷贝根节点到游标 cur_trav ,这就是当前搜索的节点。在搜索的过程中,根据游标所在的当前节点的状态 flag_trav ,来确定是返回当前节点,或者返回他的左子节点,又或者返回右子节点。所以这个状态只使用了最低3个bit,bit0代表当前节点是否已经遍历过,bit1代表当前节点的左子节点是否遍历过,bit2代表当前节点的右子节点是否遍历过。而 father 是为了能够返回父节点,随后能够搜索到另外一个子树。

节点状态 flag_trav 的变化规律:1、初始值为0,表示当前节点、左子和右子都未被搜索过;2、算是中序遍历,也就是说先搜当前节点,然后左,最后右。所以[bit2,bit1,bit0]的状态只有:[000], [001],[011],[111]四种状态;3、如果子节点不存在,则跳过。左右子节点都不存在,就跳到当前节点的父节点。

三、代码

直接把几段关键的贴出来,懒得看长篇代码的同学,就看着几段意思意思吧。

代码段1,类的声明:
class KD_Node:
    cur_trav = None             # cursor for traversal.

    def __init__(self, point=None, split=None, L=None, R=None, father=None):
        """
        initiate a kd tree.
        point: datum of this node
        split: split plane for this node
        L:     left son
        R:     right son
        father: father of this node, if root it's None
        """
        self.point  = point
        self.split  = split
        self.left   = L
        self.right  = R
        self.father = father
        self.flag_trav = 0      # traversal flag. 
                                #   bit 0 is notation for itself
                                #   bit 1 is for its left son
                                #   bit 2 is for its right son
        ......
代码段2, __next__() 函数
    def __next__(self):
        # with non-iteration traverse the tree
        cursor = None
        if KD_Node.cur_trav == None:        # First time to use cur_trav, initiate.
            KD_Node.cur_trav = self

        cursor = KD_Node.cur_trav
        while 1:
            if cursor.flag_trav & 0X07 == 0X7:      # any node has flag with
                                                    # value=3 
                                                    # that states a completion
                                                    # of traversal.
                if cursor.father == None:
                    raise StopIteration
                else:
                    cursor = cursor.father

            elif cursor.flag_trav & 0X01 == 0:      # if bit0 == 0,
                cursor.flag_trav |= 0X01            # set bit0 = 1
                #cursor = cursor            # not need. set cursor => self
                break                               # BREAK! return current.

            elif cursor.flag_trav & 0X02 == 0:      # if bit1==0, bit2==0
                cursor.flag_trav |= 0X02            # set bit1 of self
                if cursor.left != None:
                    cursor = cursor.left            # set cursor => left son
                else:                               # self.left is None, skip
                    continue

            elif cursor.flag_trav & 0X04 == 0:      # if bit2 == 0,
                cursor.flag_trav |= 0X04            # set bit2 = 1
                if cursor.right != None:
                    cursor = cursor.right           # set cursor => right son
                else:
                    continue
        KD_Node.cur_trav = cursor

        return KD_Node.cur_trav
代码段3,简洁的遍历
def main():
    kd = None
    kd = CreateKDT(kd, X)

    for node in kd:
        print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
完整代码如下
import numpy as np
import matplotlib.pyplot as plt
"""
X,  feature vectors
Y,  class of X
X_with_class, I just use this to draw a graphic in the piece of code at \
    the bottom.
D,  dimension of each of vectors.
"""
# Construct initial to be classified data
X = np.array([ (3,5), (2,4), (1,1), (5,2), (1,5), (4,1) ])
Y = [ 'g', 'g', 'r', 'r', 'g', 'r' ]
X_with_class = [ [X[a,0],X[a,1],Y[a]] for a in range(len(X)) ]

D = 0
if len(X[0]):
    D = len(X[0])


class KD_Node:
    cur_trav = None             # cursor for traversal.

    def __init__(self, point=None, split=None, L=None, R=None, father=None):
        """
        initiate a kd tree.
        point: datum of this node
        split: split plane for this node
        L:     left son
        R:     right son
        father: father of this node, if root it's None
        """
        self.point  = point
        self.split  = split
        self.left   = L
        self.right  = R
        self.father = father
        self.flag_trav = 0      # traversal flag. 
                                #   bit 0 is notation for itself
                                #   bit 1 is for its left son
                                #   bit 2 is for its right son

    def clear_trav(self):
        KD_Node.cur_trav = None
        self.flag_trav = 0
        if self.left:
            self.left.clear_trav()
        if self.right:
            self.right.clear_trav()

    def __iter__(self):
        return self

    def __next__(self):
        # with non-iteration traverse the tree
        cursor = None
        if KD_Node.cur_trav == None:        # First time to use cur_trav, initiate.
            KD_Node.cur_trav = self

        cursor = KD_Node.cur_trav
        while 1:
            if cursor.flag_trav & 0X07 == 0X7:      # any node has flag with
                                                    # value=3 
                                                    # that states a completion
                                                    # of traversal.
                if cursor.father == None:
                    raise StopIteration
                else:
                    cursor = cursor.father

            elif cursor.flag_trav & 0X01 == 0:      # if bit0 == 0,
                cursor.flag_trav |= 0X01            # set bit0 = 1
                #cursor = cursor            # not need. set cursor => self
                break                               # BREAK! return current.

            elif cursor.flag_trav & 0X02 == 0:      # if bit1==0, bit2==0
                cursor.flag_trav |= 0X02            # set bit1 of self
                if cursor.left != None:
                    cursor = cursor.left            # set cursor => left son
                else:                               # self.left is None, skip
                    continue

            elif cursor.flag_trav & 0X04 == 0:      # if bit2 == 0,
                cursor.flag_trav |= 0X04            # set bit2 = 1
                if cursor.right != None:
                    cursor = cursor.right           # set cursor => right son
                else:
                    continue
        KD_Node.cur_trav = cursor

#        print("cursor:", KD_Node.cur_trav, "self flag:", KD_Node.cur_trav.flag_trav)

        return KD_Node.cur_trav


def CreateKDT(node=None, data=None, father=None):
    """
    TODO: DOC FOR CreateKDT
    INPUT: node, 
           data, [ (3,5), (2,4), (1,1) ]
           father, the father
    OUTPUT: 
    """
    # variance for each dimension, the biggest is desirable.
    dim = D
    var = np.var(data, axis=0)
    split = np.argmax(var)
    # Calc out the position of current node. Here the middle.
    pos = int(len(data)/2)
    # Using current split plane to split current data
    pos_list = np.argpartition(data[:,split], pos)
    point = data[pos_list[pos]]
    """
    print procedure
    print("#"*20)
    print("data:",list(data))
    print("split:",split,",",data[:,split])
    print("pos list: ",pos_list)
    print("pos:",pos)
    print("the node data:",data[pos_list[pos]] )
    """
    """
    print procedure
    print("DEBUG: LEFT=", data[pos_list[:pos]] )
    print("DEBUG: RIGHT=", data[pos_list[(pos+1):]] )
    """
    node = KD_Node(point, split, father=father)
    if len(data) > 1:
        if len(data[pos_list[:pos]]) != 0:
            node.left = CreateKDT(node.left, data[pos_list[:pos]], node)
        if len(data[pos_list[(pos+1):]]) != 0:
            node.right = CreateKDT(node.right, data[pos_list[(pos+1):]], node)
    return node

def get_split_pos(data, split):
    """return the position to split in data."""
    pos = len(data)/2
    return 

def preorder(node, depth=-1):
    """
    Preorder a KD node
    """
    if node:
        s = '#' + '-'*50 + '#\n'
        s += 'Node:' + str(node) + '\n'
        s += 'Point:' + str(node.point) \
                      + ', Flag: ' + str(bin(node.flag_trav)) \
                      + ', Cursor:' + str(KD_Node.cur_trav) \
                      + '\n'
        s += "Father:" + str(node.father) + '\n'
        s += "L:" + str(node.left) + '\n'
        s += "R:" + str(node.right)
        print(s)
        if node.left:
            preorder(node.left)
        if node.right:
            preorder(node.right)

def draw_KDT(kd):
    """
    Draw a plot in which each of data determined by a point and draw the classifying plane.
    """
    plt.figure(figsize=(6,6))
    plt.xlabel("$x^{(1)}$")
    plt.ylabel("$x^{(2)}$")
    plt.title("Machine Learning: KD Tree")
    plt.xlim(0,6)
    plt.ylim(0,6)
    ax = plt.gca()
    ax.set_aspect(1)
    for node in kd:
        plt.scatter( node.point[0], node.point[1], color='g' )
    plt.show()
    pass

def find_knn(root, x):
    pass


def main():
    kd = None
    kd = CreateKDT(kd, X)

    #preorder(kd)
    for node in kd:
        print( '*' * 50 )
        print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
        print( '*' * 50 )
    #preorder(kd)
    kd.clear_trav()
    #preorder(kd)
    draw_KDT(kd)

if __name__ == "__main__":
    main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值