Difference between revisions of "User:Chase-san/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
m (license)
m (Just a minor adjustment, uses generics now. Removed pointless references.)
Line 6: Line 6:
 
Oh yeah, am I the only one that has a Range function?
 
Oh yeah, am I the only one that has a Range function?
  
=== KDTreeC ===
+
=== KDTree ===
 
<syntaxhighlight>
 
<syntaxhighlight>
package org.csdgn.util.kd2;
+
package org.csdgn.util;
  
 
import java.util.Arrays;
 
import java.util.Arrays;
Line 17: Line 17:
 
  *
 
  *
 
  */
 
  */
public class KDTreeC {
+
public class KDTree<T> {
 
/**
 
/**
 
* Item, for moving data around.
 
* Item, for moving data around.
 
* @author Chase
 
* @author Chase
 
*/
 
*/
public class Item {
+
public static class Item {
 
public double[] pnt;
 
public double[] pnt;
 
public Object obj;
 
public Object obj;
Line 33: Line 33:
 
private final int bucket_size;
 
private final int bucket_size;
 
private NodeKD root;
 
private NodeKD root;
+
 
/**
 
/**
 
* Constructor with value for dimensions.
 
* Constructor with value for dimensions.
 
* @param dimensions - Number of dimensions
 
* @param dimensions - Number of dimensions
 
*/
 
*/
public KDTreeC(int dimensions) {
+
public KDTree(int dimensions) {
 
this.dimensions = dimensions;
 
this.dimensions = dimensions;
 
this.bucket_size = 64;
 
this.bucket_size = 64;
this.root = new NodeKD(this);
+
this.root = new NodeKD();
 
}
 
}
+
 
/**
 
/**
 
* Constructor with value for dimensions and bucket size.
 
* Constructor with value for dimensions and bucket size.
Line 49: Line 49:
 
* @param bucket - Size of the buckets.
 
* @param bucket - Size of the buckets.
 
*/
 
*/
public KDTreeC(int dimensions, int bucket) {
+
public KDTree(int dimensions, int bucket) {
 
this.dimensions = dimensions;
 
this.dimensions = dimensions;
 
this.bucket_size = bucket;
 
this.bucket_size = bucket;
this.root = new NodeKD(this);
+
this.root = new NodeKD();
 
}
 
}
+
 
/**
 
/**
 
* Add a key and its associated value to the tree.
 
* Add a key and its associated value to the tree.
Line 60: Line 60:
 
* @param val - object to add
 
* @param val - object to add
 
*/
 
*/
public void add(double[] key, Object val) {
+
public void add(double[] key, T val) {
 
root.add(new Item(key,val));
 
root.add(new Item(key,val));
 
}
 
}
+
 
/**
 
/**
 
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
 
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
Line 73: Line 73:
 
return root.range(high, low);
 
return root.range(high, low);
 
}
 
}
+
 
/**
 
/**
 
* Gets the N nearest neighbors to the given key.
 
* Gets the N nearest neighbors to the given key.
Line 114: Line 114:
 
return total;
 
return total;
 
}
 
}
+
 
//Internal tree node
 
//Internal tree node
 
private class NodeKD {
 
private class NodeKD {
private KDTreeC owner;
 
 
private NodeKD left, right;
 
private NodeKD left, right;
 
private double[] upper, lower;
 
private double[] upper, lower;
Line 123: Line 122:
 
private int current, dim;
 
private int current, dim;
 
private double slice;
 
private double slice;
+
 
//note: we always start as a bucket
 
//note: we always start as a bucket
private NodeKD(KDTreeC own) {
+
private NodeKD() {
owner = own;
 
 
upper = lower = null;
 
upper = lower = null;
 
left = right = null;
 
left = right = null;
bucket = new Item[own.bucket_size];
+
bucket = new Item[bucket_size];
 
current = 0;
 
current = 0;
 
dim = 0;
 
dim = 0;
Line 136: Line 134:
 
//we use this one here
 
//we use this one here
 
private NodeKD(NodeKD node) {
 
private NodeKD(NodeKD node) {
owner = node.owner;
 
 
dim = node.dim + 1;
 
dim = node.dim + 1;
bucket = new Item[owner.bucket_size];
+
bucket = new Item[bucket_size];
if(dim + 1 > owner.dimensions) dim = 0;
+
if(dim + 1 > dimensions) dim = 0;
 
left = right = null;
 
left = right = null;
 
upper = lower = null;
 
upper = lower = null;
Line 153: Line 150:
 
} else {
 
} else {
 
//Bucket
 
//Bucket
if(current+1>owner.bucket_size) {
+
if(current+1>bucket_size) {
 
split(m);
 
split(m);
 
return;
 
return;
Line 168: Line 165:
 
right.nearestn(arr, data);
 
right.nearestn(arr, data);
 
if(left.current != 0) {
 
if(left.current != 0) {
if(KDTreeC.distanceSq(left.nearestRect(data),data)
+
if(KDTree.distanceSq(left.nearestRect(data),data)
 
< arr.getLongest()) {
 
< arr.getLongest()) {
 
left.nearestn(arr, data);
 
left.nearestn(arr, data);
 
}
 
}
 
}
 
}
+
 
} else {
 
} else {
 
left.nearestn(arr, data);
 
left.nearestn(arr, data);
 
if(right.current != 0) {
 
if(right.current != 0) {
if(KDTreeC.distanceSq(right.nearestRect(data),data)  
+
if(KDTree.distanceSq(right.nearestRect(data),data)  
 
< arr.getLongest()) {
 
< arr.getLongest()) {
 
right.nearestn(arr, data);
 
right.nearestn(arr, data);
Line 186: Line 183:
 
//Bucket
 
//Bucket
 
for(int i = 0; i < current; i++) {
 
for(int i = 0; i < current; i++) {
bucket[i].distance = KDTreeC.distanceSq(bucket[i].pnt, data);
+
bucket[i].distance = KDTree.distanceSq(bucket[i].pnt, data);
 
arr.add(bucket[i]);
 
arr.add(bucket[i]);
 
}
 
}
Line 227: Line 224:
 
return tmp2;
 
return tmp2;
 
}
 
}
+
 
//These are helper functions from here down
 
//These are helper functions from here down
 
//check if this hyper rectangle contains a give hyper-point
 
//check if this hyper rectangle contains a give hyper-point
Line 273: Line 270:
 
private void expand(double[] data) {
 
private void expand(double[] data) {
 
if(upper == null) {
 
if(upper == null) {
upper = Arrays.copyOf(data, owner.dimensions);
+
upper = Arrays.copyOf(data, dimensions);
lower = Arrays.copyOf(data, owner.dimensions);
+
lower = Arrays.copyOf(data, dimensions);
 
return;
 
return;
 
}
 
}

Revision as of 00:12, 8 June 2012

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

This and all my other code in which I display on the robowiki falls under the ZLIB License.

Oh yeah, am I the only one that has a Range function?

KDTree

package org.csdgn.util;

import java.util.Arrays;

/**
 * This is a KD Bucket Tree, for fast sorting and searching of K dimensional data.
 * @author Chase
 *
 */
public class KDTree<T> {
	/**
	 * Item, for moving data around.
	 * @author Chase
	 */
	public static class Item {
		public double[] pnt;
		public Object obj;
		public double distance;
		private Item(double[] p, Object o) {
			pnt = p; obj = o;
		}
	}
	private final int dimensions;
	private final int bucket_size;
	private NodeKD root;
 
