Difference between revisions of "User:Chase-san/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
(Updated, now with nearest N neighbors)
(Changed distance to use distanceSq in comparator)
Line 66: Line 66:
 
@Override
 
@Override
 
public int compare(PointKD o1, PointKD o2) {
 
public int compare(PointKD o1, PointKD o2) {
double dist0 = point.distance(o1);
+
double dist0 = point.distanceSq(o1);
double dist1 = point.distance(o2);
+
double dist1 = point.distanceSq(o2);
 
return (dist0-dist1 < 0) ? -1 : 1;
 
return (dist0-dist1 < 0) ? -1 : 1;
 
}
 
}

Revision as of 23:52, 1 March 2010

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

KDTreeB

package org.csdgn.util;

import java.util.Comparator;

public class KDTreeB {
	public final static int DEFAULT_BUCKET_SIZE = 200;
	protected NodeKD root;
	protected int dimensions;
	protected int bucket_size;
	protected long size = 0;
	
	/**
	 * This creates a KDTreeB with the given number of dimensions and
	 * the default bucket size.
	 * @param dimensions
	 */
	public KDTreeB(int dimensions) {
		this(dimensions,DEFAULT_BUCKET_SIZE);
	}
	/**
	 * This creates a KDTreeB with the given number of dimensions and
	 * the given bucket size.
	 * @param dimensions
	 * @param bucket_size
	 */
	public KDTreeB(int dimensions, int bucket_size) {
		this.dimensions = dimensions;
		this.bucket_size = bucket_size;
		root = new BucketKD(this);
	}
	
	/**
	 * Adds the given point to the tree. Uses a recursive algorithm.
	 * @param point
	 */
	public void add(PointKD point) {
		root.add(point);
	}
	
	/**
	 * Returns the nearest neighbor to the given point.
	 * @param point - PointKD to find the nearest to.
	 * @return - The nearest PointKD to 
	 */
	public PointKD getNearestNeighbor(PointKD point) {
		return root.nearest(point);
	}
	
	/**
	 * Returns the nearest num neighbors to the given point.
	 * @param point - PointKD to find the nearest to.
	 * @param num - The number of points to find.
	 * @return - The nearest PointKD's to point.
	 */
	public PointKD[] getNearestNeighbors(final PointKD point, int num) {
		if(num == 1) {
			return new PointKD[] { getNearestNeighbor(point) };
		}
		PriorityDeque<PointKD> queue = new PriorityDeque<PointKD>(
				new Comparator<PointKD>(){
					@Override
					public int compare(PointKD o1, PointKD o2) {
						double dist0 = point.distanceSq(o1);
						double dist1 = point.distanceSq(o2);
						return (dist0-dist1 < 0) ? -1 : 1;
					}
				},num);
		
		root.nearestn(queue,point);
		
		PointKD[] array = new PointKD[num];
		Object[] obj = queue.toArray();
		for(int i=0; i<num; ++i) {
			array[i] = (PointKD)obj[i];
		}
		
		return array;
	}
	
	/**
	 * Returns all PointKD within a certain RectangleKD.
	 * @param rect - area to get PointKD from
	 * @return - All PointKD within rect.
	 */
	public PointKD[] getRange(RectangleKD rect) {
		return root.range(rect);
	}

