Difference between revisions of "User:Rednaxela/kD-Tree"

From Robowiki
Jump to navigation Jump to search
(Use a heap, it's sliiiiightly faster)
(fix bug)
Line 298: Line 298:
 
Object[] data = resultHeap.getData();
 
Object[] data = resultHeap.getData();
 
double[] dist = resultHeap.getDistances();
 
double[] dist = resultHeap.getDistances();
for (int i=0; i<data.length; i++) {
+
for (int i=0; i<resultHeap.values; i++) {
 
T value = tree.map.get(data[i]);
 
T value = tree.map.get(data[i]);
 
results.add(new Entry<T>(dist[i], value));
 
results.add(new Entry<T>(dist[i], value));

Revision as of 20:21, 26 August 2009

A nice efficent small kD-Tree. It's quite fast... Feel free to use

/**
 * Copyright 2009 Rednaxela
 * 
 * This software is provided 'as-is', without any express or implied
 * warranty. In no event will the authors be held liable for any damages
 * arising from the use of this software.
 * 
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 * 
 *    1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 
 *    2. This notice may not be removed or altered from any source
 *    distribution.
 */

package ags.utils.newtree2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;

/**
 * An efficent well-optimized kd-tree
 * 
 * @author Rednaxela
 */
public class KdTree<T> {
	// Static variables
	private static final int bucketSize = 32;

	// All types
	private final int dimensions;

	// Root only
	private final HashMap<Object, T> map;
	private double[] weights;

	// Leaf only
	private double[][] locations;
	private int locationCount;

	// Stem only
	private KdTree<T> left, right;
	private int splitDimension;
	private double splitValue;

	// Bounds
	private double[] minLimit, maxLimit;

	/**
	 * Extends the bounds of this node do include a new location
	 */
	private final void extendBounds(double[] location) {
		if (minLimit == null) {
			minLimit = Arrays.copyOf(location, dimensions);
			maxLimit = Arrays.copyOf(location, dimensions);
			return;
		}

		for (int i=0; i<dimensions; i++) {
			if (minLimit[i] > location[i]) {
				minLimit[i] = location[i];
			}
			if (maxLimit[i] < location[i]) {
				maxLimit[i] = location[i];
			}
		}
	}

	/**
	 * Find the widest axis of the bounds of this node
	 */
	private final int findWidestAxis() {
		int widest = 0;
		double width = (maxLimit[0] - minLimit[0]);
		for (int i = 1; i < dimensions; i++) {
			double nwidth = maxLimit[i] - minLimit[i];
			if (nwidth > width) {
				widest = i;
				width = nwidth;
			}
		}
		return widest;
	}

	// Main constructor
	public KdTree(int dimensions) {
		this.dimensions = dimensions;

		// Init as leaf
		this.locations = new double[bucketSize][];
		this.locationCount = 0;

		// Init as root
		this.map = new HashMap<Object, T>();
		this.weights = new double[dimensions];
		Arrays.fill(this.weights, 1.0);
	}

	// Child constructor
	private KdTree(KdTree<T> parent, boolean right) {
		this.dimensions = parent.dimensions;

		// Init as leaf
		this.locations = new double[bucketSize][];
		this.locationCount = 0;

		// Init as non-root
		this.map = null;
	}

	/**
	 * Add a point and associated value to the tree
	 */
	public static <T> void addPoint(KdTree<T> tree, double[] location, T
			value) {
		KdTree<T> cursor = tree;

		while (cursor.locations == null || cursor.locationCount >=
			cursor.locations.length) {
			if (cursor.locations != null) {
				cursor.splitDimension = cursor.findWidestAxis();
				cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
						cursor.maxLimit[cursor.splitDimension]) * 0.5;

				// Don't split node if it has no width in any axis. Double the bucket size instead
				if ((cursor.minLimit[cursor.splitDimension] -
						cursor.maxLimit[cursor.splitDimension]) == 0) {
					cursor.locations = Arrays.copyOf(cursor.locations,
							cursor.locations.length * 2);
					break;
				}

				// Create child leaves
				KdTree<T> left = new KdTree<T>(cursor, false);
				KdTree<T> right = new KdTree<T>(cursor, true);

				// Move locations into children
				for (double[] oldLocation : cursor.locations) {
					if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
						// Right
						right.locations[right.locationCount] = oldLocation;
						right.locationCount++;
						right.extendBounds(oldLocation);
					}
					else {
						// Left
						left.locations[left.locationCount] = oldLocation;
						left.locationCount++;
						left.extendBounds(oldLocation);
					}
				}

				// Make into stem
				cursor.left = left;
				cursor.right = right;
				cursor.locations = null;
			}

			cursor.extendBounds(location);

			if (location[cursor.splitDimension] > cursor.splitValue) {
				cursor = cursor.right;
			}
			else {
				cursor = cursor.left;
			}
		}

		cursor.locations[cursor.locationCount] = location;
		cursor.locationCount++;
		cursor.extendBounds(location);

		tree.map.put(location, value);
	}

	/**
	 * Enumeration representing the status of a node during the running 
	 */
	private static enum Status {
		NONE,
		LEFTVISITED,
		RIGHTVISITED,
		ALLVISITED
	}

	/**
	 * Stores a distance and value to output
	 */
	public static class Entry<T> {
		public final double distance;
		public final T value;
		private Entry(double distance, T value) {
			this.distance = distance;
			this.value = value;
		}
	}

	/**
	 * Calculates the nearest 'count' points to 'location', with an
arbitrary weighting on dimensions
	 */
	public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
			double[] location, int count, double[] weights) {
		tree.weights = weights;
		return nearestNeighbor(tree, location, count);
	}

	/**
	 * Calculates the nearest 'count' points to 'location'
	 */
	public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
			double[] location, int count) {
		KdTree<T> cursor = tree;
		Status status = Status.NONE;
		Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
		Stack<Status> statusStack = new Stack<Status>();
		double range = Double.POSITIVE_INFINITY;
		ResultHeap resultHeap = new ResultHeap(count); 

		do {
			if (status == Status.ALLVISITED) {
				// At a fully visited part. Move up the tree
				cursor = stack.pop();
				status = statusStack.pop();
				continue;
			}

			if (status == Status.NONE && cursor.locations != null) {
				// At a leaf. Use the data.
				for (int i=0; i<cursor.locationCount; i++) {
					double dist = sqrPointDist(cursor.locations[i], location,
							tree.weights);
					resultHeap.addValue(dist, cursor.locations[i]);
				}
				range = resultHeap.getMaxDist();

				if (stack.empty()) {
					break;
				}
				cursor = stack.pop();
				status = statusStack.pop();
				continue;
			}

			// Going to descend
			KdTree<T> nextCursor = null;
			if (status == Status.NONE) {
				// At a fresh node, descend the most probably useful direction
				if (location[cursor.splitDimension] > cursor.splitValue) {
					// Descend right
					nextCursor = cursor.right;
					status = Status.RIGHTVISITED;
				}
				else {
					// Descend left;
					nextCursor = cursor.left;
					status = Status.LEFTVISITED;
				}
			}
			else if (status == Status.LEFTVISITED) {
				// Left node visited, descend right.
				nextCursor = cursor.right;
				status = Status.ALLVISITED;
			}
			else if (status == Status.RIGHTVISITED) {
				// Right node visited, descend left.
				nextCursor = cursor.left;
				status = Status.ALLVISITED;
			}

			// Check if it's worth descending. Assume it is if it's sibling has not been visited yet. 
			if (status == Status.ALLVISITED) {
				if (nextCursor.locationCount == 0 || sqrPointRegionDist(location,
						nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
					continue;
				}
			}

			// Descend down the tree
			stack.push(cursor);
			statusStack.push(status);
			cursor = nextCursor;
			status = Status.NONE;
		} while (stack.size() > 0 || status != Status.ALLVISITED);

		ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
		Object[] data = resultHeap.getData();
		double[] dist = resultHeap.getDistances();
		for (int i=0; i<resultHeap.values; i++) {
			T value = tree.map.get(data[i]);
			results.add(new Entry<T>(dist[i], value));
		}

		return results;
	}

	/**
	 * Calculates the (squared euclidean) distance between two points
	 */
	private static final double sqrPointDist(double[] p1, double[] p2,
			double[] weights) {
		double d = 0;

		for (int i=0; i<p1.length; i++) {
			double diff = (p1[i] - p2[i]) * weights[i];
			d += diff * diff;
		}

		return d;
	}

	/**
	 * Calculates the closest (squared euclidean) distance between in a
point and a bounding region
	 */
	private static final double sqrPointRegionDist(double[] point, double[]
	                                                                      min, double[] max, double[] weights) {
		double d = 0;

		for (int i=0; i<point.length; i++) {
			if (point[i] > max[i]) {
				double diff = (point[i] - max[i]) * weights[i];
				d += diff * diff;
			} else if (point[i] < min[i]) {
				double diff = (point[i] - min[i]) * weights[i];
				d += diff * diff;
			}
		}

		return d;
	}

	/**
	 * Class for tracking up to 'size' closest values
	 */
	private static class ResultHeap {
		private final Object[] data;
		private final double[] distance;
		private final int size;
		private int values;

		public ResultHeap(int size) {
			this.data = new Object[size+1];
			this.distance = new double[size+1];
			this.size = size;
			this.values = 0;
		}

		public void addValue(double dist, Object value) {
			if (values == size && dist >= distance[0]) {
				return;
			}

			// Insert value
			data[values] = value;
			distance[values] = dist;
			values++;

			// Up-Heapify
			for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
				Object pData = data[p];
				double pDist = distance[p];
				data[p] = data[c];
				distance[p] = distance[c];
				data[c] = pData;
				distance[c] = pDist;
			}

			// If too big, remove the highest value
			if (values > size) {
				// Move the last entry to the top
				values--;
				data[0] = data[values];
				distance[0] = distance[values];

				// Down-Heapify
				for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
					if (c+1 < values && distance[c] < distance[c+1]) {
						c++;
					}
					if (distance[p] < distance[c]) {
						// Swap the points
						Object pData = data[p];
						double pDist = distance[p];
						data[p] = data[c];
						distance[p] = distance[c];
						data[c] = pData;
						distance[c] = pDist;
					}
					else {
						break;
					}
				}
			}
		}

		public double getMaxDist() {
			if (values < size) {
				return Double.POSITIVE_INFINITY;
			}
			return distance[0];
		}

		public Object[] getData() {
			return data;
		}

		public double[] getDistances() {
			return distance;
		}
	}
}