User:Chase-san/Kd-Tree

From Robowiki
< User:Chase-san
Revision as of 05:28, 9 June 2012 by Chase-san (talk | contribs) (some minor fixes)
Jump to navigation Jump to search

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

This and all my other code in which I display on the robowiki falls under the ZLIB License.

Oh yeah, am I the only one that has a Range function?

KDTreeF

package org.csdgn.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This is a KD Bucket Tree, for fast sorting and searching of K dimensional
 * data.
 * 
 * @author Chase
 * 
 */
public class KDTree<T> {
	protected static final int defaultBucketSize = 48;

	private final int dimensions;
	private final int bucketSize;
	private NodeKD root;

	/**
	 * Constructor with value for dimensions.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 */
	public KDTree(int dimensions) {
		this.dimensions = dimensions;
		this.bucketSize = defaultBucketSize;
		this.root = new NodeKD();
	}

	/**
	 * Constructor with value for dimensions and bucket size.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 * @param bucket
	 *            - Size of the buckets.
	 */
	public KDTree(int dimensions, int bucket) {
		this.dimensions = dimensions;
		this.bucketSize = bucket;
		this.root = new NodeKD();
	}

	/**
	 * Add a key and its associated value to the tree.
	 * 
	 * @param key
	 *            - Key to add
	 * @param val
	 *            - object to add
	 */
	public void add(double[] key, T val) {
		root.addPoint(key, val);
	}

	/**
	 * Returns all PointKD within a certain range defined by an upper and lower
	 * PointKD.
	 * 
	 * @param low
	 *            - lower bounds of area
	 * @param high
	 *            - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	@SuppressWarnings("unchecked")
	public List<T> getRange(double[] low, double[] high) {
		Object[] objs = root.range(high, low);
		ArrayList<T> range = new ArrayList<T>(objs.length);
		for (Object o : objs) {
			range.add((T) o);
		}
		return range;
	}

	/**
	 * Gets the N nearest neighbors to the given key.
	 * 
	 * @param key
	 *            - Key
	 * @param num
	 *            - Number of results
	 * @return Array of Item Objects, distances within the items are the square
	 *         of the actual distance between them and the key
	 */
	public ResultHeap<T> getNearestNeighbors(double[] key, int num) {
		ResultHeap<T> heap = new ResultHeap<T>(num);
		root.nearest(heap, key);
		return heap;
	}


	// Internal tree node
	private class NodeKD {
		private NodeKD left, right;
		private double[] maxBounds, minBounds;
		private Object[] bucketValues;
		private double[][] bucketKeys;
		private boolean isLeaf;
		private int current, sliceDimension;
		private double slice;

		private NodeKD() {
			bucketValues = new Object[bucketSize];
			bucketKeys = new double[bucketSize][];

			left = right = null;
			maxBounds = minBounds = null;
			
			isLeaf = true;
			
			current = 0;
		}

		// what it says on the tin
		private void addPoint(double[] key, Object val) {
			if(isLeaf) {
				addLeafPoint(key,val);
			} else {
				extendBounds(key);
				if (key[sliceDimension] > slice) {
					right.addPoint(key, val);
				} else {
					left.addPoint(key, val);
				}
			}
		}
		
		private void addLeafPoint(double[] key, Object val) {
			extendBounds(key);
			if (current + 1 > bucketSize) {
				splitLeaf();
				addPoint(key, val);
				return;
			}
			bucketKeys[current] = key;
			bucketValues[current] = val;
			++current;
		}
		