	/**
	 * Returns all PointKD within a certain range defined by an upper and lower PointKD.
	 * @param low - lower bounds of area
	 * @param high - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	public PointKD[] getRange(PointKD low, PointKD high) {
		return this.getRange(new RectangleKD(low, high));
	}
	
	protected abstract class NodeKD {
		protected KDTreeB owner;
		protected BranchKD parent;
		protected RectangleKD rect;
		protected int depth = 0;
		
		protected abstract void add(PointKD k);
		protected abstract PointKD nearest(PointKD k);
		protected abstract void nearestn(PriorityDeque<PointKD> queue, PointKD k);
		protected abstract PointKD[] range(RectangleKD r);
	}
	protected class BucketKD extends NodeKD {
		protected PointKD bucket[];
		protected int current;
		
		public BucketKD(KDTreeB owner) {
			this.owner = owner;
			bucket = new PointKD[owner.bucket_size];
			rect = new RectangleKD();
			parent = null;
		}
		public BucketKD(BranchKD branch) {
			this(branch.owner);
			parent = branch;
			depth = branch.depth + 1;
		}
		
		@Override
		protected void add(PointKD point) {
			if(current >= bucket.length) {
				//Split the bucket into a branch
				BranchKD branch = new BranchKD(this);
				if(parent == null) {
					owner.root = branch;
				} else {
					if(parent.isLeft(this)) {
						parent.left = branch;
					} else {
						parent.right = branch;
					}
				}
				branch.add(point);
				bucket = null;
				current = 0;
				return;
			}
			bucket[current++] = point;
			rect.expand(point);
		}

		@Override
		protected PointKD nearest(PointKD k) {
			double nearestDist = Double.POSITIVE_INFINITY;
			int nearest = 0;
			for (int i = 0; i < current; i++) {
				double distance = k.distanceSq(bucket[i]);
				if (distance < nearestDist) {
					nearestDist = distance;
					nearest = i;
				}
			}
			return bucket[nearest];
		}
		
		@Override
		protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
			for(int i = 0; i < current; i++) {
				queue.offer(bucket[i]);
			}
		}

		@Override
		protected PointKD[] range(RectangleKD r) {
			PointKD[] tmp = new PointKD[current];
			int n = 0;
			for (int i = 0; i < current; i++) {
				if (r.contains(bucket[i])) {
					tmp[n++] = bucket[i];
				}
			}
			PointKD[] tmp2 = new PointKD[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}
		
	}
	protected class BranchKD extends NodeKD {
		protected NodeKD left, right;
		protected double slice;
		protected int dim;
		
		public BranchKD(BucketKD k) {
			owner = k.owner;
			parent = k.parent;
			slice = 0;
			rect = k.rect;
			depth = k.depth;
			left = new BucketKD(this);
			right = new BucketKD(this);
			
			dim = depth % owner.dimensions;
			double total = 0;
			for (int i = 0; i < k.current; i++) {
				total += k.bucket[i].internal[dim];
			}
			slice = total / k.current;
			for (int i = 0; i < k.current; i++)
				add(k.bucket[i]);
		}
		
		@Override
		protected void add(PointKD k) {
			if(k.internal[dim] > slice) {
				right.add(k);
			} else {
				left.add(k);
			}
		}

		@Override
		protected PointKD nearest(PointKD k) {
			PointKD near = null;
			if (k.internal[dim] > slice) {
				near = right.nearest(k);
				double t = near.distanceSq(k);
				if (k.distanceSq(left.rect.getNearest(k)) < t) {
					PointKD tmp = left.nearest(k);
					if (tmp.distanceSq(k) < t) {
						near = tmp;
					}
				}
			} else {
				near = left.nearest(k);
				double t = near.distanceSq(k);
				if (k.distanceSq(right.rect.getNearest(k)) < t) {
					PointKD tmp = right.nearest(k);
					if (tmp.distanceSq(k) < t) {
						near = tmp;
					}
				}
			}
			return near;
		}
		
		@Override
		protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
			//TODO
			if(k.internal[dim] > slice) {
				right.nearestn(queue, k);
				double t = queue.peekBottom().distanceSq(k);
				if(k.distanceSq(left.rect.getNearest(k)) < t) {
					left.nearestn(queue, k);
				}
			} else {
				left.nearestn(queue, k);
				double t = queue.peekBottom().distanceSq(k);
				if(k.distanceSq(right.rect.getNearest(k)) < t) {
					right.nearestn(queue, k);
				}
			}
		}

		@Override
		protected PointKD[] range(RectangleKD r) {
			PointKD[] tmp = new PointKD[0];
			if (r.intersects(left.rect)) {
				PointKD[] tmpl = left.range(r);
				if(0 == tmp.length)
					tmp = tmpl;
			}
			if (r.intersects(right.rect)) {
				PointKD[] tmpr = right.range(r);
				if (0 == tmp.length)
					tmp = tmpr;
				else if (0 < tmpr.length) {
					PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
					System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
					System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
					tmp = tmp2;
				}
			}
			return tmp;
		}
		
		protected boolean isLeft(BucketKD kd) {
			if(left == kd) return true;
			return false;
		}
	}
}

PointKD

package org.csdgn.util;

import java.io.Serializable;

/**
 * PointKD class is a class that wraps a double array, for use in K-Dimensional structures.
 * This wrapping is done to improve readability and modularity. Often used to define a
 * point in k-dimensional space.
 * 
 * @author Chase
 */
public class PointKD implements Serializable {
	private static final long serialVersionUID = -841162798668123755L;
	protected double[] internal;
	
