public class KdTree {
private int n;
private Node root;
private static final int LEFT = 1;
private static final int RIGHT = 2;
private static final int BOTTOM = 3;
private static final int TOP = 4;
public KdTree() // construct an empty set of points
{
root = null;
n = 0;
}
private static class Node {
private Point2D p; // the point
private RectHV rect; // the axis-aligned rectangle corresponding to this node
private Node lb; // the left/bottom subtree
private Node rt; // the right/top subtree
public Node(Point2D p,RectHV rect)
{
this.p = p;
this.rect = rect;
}
}
public boolean isEmpty() // is the set empty?
{
return n == 0;
}
public int size() // number of points in the set
{
return n;
}
private Node insert(Node x,Point2D p,int depth,double original,int orientation,RectHV rect)
{
if (x == null)
{
if (depth == 0)
{
n++;
return new Node(p,new RectHV(0,0,1,1));
}
else
{
double xmin = rect.xmin();
double xmax = rect.xmax();
double ymin = rect.ymin();
double ymax = rect.ymax();
switch (orientation)
{
case LEFT : xmax = original;break;
case RIGHT: xmin = original;break;
case TOP: ymin = original;break;
case BOTTOM: ymax = original;break;
default :break;
}
//StdOut.print(original);
n++;
return new Node(p,new RectHV(xmin,ymin,xmax,ymax));
}
}
if (x.p.equals(p)) return x;
if (depth % 2 == 0)
{
if (p.x() < x.p.x()) x.lb = insert(x.lb,p,depth + 1,x.p.x(),LEFT,x.rect);
else x.rt = insert(x.rt,p,depth + 1,x.p.x(),RIGHT,x.rect);
}
else
{
if (p.y() < x.p.y()) x.lb = insert(x.lb,p,depth + 1,x.p.y(),BOTTOM,x.rect);
else x.rt = insert(x.rt,p,depth + 1,x.p.y(),TOP,x.rect);
}
return x;
}
public void insert(Point2D p) // add the point p to the set (if it is not already in the set)
{
root = insert(root,p,0,0,0,new RectHV(1,1,1,1));
}
private boolean contains(Node x,Point2D p,int depth)
{
if (x == null)
return false;
if (x.p.equals(p)) return true;
if (depth % 2 == 0)
{
if (p.x() < x.p.x()) return contains(x.lb,p,depth + 1);
else return contains(x.rt,p,depth + 1);
}
else
{
if (p.y() < x.p.y()) return contains(x.lb,p,depth + 1);
else return contains(x.rt,p,depth + 1);
}
}
public boolean contains(Point2D p) // does the set contain the point p?
{
return contains(root,p,0);
}
private void draw(Node x,int depth)
{
if (x == null) return;
if (depth % 2 == 0)
{
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(.01);
x.p.draw();
StdDraw.setPenColor(StdDraw.RED);
StdDraw.setPenRadius();
x.rect.draw();
}
else
{
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(.01);
x.p.draw();
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.setPenRadius();
x.rect.draw();
}
draw(x.lb,depth + 1);
draw(x.rt,depth + 1);
}
public void draw() // draw all of the points to standard draw
{
draw(root,0);
}
private void range(Node x,RectHV rect,Queue q,int depth)
{
if (x == null || !x.rect.intersects(rect)) return;
if (rect.contains(x.p)) q.enqueue(x.p);
if (depth % 2 == 0)
{
double xmin = rect.xmin();
double xmax = rect.xmax();
if (x.p.x() >= xmin && x.p.x() <= xmax)
{
range(x.lb,rect,q,depth + 1);
range(x.rt,rect,q,depth + 1);
}
else if (x.p.x() < xmin)
range(x.rt,rect,q,depth + 1);
else
range(x.lb,rect,q,depth + 1);
}
else
{
double ymin = rect.ymin();
double ymax = rect.ymax();
if (x.p.y() >= ymin && x.p.y() <= ymax)
{
range(x.lb,rect,q,depth + 1);
range(x.rt,rect,q,depth + 1);
}
else if (x.p.y() < ymin)
range(x.rt,rect,q,depth + 1);
else
range(x.lb,rect,q,depth + 1);
}
}
public Iterable<Point2D> range(RectHV rect) // all points in the set that are inside the rectangle
{
Queue<Point2D> q = new Queue<Point2D>();
range(root,rect,q,0);
return q;
}
private Node smallerToP(Node x,Node y,Point2D p)
{
if (x == null) return y;
if (y == null) return x;
return x.p.distanceSquaredTo(p) <= y.p.distanceSquaredTo(p) ? x : y;
}
private Node nearest(Node x,Point2D p,Node shortNode,int depth)
{
if (shortNode == null) shortNode = x;
if (x == null || shortNode.p.distanceTo(p) < x.rect.distanceTo(p)) return null;
shortNode = smallerToP(shortNode,x,p);
Node q;
if (depth % 2 == 0)
{
if (p.x() < x.p.x())
{
q = nearest(x.lb,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
q = nearest(x.rt,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
}
else
{
q = nearest(x.rt,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
q = nearest(x.lb,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
}
}
else
{
if (p.y() < x.p.y())
{
q = nearest(x.lb,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
q = nearest(x.rt,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
}
else
{
q = nearest(x.rt,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
q = nearest(x.lb,p,shortNode,depth + 1); shortNode = smallerToP(shortNode,q,p);
}
}
return shortNode;
}
public Point2D nearest(Point2D p) // a nearest neighbor in the set to p; null if set is empty
{
Node q = nearest(root,p,null,0);
if (q == null) return null;
return q.p;
}
}