from operator import itemgetter
class Node(object):
def __init__(self):
self.data = []
self.l_child = None
self.r_child = None
class kdtree(Node):
def create_tree(self, tree, point_list, depth=0):
try:
k = len(point_list[0])
except IndexError as e:
return None
axis = depth % k
point_list.sort(key=itemgetter(axis))
median = len(point_list) // 2
if point_list is not None:
tree.data = point_list[median]
tree.l_child = Node()
tree.r_child = Node()
self.create_tree(tree.l_child, point_list[:median], depth + 1)
self.create_tree(tree.r_child, point_list[median + 1:], depth + 1)
def visit(self, tree):
if tree is not None:
print str(tree.data) + '\t',
def pre_order(self, tree):
if tree is not None:
self.visit(tree)
self.pre_order(tree.l_child)
self.pre_order(tree.r_child)
def in_order(self, tree):
if tree is not None:
self.in_order(tree.l_child)
self.visit(tree)
self.in_order(tree.r_child)
def post_order(self, tree):
if tree is not None:
self.post_order(tree.l_child)
self.post_order(tree.r_child)
self.visit(tree)
def main():
"""Example usage"""
point_list = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
t = Node()
tree = kdtree()
tree.create_tree(t, point_list)
tree.pre_order(t)
print
tree.in_order(t)
print
tree.post_order(t)
if __name__ == '__main__':
main()
高级版本
from operator import itemgetter
class KdNode(object):
def __init__(self, median, l_child, r_child):
self.median = median
self.l_child = l_child
self.r_child = r_child
class KdTree(object):
def __init__(self, data):
self.k = len(data[0])
self.data = data
self.root = self.CreateNode(data, 0)
self.preorder(self.root)
def CreateNode(self, data_set, depth):
if not data_set:
return None
axis = depth % self.k
data_set.sort(key=itemgetter(axis))
median = len(data_set) // 2
return KdNode(data_set[median],
self.CreateNode(data_set[:median], depth + 1),
self.CreateNode(data_set[median + 1:], depth + 1) )
def visit(self, node):
if node:
print node.median
def preorder(self, root):
if root:
self.visit(root)
self.preorder(root.l_child)
self.preorder(root.r_child)
if __name__ == "__main__":
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KdTree(data)