《情景剧》
看客:KD树?不用递归遍历?也不使用队列?别逗了小子。
我: 您别不信。我就是这么厉害,就是这么自信。
看客:那你敢不敢放出来看看?
我: 那您瞧好嘞……
正文
在本文中,我要给大家介绍一个用Python迭代器实现的KD树(二叉树)遍历的方法。迭代的介绍,可以参考我写的这篇文章。
网址:http://blog.youkuaiyun.com/weixin_37722024/article/details/62424311
在这里,我就不细说迭代器的写法了,而只是对遍历时需要用到的几个成员变量和成员函数 __next__() 中的遍历的逻辑作出说明。
Python迭代器的帮助:https://docs.python.org/3/tutorial/classes.html#iterators
一、KD树的类的构造
构造KD树也就是二叉树的方法比较简单,就是首先构造根节点,然后用递归地方法构造根节点的左子树和右子树。这不在本文的讨论范围。
KD_Node 类中,定义了二叉树必不可少的节点值、左子树和右子树,以及KD树需要的每个节点的 split 。除了这几个基本成员以外,新增了三个成员变量:一个是所有类实例公用的 cur_trav ,这个成员变量保存了一个类型为 KD_Node 的节点,用作搜索时的游标;类实例的成员变量 flag_trav ,是 int 型,保存每一个节点的在遍历过程中的状态;类实例的成员变量 father ,是 KD_Node 型,保存了每个节点的父节点。
这样的结构,每个节点增加一个 int 变量,类增加一个游标节点由所有节点(也是类的实例)共享,增加的内存开销有限,却带来了极大的灵活性,也使得代码变得简洁。用同样的遍历过程,不止能获得各个节点的值,还可以用matplotlib绘制出KD树的空间分割图,甚至模拟KNN的搜索过程。我在draw_KDT函数里就想这么做,还没做完。
二、迭代过程
前面说过了搜索的逻辑都写在 __next__() 里了,具体的流程还请看一下代码。我在开始写迭代器的时候是想用递归写 __next__() 来着,试过,不行,放弃,另起楼灶,就鼓捣出这么个玩意儿。
其中,用作遍历的几个变量,在KD树刚刚建立好后都保持初始值。在开始遍历时,拷贝根节点到游标 cur_trav ,这就是当前搜索的节点。在搜索的过程中,根据游标所在的当前节点的状态 flag_trav ,来确定是返回当前节点,或者返回他的左子节点,又或者返回右子节点。所以这个状态只使用了最低3个bit,bit0代表当前节点是否已经遍历过,bit1代表当前节点的左子节点是否遍历过,bit2代表当前节点的右子节点是否遍历过。而 father 是为了能够返回父节点,随后能够搜索到另外一个子树。
节点状态 flag_trav 的变化规律:1、初始值为0,表示当前节点、左子和右子都未被搜索过;2、算是中序遍历,也就是说先搜当前节点,然后左,最后右。所以[bit2,bit1,bit0]的状态只有:[000], [001],[011],[111]四种状态;3、如果子节点不存在,则跳过。左右子节点都不存在,就跳到当前节点的父节点。
三、代码
直接把几段关键的贴出来,懒得看长篇代码的同学,就看着几段意思意思吧。
代码段1,类的声明:
class KD_Node:
cur_trav = None # cursor for traversal.
def __init__(self, point=None, split=None, L=None, R=None, father=None):
"""
initiate a kd tree.
point: datum of this node
split: split plane for this node
L: left son
R: right son
father: father of this node, if root it's None
"""
self.point = point
self.split = split
self.left = L
self.right = R
self.father = father
self.flag_trav = 0 # traversal flag.
# bit 0 is notation for itself
# bit 1 is for its left son
# bit 2 is for its right son
......
代码段2, __next__() 函数
def __next__(self):
# with non-iteration traverse the tree
cursor = None
if KD_Node.cur_trav == None: # First time to use cur_trav, initiate.
KD_Node.cur_trav = self
cursor = KD_Node.cur_trav
while 1:
if cursor.flag_trav & 0X07 == 0X7: # any node has flag with
# value=3
# that states a completion
# of traversal.
if cursor.father == None:
raise StopIteration
else:
cursor = cursor.father
elif cursor.flag_trav & 0X01 == 0: # if bit0 == 0,
cursor.flag_trav |= 0X01 # set bit0 = 1
#cursor = cursor # not need. set cursor => self
break # BREAK! return current.
elif cursor.flag_trav & 0X02 == 0: # if bit1==0, bit2==0
cursor.flag_trav |= 0X02 # set bit1 of self
if cursor.left != None:
cursor = cursor.left # set cursor => left son
else: # self.left is None, skip
continue
elif cursor.flag_trav & 0X04 == 0: # if bit2 == 0,
cursor.flag_trav |= 0X04 # set bit2 = 1
if cursor.right != None:
cursor = cursor.right # set cursor => right son
else:
continue
KD_Node.cur_trav = cursor
return KD_Node.cur_trav
代码段3,简洁的遍历
def main():
kd = None
kd = CreateKDT(kd, X)
for node in kd:
print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
完整代码如下
import numpy as np
import matplotlib.pyplot as plt
"""
X, feature vectors
Y, class of X
X_with_class, I just use this to draw a graphic in the piece of code at \
the bottom.
D, dimension of each of vectors.
"""
# Construct initial to be classified data
X = np.array([ (3,5), (2,4), (1,1), (5,2), (1,5), (4,1) ])
Y = [ 'g', 'g', 'r', 'r', 'g', 'r' ]
X_with_class = [ [X[a,0],X[a,1],Y[a]] for a in range(len(X)) ]
D = 0
if len(X[0]):
D = len(X[0])
class KD_Node:
cur_trav = None # cursor for traversal.
def __init__(self, point=None, split=None, L=None, R=None, father=None):
"""
initiate a kd tree.
point: datum of this node
split: split plane for this node
L: left son
R: right son
father: father of this node, if root it's None
"""
self.point = point
self.split = split
self.left = L
self.right = R
self.father = father
self.flag_trav = 0 # traversal flag.
# bit 0 is notation for itself
# bit 1 is for its left son
# bit 2 is for its right son
def clear_trav(self):
KD_Node.cur_trav = None
self.flag_trav = 0
if self.left:
self.left.clear_trav()
if self.right:
self.right.clear_trav()
def __iter__(self):
return self
def __next__(self):
# with non-iteration traverse the tree
cursor = None
if KD_Node.cur_trav == None: # First time to use cur_trav, initiate.
KD_Node.cur_trav = self
cursor = KD_Node.cur_trav
while 1:
if cursor.flag_trav & 0X07 == 0X7: # any node has flag with
# value=3
# that states a completion
# of traversal.
if cursor.father == None:
raise StopIteration
else:
cursor = cursor.father
elif cursor.flag_trav & 0X01 == 0: # if bit0 == 0,
cursor.flag_trav |= 0X01 # set bit0 = 1
#cursor = cursor # not need. set cursor => self
break # BREAK! return current.
elif cursor.flag_trav & 0X02 == 0: # if bit1==0, bit2==0
cursor.flag_trav |= 0X02 # set bit1 of self
if cursor.left != None:
cursor = cursor.left # set cursor => left son
else: # self.left is None, skip
continue
elif cursor.flag_trav & 0X04 == 0: # if bit2 == 0,
cursor.flag_trav |= 0X04 # set bit2 = 1
if cursor.right != None:
cursor = cursor.right # set cursor => right son
else:
continue
KD_Node.cur_trav = cursor
# print("cursor:", KD_Node.cur_trav, "self flag:", KD_Node.cur_trav.flag_trav)
return KD_Node.cur_trav
def CreateKDT(node=None, data=None, father=None):
"""
TODO: DOC FOR CreateKDT
INPUT: node,
data, [ (3,5), (2,4), (1,1) ]
father, the father
OUTPUT:
"""
# variance for each dimension, the biggest is desirable.
dim = D
var = np.var(data, axis=0)
split = np.argmax(var)
# Calc out the position of current node. Here the middle.
pos = int(len(data)/2)
# Using current split plane to split current data
pos_list = np.argpartition(data[:,split], pos)
point = data[pos_list[pos]]
"""
print procedure
print("#"*20)
print("data:",list(data))
print("split:",split,",",data[:,split])
print("pos list: ",pos_list)
print("pos:",pos)
print("the node data:",data[pos_list[pos]] )
"""
"""
print procedure
print("DEBUG: LEFT=", data[pos_list[:pos]] )
print("DEBUG: RIGHT=", data[pos_list[(pos+1):]] )
"""
node = KD_Node(point, split, father=father)
if len(data) > 1:
if len(data[pos_list[:pos]]) != 0:
node.left = CreateKDT(node.left, data[pos_list[:pos]], node)
if len(data[pos_list[(pos+1):]]) != 0:
node.right = CreateKDT(node.right, data[pos_list[(pos+1):]], node)
return node
def get_split_pos(data, split):
"""return the position to split in data."""
pos = len(data)/2
return
def preorder(node, depth=-1):
"""
Preorder a KD node
"""
if node:
s = '#' + '-'*50 + '#\n'
s += 'Node:' + str(node) + '\n'
s += 'Point:' + str(node.point) \
+ ', Flag: ' + str(bin(node.flag_trav)) \
+ ', Cursor:' + str(KD_Node.cur_trav) \
+ '\n'
s += "Father:" + str(node.father) + '\n'
s += "L:" + str(node.left) + '\n'
s += "R:" + str(node.right)
print(s)
if node.left:
preorder(node.left)
if node.right:
preorder(node.right)
def draw_KDT(kd):
"""
Draw a plot in which each of data determined by a point and draw the classifying plane.
"""
plt.figure(figsize=(6,6))
plt.xlabel("$x^{(1)}$")
plt.ylabel("$x^{(2)}$")
plt.title("Machine Learning: KD Tree")
plt.xlim(0,6)
plt.ylim(0,6)
ax = plt.gca()
ax.set_aspect(1)
for node in kd:
plt.scatter( node.point[0], node.point[1], color='g' )
plt.show()
pass
def find_knn(root, x):
pass
def main():
kd = None
kd = CreateKDT(kd, X)
#preorder(kd)
for node in kd:
print( '*' * 50 )
print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
print( '*' * 50 )
#preorder(kd)
kd.clear_trav()
#preorder(kd)
draw_KDT(kd)
if __name__ == "__main__":
main()