	/**
	 * Constructor
	 * @param dimensions - Number of dimensions for this PointKD.
	 */
	public PointKD(int dimensions) {
		internal = new double[dimensions];
	}
	
	/**
	 * Constructor
	 * @param array - An array to use for the internal array of this PointKD.
	 */
	public PointKD(double[] array) {
		internal = array.clone();
	}
	
	/**
	 * Constructor
	 * @param point - PointKD to clone.
	 */
	public PointKD(PointKD point) {
		this.internal = point.internal.clone();
	}
	
	
	/**
	 * Returns array.
	 * @return The internal array of this pointKD, changes made to this will effect this PointKD.
	 */
	public double[] get() {
		return internal;
	}
	
	/**
	 * Sets the location of this point.
	 * @param point - PointKD to copy.
	 */
	public void set(PointKD point) {
		this.internal = point.internal.clone();
	}
	
	/**
	 * Sets the location of this point.
	 * @param array - Array to copy.
	 */
	public void set(double[] array) {
		this.internal = array.clone();
	}
	
	/**
	 * Sets the value at dimension index
	 * @param index - Index
	 * @param value - Value to set
	 */
	public void set(int index, double value) {
		internal[index] = value;
	}
	
	/**
	 * @return number of dimensions
	 */
	public int size() {
		return internal.length;
	}
	
	/**
	 * Compares this to a selected point and returns the euclidean distance
	 * between them.
	 * 
	 * @param p
	 *            - The Point to get the distance to.
	 * @return The distance between this and <b>p</b>.
	 */
	public double distance(PointKD p) {
		return distance(this.internal, p.internal);
	}

	/**
	 * Compares this to a selected point and returns the squared euclidean
	 * distance between them.
	 * 
	 * @param p
	 *            - The Point to get the distance to.
	 * @return The distance between this and <b>p</b>.
	 */
	public double distanceSq(PointKD p) {
		return distanceSq(this.internal, p.internal);
	}
	
	/**
	 * Clones this point.
	 */
	public PointKD clone() {
		return new PointKD(internal);
	}
	
	/**
	 * Prints the class name and the point coordinates.
	 */
	public String toString() {
		String output = getClass().getSimpleName() + "[";
		for (int i = 0; i < internal.length; ++i) {
			if (0 != i)
				output += ",";
			output += internal[i];
		}
		return output + "]";
	}
	
	/**
	 * Compares arrays of double and returns the euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distance(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return Math.sqrt(total);
	}
	
	/**
	 * Compares arrays of double and returns the squared euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distanceSq(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return total;
	}
}

RectangleKD

package org.csdgn.util;

import java.io.Serializable;

public class RectangleKD implements Serializable {
	private static final long serialVersionUID = 524388821816648020L;
	protected PointKD upper, lower;
	/**
	 * Creates an empty RectangleKD
	 */
	public RectangleKD() {
		upper = null;
		lower = null;
	}
	/**
	 * Creates a RectangleKD from two points.
	 * @param lower
	 * @param upper
	 */
	public RectangleKD(PointKD lower, PointKD upper) {
		this.lower = lower.clone();
		this.upper = upper.clone();
	}
	
