User:Chase-san/Kd-Tree

From Robowiki
< User:Chase-san
Revision as of 09:12, 16 May 2009 by Chase-san (talk | contribs) (adding category)
Jump to navigation Jump to search


This is my version of the KDTree, I admit it has gotten a bit large, and it could be cut down a great deal, I will add a K nearest neighbors implimentation soon enough. Simonton and I worked on this at about the same time, however I was more experienced and it turns out the whole time I had nothing but a very elusive off by 1 error. I recently rebuilt my kd-tree into this. Which unhooks the points from the branches to speed things up, it also uses buckets, which are great fun.

Also unlike most others mine is fairly modular and includes a range search (good for those old fashioned pattern matchers!!).

Please forgive the incomplete documentation

KDTreeB

package org.csdgn.util;

public class KDTreeB {
	public final static int DEFAULT_BUCKET_SIZE = 100;
	protected int bucketSize;
	protected int dimensions;
	protected NodeKD root;
	protected PointKD keys[];
	protected int size;

	/**
	 * Initializes a new KDTreeB with a number of dimensions.
	 * 
	 * @param dim
	 *            - Dimensions
	 */
	public KDTreeB(int dim) {
		this(dim, DEFAULT_BUCKET_SIZE);
	}

	/**
	 * Initializes a new KDTreeB with a number of dimensions and bucket size
	 * 
	 * @param dim
	 *            - Dimensions
	 * @param buckets
	 *            - BucketKD Size
	 */
	public KDTreeB(int dim, int buckets) {
		if(dim < 1)
			System.err.println("Dimensions < 1: Undefined Behavior may occur.");
		if(buckets < 2)
			System.err.println("Bucket Size < 2: Undefined Behavior may occur.");
		
		bucketSize = buckets;
		dimensions = dim;
		keys = new PointKD[buckets];
		size = 0;
	}

	/**
	 * Adds a new PointKD to the Tree, uses a recursive algorithm.
	 * 
	 * @param k
	 *            - Point to add
	 */
	public void add(PointKD k) {
		if (null == root) {
			root = new BucketKD(this);
		}
		if (size >= keys.length) {
			// this is basically what a vector does, I am
			// just removing the overhead of using a vector
			PointKD tmp[] = new PointKD[keys.length * 2];
			System.arraycopy(keys, 0, tmp, 0, size);
			keys = tmp;
		}
		keys[size++] = k;
		root.add(k);
	}

	/**
	 * Finds the approximate nearest neighbor to PointKD k. This uses a
	 * recursive algorithm.
	 * 
	 * @param k
	 * @return
	 */
	public PointKD getApproxNN(PointKD k) {
		if (null == root)
			return null;
		return root.approx(k);
	}

	/**
	 * Finds the nearest neighbor to PointKD k. Uses an exhastive search
	 * method.
	 * @param k - 
	 * @return
	 */
	public PointKD getNN(PointKD k) {
		if (null == root)
			return null;
		return root.nearest(k);
	}

	public PointKD[] getRange(RectKD r) {
		if (null == root)
			return new PointKD[0];
		return root.range(r);
	}

	public PointKD[] getRange(PointKD low, PointKD high) {
		return this.getRange(new RectKD(low, high));
	}

}

NodeKD

package org.csdgn.util;

public abstract class NodeKD {
	protected KDTreeB ref;
	protected BranchKD parent;
	protected int depth;
	protected RectKD rect;

	protected abstract void add(PointKD k);
	protected abstract PointKD approx(PointKD k);
	protected abstract PointKD nearest(PointKD k);
	protected abstract PointKD[] range(RectKD r);
}

BranchKD

package org.csdgn.util;

public class BranchKD extends NodeKD {
	protected NodeKD left, right;
	protected double slice;

	protected BranchKD(KDTreeB reference) {
		slice = 0;
		ref = reference;
		rect = new RectKD();
		depth = 0;
		ref.root = this;
		left = new BucketKD(this);
		right = new BucketKD(this);
	}
	protected BranchKD(BranchKD b) {
		ref = b.ref;
		parent = b;
		slice = 0;
		rect = new RectKD();
		depth = b.depth + 1;
		left = new BucketKD(this);
		right = new BucketKD(this);
	}

	protected void add(PointKD k) {
		rect.add(k);
		int dim = depth % ref.dimensions;
		if (k.array[dim] > slice) {
			right.add(k);
		} else {
			left.add(k);
		}
	}

	protected PointKD approx(PointKD k) {
		int dim = depth % ref.dimensions;
		if (k.array[dim] > slice)
			return right.approx(k);
		return left.approx(k);
	}

	protected PointKD nearest(PointKD k) {
		int dim = depth % ref.dimensions;
		PointKD near = null;
		if (k.array[dim] > slice) {
			near = right.nearest(k);
			double t = near.distanceSq(k);
			if (k.distanceSq(left.rect.getNearest(k)) < t) {
				PointKD tmp = left.nearest(k);
				if (tmp.distanceSq(k) < t) {
					near = tmp;
				}
			}
		} else {
			near = left.nearest(k);
			double t = near.distanceSq(k);
			if (k.distanceSq(right.rect.getNearest(k)) < t) {
				PointKD tmp = right.nearest(k);
				if (tmp.distanceSq(k) < t) {
					near = tmp;
				}
			}
		}

		return near;
	}

