User:Rednaxela/kD-Tree

From Robowiki
< User:Rednaxela
Revision as of 23:06, 23 August 2009 by Rednaxela (talk | contribs) (Post kD-Tree code)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

A nice efficent 359-line kD-Tree. It's quite fast...

package ags.utils.newtree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;

/**
 * An efficent well-optimized kd-tree
 * 
 * @author Rednaxela
 */
public class KdTree<T> {
	// Static variables
	private static final int bucketSize = 32;

	// All types
	private final int dimensions;
	
	// Root only
	private final HashMap<Object, T> map;
	private double[] weights;

	// Leaf only
	private double[][] locations;
	private int locationCount;

	// Stem only
	private KdTree<T> left, right;
	private int splitDimension;
	private double splitValue;

	// Bounds
	private double[] minLimit, maxLimit;
	
	/**
	 * Extends the bounds of this node do include a new location
	 */
	private final void extendBounds(double[] location) {
		if (minLimit == null) {
			minLimit = Arrays.copyOf(location, dimensions);
			maxLimit = Arrays.copyOf(location, dimensions);
			return;
		}

		for (int i=0; i<dimensions; i++) {
			if (minLimit[i] > location[i]) {
				minLimit[i] = location[i];
			}
			if (maxLimit[i] < location[i]) {
				maxLimit[i] = location[i];
			}
		}
	}

	/**
	 * Find the widest axis of the bounds of this node
	 */
	private final int findWidestAxis() {
		int widest = 0;
		double width = (maxLimit[0] - minLimit[0]);
		for (int i = 1; i < dimensions; i++) {
			double nwidth = maxLimit[i] - minLimit[i];
			if (nwidth > width) {
				widest = i;
				width = nwidth;
			}
		}
		return widest;
	}

	// Main constructor
	public KdTree(int dimensions) {
		this.dimensions = dimensions;

		// Init as leaf
		this.locations = new double[bucketSize][];
		this.locationCount = 0;

		// Init as root
		this.map = new HashMap<Object, T>();
		this.weights = new double[dimensions];
		Arrays.fill(this.weights, 1.0);
	}

	// Child constructor
	private KdTree(KdTree<T> parent, boolean right) {
		this.dimensions = parent.dimensions;

		// Init as leaf
		this.locations = new double[bucketSize][];
		this.locationCount = 0;

		// Init as non-root
		this.map = null;
	}

	/**
	 * Add a point and associated value to the tree
	 */
	public static <T> void addPoint(KdTree<T> tree, double[] location, T value) {
		KdTree<T> cursor = tree;

		while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) {
			if (cursor.locations != null) {
				cursor.splitDimension = cursor.findWidestAxis();
				cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;
				
				// Don't split node if it has no width in any axis. Double the bucket size instead
				if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) {
					cursor.locations = Arrays.copyOf(cursor.locations, cursor.locations.length * 2);
					break;
				}

				// Create child leaves
				KdTree<T> left = new KdTree<T>(cursor, false);
				KdTree<T> right = new KdTree<T>(cursor, true);

				// Move locations into children
				for (double[] oldLocation : cursor.locations) {
					if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
						// Right
						right.locations[right.locationCount] = oldLocation;
						right.locationCount++;
						right.extendBounds(oldLocation);
					}
					else {
						// Left
						left.locations[left.locationCount] = oldLocation;
						left.locationCount++;
						left.extendBounds(oldLocation);
					}
				}

