KD树简单实现

本文介绍了一种高效的空间数据结构——Kd树的实现细节。包括如何构建空树、插入节点、查找节点等核心操作,并实现了绘制树结构、范围查询及最近邻搜索等功能。
public class KdTree {
   private int n;
   private Node root;
   private static final int LEFT = 1;
   private static final int RIGHT = 2;
   private static final int BOTTOM = 3;
   private static final int TOP = 4;
   public KdTree()                               // construct an empty set of points
   {
       root = null;
       n = 0;
   }
   private static class Node {
       private Point2D p;      // the point
       private RectHV rect;    // the axis-aligned rectangle corresponding to this node
       private Node lb;        // the left/bottom subtree
       private Node rt;        // the right/top subtree
       public Node(Point2D p,RectHV rect)
       {
           this.p = p;
           this.rect = rect;

       }

    }
   public boolean isEmpty()                        // is the set empty?
   {
       return n == 0;
   }
   public int size()                               // number of points in the set
   {
       return n;
   }
   private Node insert(Node x,Point2D p,int depth,double original,int orientation,RectHV rect)
   {
       if (x == null) 
       {
           if (depth == 0)
           {
               n++;
               return new Node(p,new RectHV(0,0,1,1));
           }
           else
           {
               double xmin = rect.xmin();
               double xmax = rect.xmax();
               double ymin = rect.ymin();
               double ymax = rect.ymax();
               switch (orientation)
               {

                   case LEFT : xmax = original;break;
                   case RIGHT: xmin = original;break;
                   case TOP: ymin = original;break;
                   case BOTTOM: ymax = original;break;
                   default :break;
               }
               //StdOut.print(original);
               n++;
               return new Node(p,new RectHV(xmin,ymin,xmax,ymax));
           }
       }
       if (x.p.equals(p)) return x;
       if (depth % 2 == 0)
       {
           if (p.x() < x.p.x()) x.lb = insert(x.lb,p,depth + 1,x.p.x(),LEFT,x.rect);
           else x.rt = insert(x.rt,p,depth + 1,x.p.x(),RIGHT,x.rect);
       }
       else
       {
           if (p.y() < x.p.y()) x.lb = insert(x.lb,p,depth + 1,x.p.y(),BOTTOM,x.rect);
           else x.rt = insert(x.rt,p,depth + 1,x.p.y(),TOP,x.rect);
       }

       return x;

   }
   public void insert(Point2D p)                   // add the point p to the set (if it is not already in the set)
   {
       root = insert(root,p,0,0,0,new RectHV(1,1,1,1));

   }
   private boolean contains(Node x,Point2D p,int depth)
   {
       if (x == null) 
           return false;
       if (x.p.equals(p)) return true;
       if (depth % 2 == 0)
       {
           if (p.x() < x.p.x()) return contains(x.lb,p,depth + 1);
           else return contains(x.rt,p,depth + 1);
       }
       else
       {
           if (p.y() < x.p.y()) return contains(x.lb,p,depth + 1);
           else return contains(x.rt,p,depth + 1);
       }

   }
   public boolean contains(Point2D p)              // does the set contain the point p?
   {
       return contains(root,p,0);
   }
   private void draw(Node x,int depth)
   {
       if (x == null) return;
       if (depth % 2 == 0)
       {
           StdDraw.setPenColor(StdDraw.BLACK);
           StdDraw.setPenRadius(.01);
           x.p.draw();
           StdDraw.setPenColor(StdDraw.RED);
           StdDraw.setPenRadius();
           x.rect.draw();
       }
       else
       {
           StdDraw.setPenColor(StdDraw.BLACK);
           StdDraw.setPenRadius(.01);
           x.p.draw();
           StdDraw.setPenColor(StdDraw.BLUE);
           StdDraw.setPenRadius();
           x.rect.draw();
       }
       draw(x.lb,depth + 1);
       draw(x.rt,depth + 1);
   }
   public void draw()                              // draw all of the points to standard draw
   {
       draw(root,0);
   }
   private void range(Node x,RectHV rect,Queue q,int depth)
   {

       if (x == null || !x.rect.intersects(rect)) return;
       if (rect.contains(x.p)) q.enqueue(x.p);

       if (depth % 2 == 0)
       {
           double xmin = rect.xmin();
           double xmax = rect.xmax();
           if (x.p.x() >= xmin && x.p.x() <= xmax)
           {
               range(x.lb,rect,q,depth + 1);
               range(x.rt,rect,q,depth + 1);
           }
           else if (x.p.x() < xmin)
               range(x.rt,rect,q,depth + 1);
           else
               range(x.lb,rect,q,depth + 1);
       }
       else
       {
           double ymin = rect.ymin();
           double ymax = rect.ymax();
           if (x.p.y() >= ymin && x.p.y() <= ymax)
           {
               range(x.lb,rect,q,depth + 1);
               range(x.rt,rect,q,depth + 1);
           }
           else if (x.p.y() < ymin)
               range(x.rt,rect,q,depth + 1);
           else
               range(x.lb,rect,q,depth + 1);
       }

   }
   public Iterable<Point2D> range(RectHV rect)     // all points in the set that are inside the rectangle
   {
       Queue<Point2D> q = new Queue<Point2D>();
       range(root,rect,q,0);
       return q;
   }
   private Node smallerToP(Node x,Node y,Point2D p)
   {
       if (x == null) return y;
       if (y == null) return x;
       return x.p.distanceSquaredTo(p) <= y.p.distanceSquaredTo(p) ? x : y;
   }
   private Node nearest(Node x,Point2D p,Node shortNode,int depth)
   {
       if (shortNode == null) shortNode = x;
       if (x == null || shortNode.p.distanceTo(p) < x.rect.distanceTo(p)) return null;
       shortNode = smallerToP(shortNode,x,p);
       Node q;
       if (depth % 2 == 0)
       {
           if (p.x() < x.p.x()) 
           {
               q = nearest(x.lb,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
               q = nearest(x.rt,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
           }
           else
           {
               q = nearest(x.rt,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
               q = nearest(x.lb,p,shortNode,depth + 1);   shortNode = smallerToP(shortNode,q,p);
           }
       }
       else
       {

           if (p.y() < x.p.y())
           {
               q = nearest(x.lb,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
               q = nearest(x.rt,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
           }
           else
           {
               q = nearest(x.rt,p,shortNode,depth + 1);  shortNode = smallerToP(shortNode,q,p);
               q = nearest(x.lb,p,shortNode,depth + 1);   shortNode = smallerToP(shortNode,q,p);
           }
       }

       return shortNode;
   }
   public Point2D nearest(Point2D p) // a nearest neighbor in the set to p; null if set is empty              
   {
       Node q = nearest(root,p,null,0);
       if (q == null) return null;
       return q.p;

   }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值