	protected PointKD[] range(RectKD r) {
		PointKD[] tmp = new PointKD[0];
		if (r.intersects(left.rect)) {
			PointKD[] tmpl = left.range(r);
			if(0 == tmp.length)
				tmp = tmpl;
		}
		if (r.intersects(right.rect)) {
			PointKD[] tmpr = right.range(r);
			if (0 == tmp.length)
				tmp = tmpr;
			else if (0 < tmpr.length) {
				PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
				System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
				System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
				tmp = tmp2;
			}
		}
		return tmp;
	}

	protected BranchKD extend(BucketKD b) {
		if (b == left) {
			left = null;
			left = new BranchKD(this);
			return (BranchKD) left;
		} else if (b == right) {
			right = null;
			right = new BranchKD(this);
			return (BranchKD) right;
		}
		return null;
	}
}

BucketKD

package org.csdgn.util;

public class BucketKD extends NodeKD {
	protected PointKD bucket[];
	protected int current;

	protected BucketKD(KDTreeB reference) {
		ref = reference;
		ref.root = this;
		bucket = new PointKD[ref.bucketSize];
		parent = null;
		rect = new RectKD();
		depth = 0;
	}
	
	protected BucketKD(BranchKD p) {
		ref = p.ref;
		bucket = new PointKD[ref.bucketSize];
		parent = p;
		rect = new RectKD();
		depth = p.depth + 1;
	}

	protected PointKD approx(PointKD k) {
		double nearestDist = Double.POSITIVE_INFINITY;
		int nearest = 0;
		for (int i = 0; i < current; i++) {
			double distance = k.distanceSq(bucket[i]);
			if (distance < nearestDist) {
				nearestDist = distance;
				nearest = i;
			}
		}
		return bucket[nearest];
	}

	protected PointKD nearest(PointKD k) {
		return approx(k);
	}

	protected void add(PointKD k) {
		if (current >= ref.bucketSize) {
			BranchKD b = null;
			if (null == parent)
				b = new BranchKD(ref);
			else
				b = parent.extend(this);

			int dim = b.depth % ref.dimensions;
			double total = 0;
			for (int i = 0; i < current; i++) {
				total += bucket[i].array[dim];
			}
			b.slice = total / current;
			for (int i = 0; i < current; i++) {
				b.add(bucket[i]);
			}
			b.add(k);
			bucket = null;
			parent = null;
			current = 0;
			return;
		}
		rect.add(k);
		bucket[current++] = k;
	}

	protected PointKD[] range(RectKD r) {
		PointKD[] tmp = new PointKD[current];
		int n = 0;
		for (int i = 0; i < current; i++) {
			if (r.contains(bucket[i])) {
				tmp[n++] = bucket[i];
			}
		}
		PointKD[] tmp2 = new PointKD[n];
		System.arraycopy(tmp, 0, tmp2, 0, n);
		return tmp2;
	}
}


PointKD

package org.csdgn.util;

/**
 * A K-Dimensional Point, used in conjunction with the KD-Tree multidimensional
 * data structure.
 * 
 * @author Robert Maupin
 * 
 */
public class PointKD {
	protected double[] array;

	/**
	 * Constructor defining the number of dimensions to use in this Point;
	 * 
	 * @param dimensions
	 *            - The number of dimensions in this point
	 */
	public PointKD(int dimensions) {
		array = new double[dimensions];
	}

	/**
	 * Creates a PointKD from an array.
	 * 
	 * @param pos
	 *            - An array of Numbers
	 */
	public PointKD(double[] pos) {
		array = pos.clone();
		// array = new double[pos.length];
		// System.arraycopy(pos, 0, array, 0, array.length);
	}

	/**
	 * Creates a copy of the selected point.
	 * 
	 * @param p
	 *            - A point to copy
	 */
	public PointKD(PointKD p) {
		// array = new double[p.array.length];
		setPoint(p);
	}

	/**
	 * Sets the coordinates of this point to be the same as the selected point.
	 * 
	 * @param p
	 *            - A point to copy
	 */
	public void setPoint(PointKD p) {
		this.array = p.array.clone();
		// System.arraycopy(p.array, 0, array, 0, array.length);
	}

	/**
	 * Returns the number of dimensions this point has.
	 * 
	 * @return the dimensions
	 */
	public int getDimensions() {
		return array.length;
	}

	/**
	 * Returns the coordinate at dimension <b>i</b>.
	 * 
	 * @param i
	 *            - The dimension to retrieve.
	 * @return The coordinate at i.
	 */
	public double getCoordinate(int i) {
		return array[i];
	}