	/**
	 * Constructor with value for dimensions.
	 * @param dimensions - Number of dimensions
	 */
	public KDTree(int dimensions) {
		this.dimensions = dimensions;
		this.bucket_size = 64;
		this.root = new NodeKD();
	}
 
	/**
	 * Constructor with value for dimensions and bucket size.
	 * @param dimensions - Number of dimensions
	 * @param bucket - Size of the buckets.
	 */
	public KDTree(int dimensions, int bucket) {
		this.dimensions = dimensions;
		this.bucket_size = bucket;
		this.root = new NodeKD();
	}
 
	/**
	 * Add a key and its associated value to the tree.
	 * @param key - Key to add
	 * @param val - object to add
	 */
	public void add(double[] key, T val) {
		root.add(new Item(key,val));
	}
 
	/**
	 * Returns all PointKD within a certain range defined by an upper and lower PointKD.
	 * @param low - lower bounds of area
	 * @param high - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	public Item[] getRange(double[] low, double[] high) {
		return root.range(high, low);
	}
 
	/**
	 * Gets the N nearest neighbors to the given key.
	 * @param key - Key
	 * @param num - Number of results
	 * @return Array of Item Objects, distances within the items
	 * are the square of the actual distance between them and the key
	 */
	public Item[] getNearestNeighbor(double[] key, int num) {
		ShiftArray arr = new ShiftArray(num);
		root.nearestn(arr, key);
		return arr.getArray();
	}
	/**
	 * Compares arrays of double and returns the euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distance(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return Math.sqrt(total);
	}
	/**
	 * Compares arrays of double and returns the squared euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distanceSq(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return total;
	}
 
	//Internal tree node
	private class NodeKD {
		private NodeKD left, right;
		private double[] upper, lower;
		private Item[] bucket;
		private int current, dim;
		private double slice;
 
		//note: we always start as a bucket
		private NodeKD() {
			upper = lower = null;
			left = right = null;
			bucket = new Item[bucket_size];
			current = 0;
			dim = 0;
		}
		//when we create non-root nodes within this class
		//we use this one here
		private NodeKD(NodeKD node) {
			dim = node.dim + 1;
			bucket = new Item[bucket_size];
			if(dim + 1 > dimensions) dim = 0;
			left = right = null;
			upper = lower = null;
			current = 0;
		}
		//what it says on the tin
		private void add(Item m) {
			if(bucket == null) {
				//Branch
				if(m.pnt[dim] > slice)
					right.add(m);
				else left.add(m);
			} else {
				//Bucket
				if(current+1>bucket_size) {
					split(m);
					return;
				}
				bucket[current++] = m;
			}
			expand(m.pnt);
		}
		//nearest neighbor thing
		private void nearestn(ShiftArray arr, double[] data) {
			if(bucket == null) {
				//Branch
				if(data[dim] > slice) {
					right.nearestn(arr, data);
					if(left.current != 0) {
						if(KDTree.distanceSq(left.nearestRect(data),data)
								< arr.getLongest()) {
							left.nearestn(arr, data);
						}
					}
 
				} else {
					left.nearestn(arr, data);
					if(right.current != 0) {
						if(KDTree.distanceSq(right.nearestRect(data),data) 
								< arr.getLongest()) {
							right.nearestn(arr, data);
						}
					}
				}
			} else {
				//Bucket
				for(int i = 0; i < current; i++) {
					bucket[i].distance = KDTree.distanceSq(bucket[i].pnt, data);
					arr.add(bucket[i]);
				}
			}
		}
		//gets all items from within a range
		private Item[] range(double[] upper, double[] lower) {
			//TODO: clean this up a bit
			if(bucket == null) {
				//Branch
				Item[] tmp = new Item[0];
				if (intersects(upper,lower,left.upper,left.lower)) {
					Item[] tmpl = left.range(upper,lower);
					if(0 == tmp.length)
						tmp = tmpl;
				}
				if (intersects(upper,lower,right.upper,right.lower)) {
					Item[] tmpr = right.range(upper,lower);
					if (0 == tmp.length)
						tmp = tmpr;
					else if (0 < tmpr.length) {
						Item[] tmp2 = new Item[tmp.length + tmpr.length];
						System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
						System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
						tmp = tmp2;
					}
				}
				return tmp;
			}
			//Bucket
			Item[] tmp = new Item[current];
			int n = 0;
			for (int i = 0; i < current; i++) {
				if (contains(upper, lower, bucket[i].pnt)) {
					tmp[n++] = bucket[i];
				}
			}
			Item[] tmp2 = new Item[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}
 
		//These are helper functions from here down
		//check if this hyper rectangle contains a give hyper-point
		public boolean contains(double[] upper, double[] lower, double[] point) {
			if(current == 0) return false;
			for(int i=0; i<point.length; ++i) {
				if(point[i] > upper[i] ||
					point[i] < lower[i]) 
						return false;
			}
			return true;
		}
		//checks if two hyper-rectangles intersect
		public boolean intersects(double[] up0, double[] low0,
				double[] up1, double[] low1) {
			for(int i=0; i<up0.length; ++i) {
				if(up1[i] < low0[i] || low1[i] > up0[i]) return false;
			}
			return true;
		}
		//splits a bucket into a branch
		private void split(Item m) {
			//split based on our bound data
			slice = (upper[dim]+lower[dim])/2.0;
			left = new NodeKD(this);
			right = new NodeKD(this);
			for(int i=0; i<current; ++i) {
				if(bucket[i].pnt[dim] > slice)
					right.add(bucket[i]);
				else left.add(bucket[i]);
			}
			bucket = null;
			add(m);
		}
		//gets nearest point to data within this hyper rectangle
		private double[] nearestRect(double[] data) {
			double[] nearest = data.clone();
			for(int i = 0; i < data.length; ++i) {
				if(nearest[i] > upper[i]) nearest[i] = upper[i];
				if(nearest[i] < lower[i]) nearest[i] = lower[i];
			}
			return nearest;
		}
		//expands this hyper rectangle
		private void expand(double[] data) {
			if(upper == null) {
				upper = Arrays.copyOf(data, dimensions);
				lower = Arrays.copyOf(data, dimensions);
				return;
			}
			for(int i=0; i<data.length; ++i) {
				if(upper[i] < data[i]) upper[i] = data[i];
				if(lower[i] > data[i]) lower[i] = data[i];
			}
		}
	}
	//A simple shift array that sifts data up
	//as we add new ones to lower in the array.
	private class ShiftArray {
		private Item[] items;
		private final int max;
		private int current;
		private ShiftArray(int maximum) {
			max = maximum;
			current = 0;
			items = new Item[max];
		}
		private void add(Item m) {
			int i;
			for(i=current;i>0&&items[i-1].distance >  m.distance; --i);
			if(i >= max) return;
			if(current < max) ++current;
			System.arraycopy(items, i, items, i+1, current-(i+1));
			items[i] = m;
		}
		private double getLongest() {
			if (current < max) return Double.POSITIVE_INFINITY;
			return items[max-1].distance;
		}
		private Item[] getArray() {
			return items.clone();
		}
	}
}