import numpy as np
class KDTree():
def __init__(self, obj):
self.key = obj
self.lchild = None
self.rchild = None
def addlChild(self, obj):
if self.lchild == None:
self.lchild = obj
else:
t = KDTree(obj)
t.lchild = self.lchild
self.lchild = t
def addrChild(self, obj):
if self.rchild == None:
self.rchild = obj
else:
t = KDTree(obj)
t.rchild = self.rchild
self.rchild = t
def getRootVal(self):
return self.key
def splitLR(self, root, left, right):
self.key = root
def buildKD(self, data, depth):
dataNum = data.__len__()
if dataNum == 0:
return
else:
self.key = KDTree(data[0])
numAxis = 2
splitAxis = depth % numAxis
lchilds = []
rchilds = []
for i in range(dataNum):
if data[i][splitAxis] == self.calMedian([x[splitAxis] for x in data]): # 根
self.key = data[i]
else:
if data[i][splitAxis] < self.calMedian([x[splitAxis] for x in data]):
lchilds.append(data[i])
else:
rchilds.append(data[i])
self.lchild = KDTree(' ')
self.rchild = KDTree(' ')
self.lchild.buildKD(lchilds, depth + 1)
self.rchild.buildKD(rchilds, depth + 1)
def calMedian(self, data):
numElem = data.__len__() / 2
data = np.sort(data)
return data[numElem]
pass
def printTree(KDTree):
if ' ' == KDTree.lchild.key and ' ' == KDTree.rchild.key:
print KDTree.key
else:
printTree(KDTree.lchild)
printTree(KDTree.rchild)
if __name__ == "__main__":
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
tree = KDTree(' ')
tree.buildKD(data, 0)
printTree(tree)
python实现KD树模型
最新推荐文章于 2024-07-10 19:25:55 发布