		/**
		 * Find the nearest neighbor recursively.
		 */
		@SuppressWarnings("unchecked")
		private void nearest(ResultHeap<T> heap, double[] data) {
			if(current == 0)
				return;
			if(isLeaf) {
				//IS LEAF
				for (int i = 0; i < current; i++) {
					double dist = pointDistSq(bucketKeys[i], data);
					heap.offer(dist, (T) bucketValues[i]);
				}
			} else {
				//IS BRANCH
				if (data[sliceDimension] > slice) {
					right.nearest(heap, data);
					if(left.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,left.minBounds,left.maxBounds) < heap.getMaxKey()) {
						left.nearest(heap, data);
					}
				} else {
					left.nearest(heap, data);
					if (right.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,right.minBounds,right.maxBounds) < heap.getMaxKey()) {
						right.nearest(heap, data);
					}
				}
			}
		}

		// gets all items from within a range
		private Object[] range(double[] upper, double[] lower) {
			if (bucketValues == null) {
				// Branch
				Object[] tmp = new Object[0];
				if (intersects(upper, lower, left.maxBounds, left.minBounds)) {
					Object[] tmpl = left.range(upper, lower);
					if (0 == tmp.length) tmp = tmpl;
				}
				if (intersects(upper, lower, right.maxBounds, right.minBounds)) {
					Object[] tmpr = right.range(upper, lower);
					if (0 == tmp.length)
						tmp = tmpr;
					else if (0 < tmpr.length) {
						Object[] tmp2 = new Object[tmp.length + tmpr.length];
						System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
						System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
						tmp = tmp2;
					}
				}
				return tmp;
			}
			// Leaf
			Object[] tmp = new Object[current];
			int n = 0;
			for (int i = 0; i < current; i++) {
				if (contains(upper, lower, bucketKeys[i])) {
					tmp[n++] = bucketValues[i];
				}
			}
			Object[] tmp2 = new Object[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}

		// These are helper functions from here down
		// check if this hyper rectangle contains a give hyper-point
		public boolean contains(double[] upper, double[] lower, double[] point) {
			if (current == 0) return false;
			for (int i = 0; i < point.length; ++i) {
				if (point[i] > upper[i] || point[i] < lower[i]) return false;
			}
			return true;
		}

		// checks if two hyper-rectangles intersect
		public boolean intersects(double[] up0, double[] low0, double[] up1, double[] low1) {
			for (int i = 0; i < up0.length; ++i) {
				if (up1[i] < low0[i] || low1[i] > up0[i]) return false;
			}
			return true;
		}

		private void splitLeaf() {
			double bestRange = 0;
			for(int i=0;i<dimensions;++i) {
				double range = maxBounds[i] - minBounds[i];
				if(range > bestRange) {
					sliceDimension = i;
					bestRange = range;
				}
			}
			
			left = new NodeKD();
			right = new NodeKD();
			
			slice = (maxBounds[sliceDimension] + minBounds[sliceDimension]) * 0.5;
			
			for (int i = 0; i < current; ++i) {
				if (bucketKeys[i][sliceDimension] > slice) {
					right.addLeafPoint(bucketKeys[i], bucketValues[i]);
				} else {
					left.addLeafPoint(bucketKeys[i], bucketValues[i]);
				}
			}
			bucketKeys = null;
			bucketValues = null;
			isLeaf = false;
		}

		// expands this hyper rectangle
		private void extendBounds(double[] key) {
			if (maxBounds == null) {
				maxBounds = Arrays.copyOf(key, dimensions);
				minBounds = Arrays.copyOf(key, dimensions);
				return;
			}
			for (int i = 0; i < key.length; ++i) {
				if (maxBounds[i] < key[i]) maxBounds[i] = key[i];
				if (minBounds[i] > key[i]) minBounds[i] = key[i];
			}
		}
	}
	
	/* I may have borrowed these from an early version of Red's tree. I however forget. */
	private static final double pointDistSq(double[] p1, double[] p2) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < p1.length; i++) {
            d += (q=(p1[i] - p2[i]))*q;
        }
        return d;
    }

    private static final double regionDistSq(double[] point, double[] min, double[] max) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < point.length; i++) {
            if (point[i] > max[i]) {
            	d += (q = (point[i] - max[i]))*q;
            } else if (point[i] < min[i]) {
                d += (q = (point[i] - min[i]))*q;
            }
        }
        return d;
    }
}

ResultHeap

package org.csdgn.util;

/**
 * @author Chase
 * 
 * @param <T>
 */
public class ResultHeap<T> {
	private Object[] data;
	private double[] keys;
	private int capacity;
	private int size;

	protected ResultHeap(int capacity) {
		this.data = new Object[capacity];
		this.keys = new double[capacity];
		this.capacity = capacity;
		this.size = 0;
	}

	protected void offer(double key, T value) {
		int i = size;
		for (; i > 0 && keys[i - 1] > key; --i);
		if (i >= capacity) return;
		if (size < capacity) ++size;
		int j = i + 1;
		System.arraycopy(keys, i, keys, j, size - j);
		keys[i] = key;
		System.arraycopy(data, i, data, j, size - j);
		data[i] = value;
	}

	public double getMaxKey() {
		return keys[size - 1];
	}
	
	@SuppressWarnings("unchecked")
	public T removeMax() {
		if(isEmpty()) return null;
		return (T)data[--size];
	}

	public boolean isEmpty() {
		return size == 0;
	}

	public boolean isFull() {
		return size == capacity;
	}

	public int size() {
		return size;
	}

	public int capacity() {
		return capacity;
	}
}