	/**
	 * Sets the coordinate to <b>k</b> at dimension <b>i</b>.
	 * 
	 * @param i
	 *            - the dimension to set
	 * @param k
	 *            - the value to set the dimension to
	 */
	public void setCoordinate(int i, double k) {
		array[i] = k;
	}

	/**
	 * Compares this to a selected point and returns the euclidean distance
	 * between them.
	 * 
	 * @param p
	 *            - The Point to get the distance to.
	 * @return The distance between this and <b>p</b>.
	 */
	public double distance(PointKD p) {
		return distance(this, p);
	}

	/**
	 * Compares this to a selected point and returns the squared euclidean
	 * distance between them.
	 * 
	 * @param p
	 *            - The Point to get the distance to.
	 * @return The distance between this and <b>p</b>.
	 */
	public double distanceSq(PointKD p) {
		return distanceSq(this, p);
	}

	/**
	 * Prints the class name and the point coordinates.
	 */
	public String toString() {
		String output = getClass().getName() + "[";
		for (int i = 0; i < array.length; i++) {
			if (0 != i)
				output += ",";
			output += array[i];
		}
		return output + "]";
	}

	/**
	 * @return a distinct copy of this object
	 */
	public Object clone() {
		return new PointKD(this);
	}

	/**
	 * Compares two Points and returns the euclidean distance between them.
	 * 
	 * @param a
	 *            - The first set of numbers
	 * @param b
	 *            - The second set of numbers
	 * @return The distance between <b>a</b> and <b>b</b>.
	 */
	public static final double distance(PointKD a, PointKD b) {
		return distance(a.array, b.array);
	}

	/**
	 * Compares two Points and returns the squared euclidean distance between
	 * them.
	 * 
	 * @param a
	 *            - The first set of numbers
	 * @param b
	 *            - The second set of numbers
	 * @return The distance between <b>a</b> and <b>b</b>.
	 */
	public static final double distanceSq(PointKD a, PointKD b) {
		return distanceSq(a.array, b.array);
	}

	/**
	 * Compares arrays of double and returns the euclidean distance between
	 * them.
	 * 
	 * @param a
	 *            - The first set of numbers
	 * @param b
	 *            - The second set of numbers
	 * @return The distance between <b>a</b> and <b>b</b>.
	 */
	public static final double distance(double[] a, double[] b) {
		return Math.sqrt(distanceSq(a, b));
	}

	/**
	 * Compares arrays of double and returns the squared euclidean distance
	 * between them.
	 * 
	 * @param a
	 *            - The first set of numbers
	 * @param b
	 *            - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distanceSq(double[] a, double[] b) {
		if (a.length != b.length || a.length < 1)
			return -1;
		double total = 0;
		for (int i = 0; i < a.length; i++)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return total;
	}
}

RectKD

package org.csdgn.util;

/**
 * A K-Dimensional Hyper Rectangle, used in conjunction with the KD-Tree
 * multidimensional data structure.
 * 
 * @author Robert Maupin
 * 
 */
public class RectKD {
	PointKD upper;
	PointKD lower;

	/**
	 * Creates an empty RectKD
	 */
	public RectKD() {
		upper = null;
		lower = null;
	}

	public RectKD(PointKD s) {
		this(s, s);
	}

	public RectKD(PointKD l, PointKD u) {
		this();
		if (l.array.length == u.array.length) {
			upper = new PointKD(u);
			lower = new PointKD(l);
		}
	}

	public void add(PointKD p) {
		if (upper == null) {
			upper = new PointKD(p);
			lower = new PointKD(p);
			return;
		}

		for (int i = 0; i < upper.array.length; i++) {
			if (p.array[i] > upper.array[i])
				upper.array[i] = p.array[i];
			if (p.array[i] < lower.array[i])
				lower.array[i] = p.array[i];
		}
	}

	public boolean contains(PointKD p) {
		if (upper == null)
			return false;
		if (upper.array.length != p.array.length)
			return false;
		boolean inside = true;
		for (int i = 0; i < upper.array.length; i++) {
			if (p.array[i] > upper.array[i])
				inside = false;
			if (p.array[i] < lower.array[i])
				inside = false;
		}
		return inside;
	}

	public boolean intersects(RectKD r) {
		boolean check = false;

		if (null == r)
			return false;
		if (null == r.lower)
			return false;
		if (null == r.upper)
			return false;
		if (r.upper.array.length != upper.array.length)
			return false;

		int len = upper.array.length;
		for (int i = 0; i < len; i++) {
			if (upper.array[i] < r.lower.array[i])
				check = true;
			if (lower.array[i] > r.upper.array[i])
				check = true;
		}

		return !check;
	}

	public PointKD getNearest(PointKD p) {
		if (upper == null)
			return null;
		if (upper.array.length != p.array.length)
			return null;
		PointKD near = new PointKD(p);
		for (int i = 0; i < upper.array.length; i++) {
			near.array[i] = p.array[i];
			if (p.array[i] > upper.array[i])
				near.array[i] = upper.array[i];
			if (p.array[i] < lower.array[i])
				near.array[i] = lower.array[i];
		}
		return near;
	}
}