	/**
	 * Expands this RectangleKD to include point.
	 * @param point - Point used to expand this Rectangle
	 */
	public void expand(PointKD point) {
		if(upper == null) {
			upper = point.clone();
			lower = point.clone();
			return;
		}
		for(int i=0; i<point.size(); ++i) {
			if(upper.internal[i] < point.internal[i])
				upper.internal[i] = point.internal[i];
			if(lower.internal[i] > point.internal[i])
				lower.internal[i] = point.internal[i];
		}
	}
	
	/**
	 * Checks if this RectangleKD contains the given point.
	 * @param point - PointKD to check.
	 * @return true - if this rectangle contains the given point.
	 */
	public boolean contains(PointKD point) {
		if(upper == null) return false;
		for(int i=0; i<point.size(); ++i) {
			if(point.internal[i] > upper.internal[i] ||
				point.internal[i] < lower.internal[i]) 
					return false;
		}
		return true;
	}
	
	/**
	 * Checks if this table intersects the given table.
	 * @param rectangle - The rectangle to check intersection with.
	 * @return true - if this rectangle intersects the given rectangle.
	 */
	public boolean intersects(RectangleKD rectangle) {
		if(rectangle.upper == null || upper == null) return false;
		for(int i=0; i<upper.size(); ++i) {
			if(rectangle.upper.internal[i] < lower.internal[i]
			|| rectangle.lower.internal[i] > upper.internal[i]) return false;
		}
		return true;
	}
	
	/**
	 * Gets the nearest point in this RectangleKD to the given point
	 * @param point - PointKD to get the nearest point to.
	 * @return the nearest point in this RectangleKD to the given point.
	 */
	public PointKD getNearest(PointKD point) {
		if(upper == null) return null;
		if(contains(point)) return point; //This check may not be needed.
		PointKD nearest = new PointKD(point);
		for(int i = 0; i < upper.size(); ++i) {
			if(nearest.internal[i] > upper.internal[i])
				nearest.internal[i] = upper.internal[i];
			if(nearest.internal[i] < lower.internal[i])
				nearest.internal[i] = lower.internal[i];
		}
		return nearest;
	}
}

PriorityDeque

package org.csdgn.util;

import java.util.Comparator;

/**
 * This is a home made PriorityDeque with maximum size limiter, and
 * comparator or natural ordering selection.
 * 
 * @author Chase
 * @param <E>
 */
public class PriorityDeque<E> {
	private class Item {
		public Item down, up;
		public E obj;
		public Item(E item) { 
			obj = item;
			down = up = null;
		}
	}
	private final Comparator<? super E> comparator;
	private Item bottom, top;
	private int maximum_size;
	private int size;
	
	/**
	 * Generic Constructor
	 */
	public PriorityDeque() {
		this(null,-1);
	}
	/**
	 * Constructor with a defined maximum number of entries
	 * @param maximum
	 */
	public PriorityDeque(int maximum) {
		this(null,maximum);
	}
	/**
	 * Constructor with a defined comparator and an unlimited number of items.
	 * @param comp
	 */
	public PriorityDeque(Comparator<? super E> comp) {
		this(comp,-1);
	}
	/**
	 * Constructor with a defined conparator and limited number of items.
	 * @param comp
	 * @param maximum
	 */
	public PriorityDeque(Comparator<? super E> comp, int maximum) {
		comparator = comp;
		maximum_size = maximum;
		bottom = top = null;
		size = 0;
	}
	/**
	 * Adds an item to this deque.
	 * @param value - item to add
	 */
	public void offer(E value) {
		//System.err.println("Offering: " + value.toString());
		if(bottom == null) {
			//System.err.println("-Is first item.");
			bottom = top = new Item(value);
			return;
		}
		//do ordering etc
		if(comparator != null) {
			//System.err.println("-Comparator.");
			offerComparator(value);
		} else {
			//System.err.println("-Natural.");
			offerNatural(value);
		}
	}
	