				// Make into stem
				cursor.left = left;
				cursor.right = right;
				cursor.locations = null;
			}

			cursor.extendBounds(location);

			if (location[cursor.splitDimension] > cursor.splitValue) {
				cursor = cursor.right;
			}
			else {
				cursor = cursor.left;
			}
		}

		cursor.locations[cursor.locationCount] = location;
		cursor.locationCount++;
		cursor.extendBounds(location);

		tree.map.put(location, value);
	}

	/**
	 * Enumeration representing the status of a node during the running 
	 */
	private static enum Status {
		NONE,
		LEFTVISITED,
		RIGHTVISITED,
		ALLVISITED
	}
	
	/**
	 * Stores a distance and value to output
	 */
	public static class Entry<T> {
		public final double distance;
		public final T value;
		private Entry(double distance, T value) {
			this.distance = distance;
			this.value = value;
		}
	}

	/**
	 * Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
	 */
	public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree, double[] location, int count, double[] weights) {
		tree.weights = weights;
		return nearestNeighbor(tree, location, count);
	}

	/**
	 * Calculates the nearest 'count' points to 'location'
	 */
	public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree, double[] location, int count) {
		KdTree<T> cursor = tree;
		Status status = Status.NONE;
		Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
		Stack<Status> statusStack = new Stack<Status>();
		double range = Double.POSITIVE_INFINITY;
		TopArray top = new TopArray(count); 

		do {
			if (status == Status.ALLVISITED) {
				// At a fully visited part. Move up the tree
				cursor = stack.pop();
				status = statusStack.pop();
				continue;
			}

			if (status == Status.NONE && cursor.locations != null) {
				// At a leaf. Use the data.
				for (int i=0; i<cursor.locationCount; i++) {
					double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
					top.addValue(dist, cursor.locations[i]);
				}
				range = top.getMaxDist();

				if (stack.empty()) {
					break;
				}
				cursor = stack.pop();
				status = statusStack.pop();
				continue;
			}

			// Going to descend
			KdTree<T> nextCursor = null;
			if (status == Status.NONE) {
				// At a fresh node, descend the most probably useful direction
				if (location[cursor.splitDimension] > cursor.splitValue) {
					// Descend right
					nextCursor = cursor.right;
					status = Status.RIGHTVISITED;
				}
				else {
					// Descend left;
					nextCursor = cursor.left;
					status = Status.LEFTVISITED;
				}
			}
			else if (status == Status.LEFTVISITED) {
				// Left node visited, descend right.
				nextCursor = cursor.right;
				status = Status.ALLVISITED;
			}
			else if (status == Status.RIGHTVISITED) {
				// Right node visited, descend left.
				nextCursor = cursor.left;
				status = Status.ALLVISITED;
			}

			// Check if it's worth descending. Assume it is if it's sibling has not been visited yet. 
			if (status == Status.ALLVISITED) {
				if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
					continue;
				}
			}

			// Descend down the tree
			stack.push(cursor);
			statusStack.push(status);
			cursor = nextCursor;
			status = Status.NONE;
		} while (stack.size() > 0 || status != Status.ALLVISITED);

		ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
		Object[] data = top.getData();
		double[] dist = top.getDistances();
		for (int i=0; i<data.length; i++) {
			T value = tree.map.get(data[i]);
			results.add(new Entry<T>(dist[i], value));
		}

		return results;
	}

	/**
	 * Calculates the (squared euclidean) distance between two points
	 */
	private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
		double d = 0;

		for (int i=0; i<p1.length; i++) {
			double diff = (p1[i] - p2[i]) * weights[i];
			d += diff * diff;
		}

		return d;
	}

	/**
	 * Calculates the closest (squared euclidean) distance between in a point and a bounding region
	 */
	private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
		double d = 0;

		for (int i=0; i<point.length; i++) {
			if (point[i] > max[i]) {
				double diff = (point[i] - max[i]) * weights[i];
				d += diff * diff;
			} else if (point[i] < min[i]) {
				double diff = (point[i] - min[i]) * weights[i];
				d += diff * diff;
			}
		}

		return d;
	}
	
	/**
	 * Class for tracking up to 'size' closest values
	 */
	private static class TopArray {
		private final Object[] data;
		private final double[] distance;
		private final int size;
		private int values;

		public TopArray(int size) {
			this.data = new Object[size];
			this.distance = new double[size];
			this.size = size;
			this.values = 0;
		}

		public void addValue(double dist, Object value) {
			int i = values;
			while (i > 0 && distance[i-1] > dist) {
				i--;
			}

			if (i >= size) {
				return;
			}

			if (values < size) {
				values++;
			}

			System.arraycopy(distance, i, distance, i+1, values-(i+1));
			distance[i] = dist;
			System.arraycopy(data, i, data, i+1, values-(i+1));
			data[i] = value;
		}

		public double getMaxDist() {
			if (values < size) {
				return Double.POSITIVE_INFINITY;
			}
			return distance[size-1];
		}

		public Object[] getData() {
			return Arrays.copyOf(data, values);
		}
		
		public double[] getDistances() {
			return Arrays.copyOf(distance, values);
		}
	}
}