	/**
	 * Removes and returns an item from the bottom of the list (lowest value)
	 * @return the lowest value (bottom)
	 */
	public E pollBottom() {
		if(bottom == null) return null;
		Item tmp = bottom;
		if(bottom == top) bottom = top = null;
		else {
			bottom = bottom.up;
			bottom.down = null;
			tmp.up = null;
		}
		--size;
		return tmp.obj;
	}
	/**
	 * Returns but does not remove the item from the bottom of the list.
	 * @return the lowest value (bottom)
	 */
	public E peekBottom() {
		if(bottom == null) return null;
		return bottom.obj;
	}
	/**
	 * Removes and returns an item from the top of the list (highest value).
	 * @return the highest value (top)
	 */
	public E pollTop() {
		if(top == null) return null;
		Item tmp = top;
		if(bottom == top) bottom = top = null;
		else {
			top = top.down;
			top.up = null;
			tmp.down = null;
		}
		--size;
		return tmp.obj;
	}
	/**
	 * Returns but does not remove the item from the top of the list (highest value).
	 * @return the highest value (top)
	 */
	public E peekTop() {
		if(top == null) return null;
		return top.obj;
	}
	/**
	 * This is a generic toArray, returns a non-castable Object
	 * array with all the items in this deque.
	 * @return an Object array with everything in this deque.
	 */
	public Object[] toArray() {
		Object array[] = new Object[size+1];
		Item current = bottom;
		int i = 0;
		while(current != null) {
			array[i++] = current.obj;
			current = current.up;
		}
		return array;
	}
	/**
	 * Add/sorts a value using the comparator
	 * @param value
	 */
	private void offerComparator(E value) {
		if(maximum_size > 0 && size >= maximum_size) {
			if(comparator.compare(value, top.obj) >= 0) {
				return;
			}
		}
		Item nItem = new Item(value);
		if(comparator.compare(value,bottom.obj) < 0) {
			//less than the bottom, put on the bottom
			nItem.up = bottom;
			bottom.down = nItem;
			bottom = nItem;
			++size;
			if(maximum_size > 0 && size > maximum_size) {
				pollTop();
			}
			return;
		}
		
		//start at the bottom
		Item current = bottom;
		while(current != null) {
			if(comparator.compare(value,current.obj) < 0) {
				nItem.up = current;
				nItem.down = current.down;
				current.down.up = nItem;
				current.down = nItem;
				++size;
				if(maximum_size > 0 && size > maximum_size) {
					pollTop();
				}
				return;
			}
			current = current.up;
		}
		
		//else put it on top
		nItem.down = top;
		top.up = nItem;
		top = nItem;
		++size;
		if(maximum_size > 0 && size > maximum_size) {
			pollTop();
		}
	}
	
	/**
	 * Add/sorts a value using the natural order
	 * @param value
	 */
	@SuppressWarnings("unchecked")
	private void offerNatural(E value) {
		Comparable<? super E> key = (Comparable<? super E>)value;
		if(maximum_size > 0 && size >= maximum_size) {
			if(key.compareTo(top.obj) >= 0) {
				return;
			}
		}
		Item nItem = new Item(value);
		if(key.compareTo(bottom.obj) < 0) {
			//less than the bottom, put on the bottom
			nItem.up = bottom;
			bottom.down = nItem;
			bottom = nItem;
			++size;
			if(maximum_size > 0 && size > maximum_size) {
				pollTop();
			}
			return;
		}
		
		//start at the bottom
		Item current = bottom;
		while(current != null) {
			if(key.compareTo(current.obj) < 0) {
				nItem.up = current;
				nItem.down = current.down;
				current.down.up = nItem;
				current.down = nItem;
				++size;
				if(maximum_size > 0 && size > maximum_size) {
					pollTop();
				}
				return;
			}
			current = current.up;
		}
		
		//else put it on top
		nItem.down = top;
		top.up = nItem;
		top = nItem;
		++size;
		if(maximum_size > 0 && size > maximum_size) {
			pollTop();
		}		
	}
}