Difference between revisions of "User:Chase-san/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
m (minor paragraph fixing)
m (→‎KDTreeF: (no major change, just some minor tweaks, i++ to ++i, removing the enhanced for loop))
 
(15 intermediate revisions by 2 users not shown)
Line 1: Line 1:
 
[[Category:Code Snippets|Kd-Tree]]
 
[[Category:Code Snippets|Kd-Tree]]
This is my version of the KDTree, I admit it has gotten a bit large, and it could be cut down a great deal, I will add a K nearest neighbors implimentation soon enough. however it does have a singular copy, which will find the single closest neighbor (not as useful).
+
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
  
Simonton and I worked on this at about the same time, however I was more experienced and it turns out the whole time I had nothing but a very elusive off by 1 error. I recently rebuilt my kd-tree into this. Which unhooks the points from the branches to speed things up, it also uses buckets, which are great fun.
+
This and all my other code in which I display on the robowiki falls under the [http://en.wikipedia.org/wiki/Zlib_License ZLIB License].
  
Also unlike most others mine is fairly modular and includes a range search (good for those old fashioned pattern matchers!!).
+
Oh yeah, am I the only one that has a Range function?
  
Please forgive the incomplete documentation
+
=== KDTreeF ===
 +
<syntaxhighlight>
 +
package org.csdgn.util;
  
===KDTreeB===
+
import java.util.ArrayList;
<pre>
+
import java.util.Arrays;
package org.csdgn.util;
+
import java.util.List;
 +
 
 +
/**
 +
* This is a KD Bucket Tree, for fast sorting and searching of K dimensional
 +
* data.
 +
*
 +
* @author Chase
 +
*
 +
*/
 +
public class KDTree<T> {
 +
protected static final int defaultBucketSize = 48;
  
public class KDTreeB {
+
private final int dimensions;
public final static int DEFAULT_BUCKET_SIZE = 100;
+
private final int bucketSize;
protected int bucketSize;
+
private NodeKD root;
protected int dimensions;
 
protected NodeKD root;
 
protected PointKD keys[];
 
protected int size;
 
  
 
/**
 
/**
* Initializes a new KDTreeB with a number of dimensions.
+
* Constructor with value for dimensions.
 
*  
 
*  
* @param dim
+
* @param dimensions
*            - Dimensions
+
*            - Number of dimensions
 
*/
 
*/
public KDTreeB(int dim) {
+
public KDTree(int dimensions) {
this(dim, DEFAULT_BUCKET_SIZE);
+
this.dimensions = dimensions;
 +
this.bucketSize = defaultBucketSize;
 +
this.root = new NodeKD();
 
}
 
}
  
 
/**
 
/**
* Initializes a new KDTreeB with a number of dimensions and bucket size
+
* Constructor with value for dimensions and bucket size.
 
*  
 
*  
* @param dim
+
* @param dimensions
*            - Dimensions
+
*            - Number of dimensions
* @param buckets
+
* @param bucket
*            - BucketKD Size
+
*            - Size of the buckets.
 
*/
 
*/
public KDTreeB(int dim, int buckets) {
+
public KDTree(int dimensions, int bucket) {
if(dim < 1)
+
this.dimensions = dimensions;
System.err.println("Dimensions < 1: Undefined Behavior may occur.");
+
this.bucketSize = bucket;
if(buckets < 2)
+
this.root = new NodeKD();
System.err.println("Bucket Size < 2: Undefined Behavior may occur.");
 
 
bucketSize = buckets;
 
dimensions = dim;
 
keys = new PointKD[buckets];
 
size = 0;
 
 
}
 
}
  
 
/**
 
/**
* Adds a new PointKD to the Tree, uses a recursive algorithm.
+
* Add a key and its associated value to the tree.
 
*  
 
*  
* @param k
+
* @param key
*            - Point to add
+
*            - Key to add
 +
* @param val
 +
*            - object to add
 
*/
 
*/
public void add(PointKD k) {
+
public void add(double[] key, T val) {
if (null == root) {
+
root.addPoint(key, val);
root = new BucketKD(this);
 
}
 
if (size >= keys.length) {
 
// this is basically what a vector does, I am
 
// just removing the overhead of using a vector
 
PointKD tmp[] = new PointKD[keys.length * 2];
 
System.arraycopy(keys, 0, tmp, 0, size);
 
keys = tmp;
 
}
 
keys[size++] = k;
 
root.add(k);
 
 
}
 
}
  
 
/**
 
/**
* Finds the approximate nearest neighbor to PointKD k. This uses a
+
* Returns all PointKD within a certain range defined by an upper and lower
* recursive algorithm.
+
* PointKD.
 
*  
 
*  
* @param k
+
* @param low
* @return
+
*            - lower bounds of area
 +
* @param high
 +
*            - upper bounds of area
 +
* @return - All PointKD between low and high.
 
*/
 
*/
public PointKD getApproxNN(PointKD k) {
+
@SuppressWarnings("unchecked")
if (null == root)
+
public List<T> getRange(double[] low, double[] high) {
return null;
+
Object[] objs = root.range(high, low);
return root.approx(k);
+
ArrayList<T> range = new ArrayList<T>(objs.length);
 +
for(int i=0; i<objs.length; ++i) {
 +
range.add((T)objs[i]);
 +
}
 +
return range;
 
}
 
}
  
 
/**
 
/**
* Finds the nearest neighbor to PointKD k. Uses an exhastive search
+
* Gets the N nearest neighbors to the given key.
* method.
+
*  
* @param k -  
+
* @param key
* @return
+
*            - Key
 +
* @param num
 +
*            - Number of results
 +
* @return Array of Item Objects, distances within the items are the square
 +
*        of the actual distance between them and the key
 
*/
 
*/
public PointKD getNN(PointKD k) {
+
public ResultHeap<T> getNearestNeighbors(double[] key, int num) {
if (null == root)
+
ResultHeap<T> heap = new ResultHeap<T>(num);
return null;
+
root.nearest(heap, key);
return root.nearest(k);
+
return heap;
}
 
 
 
public PointKD[] getRange(RectKD r) {
 
if (null == root)
 
return new PointKD[0];
 
return root.range(r);
 
 
}
 
}
  
public PointKD[] getRange(PointKD low, PointKD high) {
 
return this.getRange(new RectKD(low, high));
 
}
 
  
}
+
// Internal tree node
</pre>
+
private class NodeKD {
 +
private NodeKD left, right;
 +
private double[] maxBounds, minBounds;
 +
private Object[] bucketValues;
 +
private double[][] bucketKeys;
 +
private boolean isLeaf;
 +
private int current, sliceDimension;
 +
private double slice;
  
===NodeKD===
+
private NodeKD() {
<pre>
+
bucketValues = new Object[bucketSize];
package org.csdgn.util;
+
bucketKeys = new double[bucketSize][];
  
public abstract class NodeKD {
+
left = right = null;
protected KDTreeB ref;
+
maxBounds = minBounds = null;
protected BranchKD parent;
+
protected int depth;
+
isLeaf = true;
protected RectKD rect;
+
 
+
current = 0;
protected abstract void add(PointKD k);
 
protected abstract PointKD approx(PointKD k);
 
protected abstract PointKD nearest(PointKD k);
 
protected abstract PointKD[] range(RectKD r);
 
}
 
</pre>
 
 
 
===BranchKD===
 
<pre>
 
package org.csdgn.util;
 
 
 
public class BranchKD extends NodeKD {
 
protected NodeKD left, right;
 
protected double slice;
 
 
 
protected BranchKD(KDTreeB reference) {
 
slice = 0;
 
ref = reference;
 
rect = new RectKD();
 
depth = 0;
 
ref.root = this;
 
left = new BucketKD(this);
 
right = new BucketKD(this);
 
}
 
protected BranchKD(BranchKD b) {
 
ref = b.ref;
 
parent = b;
 
slice = 0;
 
rect = new RectKD();
 
depth = b.depth + 1;
 
left = new BucketKD(this);
 
right = new BucketKD(this);
 
}
 
 
 
protected void add(PointKD k) {
 
rect.add(k);
 
int dim = depth % ref.dimensions;
 
if (k.array[dim] > slice) {
 
right.add(k);
 
} else {
 
left.add(k);
 
 
}
 
}
}
 
 
protected PointKD approx(PointKD k) {
 
int dim = depth % ref.dimensions;
 
if (k.array[dim] > slice)
 
return right.approx(k);
 
return left.approx(k);
 
}
 
  
protected PointKD nearest(PointKD k) {
+
// what it says on the tin
int dim = depth % ref.dimensions;
+
private void addPoint(double[] key, Object val) {
PointKD near = null;
+
if(isLeaf) {
if (k.array[dim] > slice) {
+
addLeafPoint(key,val);
near = right.nearest(k);
+
} else {
double t = near.distanceSq(k);
+
extendBounds(key);
if (k.distanceSq(left.rect.getNearest(k)) < t) {
+
if (key[sliceDimension] > slice) {
PointKD tmp = left.nearest(k);
+
right.addPoint(key, val);
if (tmp.distanceSq(k) < t) {
+
} else {
near = tmp;
+
left.addPoint(key, val);
 
}
 
}
 
}
 
}
} else {
+
}
near = left.nearest(k);
+
double t = near.distanceSq(k);
+
private void addLeafPoint(double[] key, Object val) {
if (k.distanceSq(right.rect.getNearest(k)) < t) {
+
extendBounds(key);
PointKD tmp = right.nearest(k);
+
if (current + 1 > bucketSize) {
if (tmp.distanceSq(k) < t) {
+
splitLeaf();
near = tmp;
+
addPoint(key, val);
 +
return;
 +
}
 +
bucketKeys[current] = key;
 +
bucketValues[current] = val;
 +
++current;
 +
}
 +
 +
/**
 +
* Find the nearest neighbor recursively.
 +
*/
 +
@SuppressWarnings("unchecked")
 +
private void nearest(ResultHeap<T> heap, double[] data) {
 +
if(current == 0)
 +
return;
 +
if(isLeaf) {
 +
//IS LEAF
 +
for (int i = 0; i < current; ++i) {
 +
double dist = pointDistSq(bucketKeys[i], data);
 +
heap.offer(dist, (T) bucketValues[i]);
 +
}
 +
} else {
 +
//IS BRANCH
 +
if (data[sliceDimension] > slice) {
 +
right.nearest(heap, data);
 +
if(left.current == 0)
 +
return;
 +
if (!heap.isFull() || regionDistSq(data,left.minBounds,left.maxBounds) < heap.getMaxKey()) {
 +
left.nearest(heap, data);
 +
}
 +
} else {
 +
left.nearest(heap, data);
 +
if (right.current == 0)
 +
return;
 +
if (!heap.isFull() || regionDistSq(data,right.minBounds,right.maxBounds) < heap.getMaxKey()) {
 +
right.nearest(heap, data);
 +
}
 
}
 
}
 
}
 
}
 
}
 
}
  
return near;
+
// gets all items from within a range
}
+
private Object[] range(double[] upper, double[] lower) {
 
+
if (bucketValues == null) {
protected PointKD[] range(RectKD r) {
+
// Branch
PointKD[] tmp = new PointKD[0];
+
Object[] tmp = new Object[0];
if (r.intersects(left.rect)) {
+
if (intersects(upper, lower, left.maxBounds, left.minBounds)) {
PointKD[] tmpl = left.range(r);
+
Object[] tmpl = left.range(upper, lower);
if(0 == tmp.length)
+
if (0 == tmp.length) tmp = tmpl;
tmp = tmpl;
+
}
}
+
if (intersects(upper, lower, right.maxBounds, right.minBounds)) {
if (r.intersects(right.rect)) {
+
Object[] tmpr = right.range(upper, lower);
PointKD[] tmpr = right.range(r);
+
if (0 == tmp.length)
if (0 == tmp.length)
+
tmp = tmpr;
tmp = tmpr;
+
else if (0 < tmpr.length) {
else if (0 < tmpr.length) {
+
Object[] tmp2 = new Object[tmp.length + tmpr.length];
PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
+
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
+
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
+
tmp = tmp2;
tmp = tmp2;
+
}
 +
}
 +
return tmp;
 +
}
 +
// Leaf
 +
Object[] tmp = new Object[current];
 +
int n = 0;
 +
for (int i = 0; i < current; ++i) {
 +
if (contains(upper, lower, bucketKeys[i])) {
 +
tmp[n++] = bucketValues[i];
 +
}
 
}
 
}
 +
Object[] tmp2 = new Object[n];
 +
System.arraycopy(tmp, 0, tmp2, 0, n);
 +
return tmp2;
 
}
 
}
return tmp;
 
}
 
  
protected BranchKD extend(BucketKD b) {
+
// These are helper functions from here down
if (b == left) {
+
// check if this hyper rectangle contains a give hyper-point
left = null;
+
public boolean contains(double[] upper, double[] lower, double[] point) {
left = new BranchKD(this);
+
if (current == 0) return false;
return (BranchKD) left;
+
for (int i = 0; i < point.length; ++i) {
} else if (b == right) {
+
if (point[i] > upper[i] || point[i] < lower[i]) return false;
right = null;
+
}
right = new BranchKD(this);
+
return true;
return (BranchKD) right;
 
 
}
 
}
return null;
 
}
 
}
 
</pre>
 
  
===BucketKD===
+
// checks if two hyper-rectangles intersect
<pre>
+
public boolean intersects(double[] up0, double[] low0, double[] up1, double[] low1) {
package org.csdgn.util;
+
for (int i = 0; i < up0.length; ++i) {
 
+
if (up1[i] < low0[i] || low1[i] > up0[i]) return false;
public class BucketKD extends NodeKD {
 
protected PointKD bucket[];
 
protected int current;
 
 
 
protected BucketKD(KDTreeB reference) {
 
ref = reference;
 
ref.root = this;
 
bucket = new PointKD[ref.bucketSize];
 
parent = null;
 
rect = new RectKD();
 
depth = 0;
 
}
 
 
protected BucketKD(BranchKD p) {
 
ref = p.ref;
 
bucket = new PointKD[ref.bucketSize];
 
parent = p;
 
rect = new RectKD();
 
depth = p.depth + 1;
 
}
 
 
 
protected PointKD approx(PointKD k) {
 
double nearestDist = Double.POSITIVE_INFINITY;
 
int nearest = 0;
 
for (int i = 0; i < current; i++) {
 
double distance = k.distanceSq(bucket[i]);
 
if (distance < nearestDist) {
 
nearestDist = distance;
 
nearest = i;
 
 
}
 
}
 +
return true;
 
}
 
}
return bucket[nearest];
 
}
 
  
protected PointKD nearest(PointKD k) {
+
private void splitLeaf() {
return approx(k);
+
double bestRange = 0;
}
+
for(int i=0;i<dimensions;++i) {
 
+
double range = maxBounds[i] - minBounds[i];
protected void add(PointKD k) {
+
if(range > bestRange) {
if (current >= ref.bucketSize) {
+
sliceDimension = i;
BranchKD b = null;
+
bestRange = range;
if (null == parent)
+
}
b = new BranchKD(ref);
 
else
 
b = parent.extend(this);
 
 
 
int dim = b.depth % ref.dimensions;
 
double total = 0;
 
for (int i = 0; i < current; i++) {
 
total += bucket[i].array[dim];
 
 
}
 
}
b.slice = total / current;
+
for (int i = 0; i < current; i++) {
+
left = new NodeKD();
b.add(bucket[i]);
+
right = new NodeKD();
 +
 +
slice = (maxBounds[sliceDimension] + minBounds[sliceDimension]) * 0.5;
 +
 +
for (int i = 0; i < current; ++i) {
 +
if (bucketKeys[i][sliceDimension] > slice) {
 +
right.addLeafPoint(bucketKeys[i], bucketValues[i]);
 +
} else {
 +
left.addLeafPoint(bucketKeys[i], bucketValues[i]);
 +
}
 
}
 
}
b.add(k);
+
bucketKeys = null;
bucket = null;
+
bucketValues = null;
parent = null;
+
isLeaf = false;
current = 0;
 
return;
 
 
}
 
}
rect.add(k);
 
bucket[current++] = k;
 
}
 
  
protected PointKD[] range(RectKD r) {
+
// expands this hyper rectangle
PointKD[] tmp = new PointKD[current];
+
private void extendBounds(double[] key) {
int n = 0;
+
if (maxBounds == null) {
for (int i = 0; i < current; i++) {
+
maxBounds = Arrays.copyOf(key, dimensions);
if (r.contains(bucket[i])) {
+
minBounds = Arrays.copyOf(key, dimensions);
tmp[n++] = bucket[i];
+
return;
 +
}
 +
for (int i = 0; i < key.length; ++i) {
 +
if (maxBounds[i] < key[i]) maxBounds[i] = key[i];
 +
if (minBounds[i] > key[i]) minBounds[i] = key[i];
 
}
 
}
 
}
 
}
PointKD[] tmp2 = new PointKD[n];
 
System.arraycopy(tmp, 0, tmp2, 0, n);
 
return tmp2;
 
 
}
 
}
 +
 +
/* I may have borrowed these from an early version of Red's tree. I however forget. */
 +
private static final double pointDistSq(double[] p1, double[] p2) {
 +
        double d = 0;
 +
        double q = 0;
 +
        for (int i = 0; i < p1.length; ++i) {
 +
            d += (q=(p1[i] - p2[i]))*q;
 +
        }
 +
        return d;
 +
    }
 +
 +
    private static final double regionDistSq(double[] point, double[] min, double[] max) {
 +
        double d = 0;
 +
        double q = 0;
 +
        for (int i = 0; i < point.length; ++i) {
 +
            if (point[i] > max[i]) {
 +
            d += (q = (point[i] - max[i]))*q;
 +
            } else if (point[i] < min[i]) {
 +
                d += (q = (point[i] - min[i]))*q;
 +
            }
 +
        }
 +
        return d;
 +
    }
 
}
 
}
</pre>
+
</syntaxhighlight>
 
 
  
===PointKD===
+
=== ResultHeap ===
<pre>
+
<syntaxhighlight>
 
package org.csdgn.util;
 
package org.csdgn.util;
  
 
/**
 
/**
* A K-Dimensional Point, used in conjunction with the KD-Tree multidimensional
+
  * @author Chase
* data structure.
 
*
 
  * @author Robert Maupin
 
 
  *  
 
  *  
 +
* @param <T>
 
  */
 
  */
public class PointKD {
+
public class ResultHeap<T> {
protected double[] array;
+
private Object[] data;
 +
private double[] keys;
 +
private int capacity;
 +
private int size;
  
/**
+
protected ResultHeap(int capacity) {
* Constructor defining the number of dimensions to use in this Point;
+
this.data = new Object[capacity];
*
+
this.keys = new double[capacity];
* @param dimensions
+
this.capacity = capacity;
*            - The number of dimensions in this point
+
this.size = 0;
*/
 
public PointKD(int dimensions) {
 
array = new double[dimensions];
 
 
}
 
}
  
/**
+
protected void offer(double key, T value) {
* Creates a PointKD from an array.
+
int i = size;
*
+
for (; i > 0 && keys[i - 1] > key; --i);
* @param pos
+
if (i >= capacity) return;
*            - An array of Numbers
+
if (size < capacity) ++size;
*/
+
int j = i + 1;
public PointKD(double[] pos) {
+
System.arraycopy(keys, i, keys, j, size - j);
array = pos.clone();
+
keys[i] = key;
// array = new double[pos.length];
+
System.arraycopy(data, i, data, j, size - j);
// System.arraycopy(pos, 0, array, 0, array.length);
+
data[i] = value;
 
}
 
}
  
/**
+
public double getMaxKey() {
* Creates a copy of the selected point.
+
return keys[size - 1];
*
 
* @param p
 
*            - A point to copy
 
*/
 
public PointKD(PointKD p) {
 
// array = new double[p.array.length];
 
setPoint(p);
 
 
}
 
}
 
+
/**
+
@SuppressWarnings("unchecked")
* Sets the coordinates of this point to be the same as the selected point.
+
public T removeMax() {
*
+
if(isEmpty()) return null;
* @param p
+
return (T)data[--size];
*            - A point to copy
 
*/
 
public void setPoint(PointKD p) {
 
this.array = p.array.clone();
 
// System.arraycopy(p.array, 0, array, 0, array.length);
 
}
 
 
 
/**
 
* Returns the number of dimensions this point has.
 
*
 
* @return the dimensions
 
*/
 
public int getDimensions() {
 
return array.length;
 
}
 
 
 
/**
 
* Returns the coordinate at dimension <b>i</b>.
 
*
 
* @param i
 
*            - The dimension to retrieve.
 
* @return The coordinate at i.
 
*/
 
public double getCoordinate(int i) {
 
return array[i];
 
}
 
 
 
/**
 
* Sets the coordinate to <b>k</b> at dimension <b>i</b>.
 
*
 
* @param i
 
*            - the dimension to set
 
* @param k
 
*            - the value to set the dimension to
 
*/
 
public void setCoordinate(int i, double k) {
 
array[i] = k;
 
}
 
 
 
/**
 
* Compares this to a selected point and returns the euclidean distance
 
* between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distance(PointKD p) {
 
return distance(this, p);
 
}
 
 
 
/**
 
* Compares this to a selected point and returns the squared euclidean
 
* distance between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distanceSq(PointKD p) {
 
return distanceSq(this, p);
 
}
 
 
 
/**
 
* Prints the class name and the point coordinates.
 
*/
 
public String toString() {
 
String output = getClass().getName() + "[";
 
for (int i = 0; i < array.length; i++) {
 
if (0 != i)
 
output += ",";
 
output += array[i];
 
}
 
return output + "]";
 
}
 
 
 
/**
 
* @return a distinct copy of this object
 
*/
 
public Object clone() {
 
return new PointKD(this);
 
}
 
 
 
/**
 
* Compares two Points and returns the euclidean distance between them.
 
*
 
* @param a
 
*            - The first set of numbers
 
* @param b
 
*            - The second set of numbers
 
* @return The distance between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distance(PointKD a, PointKD b) {
 
return distance(a.array, b.array);
 
 
}
 
}
  
/**
+
public boolean isEmpty() {
* Compares two Points and returns the squared euclidean distance between
+
return size == 0;
* them.
 
*
 
* @param a
 
*            - The first set of numbers
 
* @param b
 
*            - The second set of numbers
 
* @return The distance between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distanceSq(PointKD a, PointKD b) {
 
return distanceSq(a.array, b.array);
 
 
}
 
}
  
/**
+
public boolean isFull() {
* Compares arrays of double and returns the euclidean distance between
+
return size == capacity;
* them.
 
*
 
* @param a
 
*            - The first set of numbers
 
* @param b
 
*            - The second set of numbers
 
* @return The distance between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distance(double[] a, double[] b) {
 
return Math.sqrt(distanceSq(a, b));
 
 
}
 
}
  
/**
+
public int size() {
* Compares arrays of double and returns the squared euclidean distance
+
return size;
* between them.
 
*
 
* @param a
 
*            - The first set of numbers
 
* @param b
 
*            - The second set of numbers
 
* @return The distance squared between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distanceSq(double[] a, double[] b) {
 
if (a.length != b.length || a.length < 1)
 
return -1;
 
double total = 0;
 
for (int i = 0; i < a.length; i++)
 
total += (b[i] - a[i]) * (b[i] - a[i]);
 
return total;
 
 
}
 
}
}
 
</pre>
 
  
===RectKD===
+
public int capacity() {
<pre>
+
return capacity;
package org.csdgn.util;
 
 
 
/**
 
* A K-Dimensional Hyper Rectangle, used in conjunction with the KD-Tree
 
* multidimensional data structure.
 
*
 
* @author Robert Maupin
 
*
 
*/
 
public class RectKD {
 
PointKD upper;
 
PointKD lower;
 
 
 
/**
 
* Creates an empty RectKD
 
*/
 
public RectKD() {
 
upper = null;
 
lower = null;
 
}
 
 
 
public RectKD(PointKD s) {
 
this(s, s);
 
}
 
 
 
public RectKD(PointKD l, PointKD u) {
 
this();
 
if (l.array.length == u.array.length) {
 
upper = new PointKD(u);
 
lower = new PointKD(l);
 
}
 
}
 
 
 
public void add(PointKD p) {
 
if (upper == null) {
 
upper = new PointKD(p);
 
lower = new PointKD(p);
 
return;
 
}
 
 
 
for (int i = 0; i < upper.array.length; i++) {
 
if (p.array[i] > upper.array[i])
 
upper.array[i] = p.array[i];
 
if (p.array[i] < lower.array[i])
 
lower.array[i] = p.array[i];
 
}
 
}
 
 
 
public boolean contains(PointKD p) {
 
if (upper == null)
 
return false;
 
if (upper.array.length != p.array.length)
 
return false;
 
boolean inside = true;
 
for (int i = 0; i < upper.array.length; i++) {
 
if (p.array[i] > upper.array[i])
 
inside = false;
 
if (p.array[i] < lower.array[i])
 
inside = false;
 
}
 
return inside;
 
}
 
 
 
public boolean intersects(RectKD r) {
 
boolean check = false;
 
 
 
if (null == r)
 
return false;
 
if (null == r.lower)
 
return false;
 
if (null == r.upper)
 
return false;
 
if (r.upper.array.length != upper.array.length)
 
return false;
 
 
 
int len = upper.array.length;
 
for (int i = 0; i < len; i++) {
 
if (upper.array[i] < r.lower.array[i])
 
check = true;
 
if (lower.array[i] > r.upper.array[i])
 
check = true;
 
}
 
 
 
return !check;
 
}
 
 
 
public PointKD getNearest(PointKD p) {
 
if (upper == null)
 
return null;
 
if (upper.array.length != p.array.length)
 
return null;
 
PointKD near = new PointKD(p);
 
for (int i = 0; i < upper.array.length; i++) {
 
near.array[i] = p.array[i];
 
if (p.array[i] > upper.array[i])
 
near.array[i] = upper.array[i];
 
if (p.array[i] < lower.array[i])
 
near.array[i] = lower.array[i];
 
}
 
return near;
 
 
}
 
}
 
}
 
}
</pre>
+
</syntaxhighlight>

Latest revision as of 21:00, 7 November 2012

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

This and all my other code in which I display on the robowiki falls under the ZLIB License.

Oh yeah, am I the only one that has a Range function?

KDTreeF

package org.csdgn.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This is a KD Bucket Tree, for fast sorting and searching of K dimensional
 * data.
 * 
 * @author Chase
 * 
 */
public class KDTree<T> {
	protected static final int defaultBucketSize = 48;

	private final int dimensions;
	private final int bucketSize;
	private NodeKD root;

	/**
	 * Constructor with value for dimensions.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 */
	public KDTree(int dimensions) {
		this.dimensions = dimensions;
		this.bucketSize = defaultBucketSize;
		this.root = new NodeKD();
	}

	/**
	 * Constructor with value for dimensions and bucket size.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 * @param bucket
	 *            - Size of the buckets.
	 */
	public KDTree(int dimensions, int bucket) {
		this.dimensions = dimensions;
		this.bucketSize = bucket;
		this.root = new NodeKD();
	}

	/**
	 * Add a key and its associated value to the tree.
	 * 
	 * @param key
	 *            - Key to add
	 * @param val
	 *            - object to add
	 */
	public void add(double[] key, T val) {
		root.addPoint(key, val);
	}

	/**
	 * Returns all PointKD within a certain range defined by an upper and lower
	 * PointKD.
	 * 
	 * @param low
	 *            - lower bounds of area
	 * @param high
	 *            - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	@SuppressWarnings("unchecked")
	public List<T> getRange(double[] low, double[] high) {
		Object[] objs = root.range(high, low);
		ArrayList<T> range = new ArrayList<T>(objs.length);
		for(int i=0; i<objs.length; ++i) {
			range.add((T)objs[i]);
		}
		return range;
	}

	/**
	 * Gets the N nearest neighbors to the given key.
	 * 
	 * @param key
	 *            - Key
	 * @param num
	 *            - Number of results
	 * @return Array of Item Objects, distances within the items are the square
	 *         of the actual distance between them and the key
	 */
	public ResultHeap<T> getNearestNeighbors(double[] key, int num) {
		ResultHeap<T> heap = new ResultHeap<T>(num);
		root.nearest(heap, key);
		return heap;
	}


	// Internal tree node
	private class NodeKD {
		private NodeKD left, right;
		private double[] maxBounds, minBounds;
		private Object[] bucketValues;
		private double[][] bucketKeys;
		private boolean isLeaf;
		private int current, sliceDimension;
		private double slice;

		private NodeKD() {
			bucketValues = new Object[bucketSize];
			bucketKeys = new double[bucketSize][];

			left = right = null;
			maxBounds = minBounds = null;
			
			isLeaf = true;
			
			current = 0;
		}

		// what it says on the tin
		private void addPoint(double[] key, Object val) {
			if(isLeaf) {
				addLeafPoint(key,val);
			} else {
				extendBounds(key);
				if (key[sliceDimension] > slice) {
					right.addPoint(key, val);
				} else {
					left.addPoint(key, val);
				}
			}
		}
		
		private void addLeafPoint(double[] key, Object val) {
			extendBounds(key);
			if (current + 1 > bucketSize) {
				splitLeaf();
				addPoint(key, val);
				return;
			}
			bucketKeys[current] = key;
			bucketValues[current] = val;
			++current;
		}
		
		/**
		 * Find the nearest neighbor recursively.
		 */
		@SuppressWarnings("unchecked")
		private void nearest(ResultHeap<T> heap, double[] data) {
			if(current == 0)
				return;
			if(isLeaf) {
				//IS LEAF
				for (int i = 0; i < current; ++i) {
					double dist = pointDistSq(bucketKeys[i], data);
					heap.offer(dist, (T) bucketValues[i]);
				}
			} else {
				//IS BRANCH
				if (data[sliceDimension] > slice) {
					right.nearest(heap, data);
					if(left.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,left.minBounds,left.maxBounds) < heap.getMaxKey()) {
						left.nearest(heap, data);
					}
				} else {
					left.nearest(heap, data);
					if (right.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,right.minBounds,right.maxBounds) < heap.getMaxKey()) {
						right.nearest(heap, data);
					}
				}
			}
		}

		// gets all items from within a range
		private Object[] range(double[] upper, double[] lower) {
			if (bucketValues == null) {
				// Branch
				Object[] tmp = new Object[0];
				if (intersects(upper, lower, left.maxBounds, left.minBounds)) {
					Object[] tmpl = left.range(upper, lower);
					if (0 == tmp.length) tmp = tmpl;
				}
				if (intersects(upper, lower, right.maxBounds, right.minBounds)) {
					Object[] tmpr = right.range(upper, lower);
					if (0 == tmp.length)
						tmp = tmpr;
					else if (0 < tmpr.length) {
						Object[] tmp2 = new Object[tmp.length + tmpr.length];
						System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
						System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
						tmp = tmp2;
					}
				}
				return tmp;
			}
			// Leaf
			Object[] tmp = new Object[current];
			int n = 0;
			for (int i = 0; i < current; ++i) {
				if (contains(upper, lower, bucketKeys[i])) {
					tmp[n++] = bucketValues[i];
				}
			}
			Object[] tmp2 = new Object[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}

		// These are helper functions from here down
		// check if this hyper rectangle contains a give hyper-point
		public boolean contains(double[] upper, double[] lower, double[] point) {
			if (current == 0) return false;
			for (int i = 0; i < point.length; ++i) {
				if (point[i] > upper[i] || point[i] < lower[i]) return false;
			}
			return true;
		}

		// checks if two hyper-rectangles intersect
		public boolean intersects(double[] up0, double[] low0, double[] up1, double[] low1) {
			for (int i = 0; i < up0.length; ++i) {
				if (up1[i] < low0[i] || low1[i] > up0[i]) return false;
			}
			return true;
		}

		private void splitLeaf() {
			double bestRange = 0;
			for(int i=0;i<dimensions;++i) {
				double range = maxBounds[i] - minBounds[i];
				if(range > bestRange) {
					sliceDimension = i;
					bestRange = range;
				}
			}
			
			left = new NodeKD();
			right = new NodeKD();
			
			slice = (maxBounds[sliceDimension] + minBounds[sliceDimension]) * 0.5;
			
			for (int i = 0; i < current; ++i) {
				if (bucketKeys[i][sliceDimension] > slice) {
					right.addLeafPoint(bucketKeys[i], bucketValues[i]);
				} else {
					left.addLeafPoint(bucketKeys[i], bucketValues[i]);
				}
			}
			bucketKeys = null;
			bucketValues = null;
			isLeaf = false;
		}

		// expands this hyper rectangle
		private void extendBounds(double[] key) {
			if (maxBounds == null) {
				maxBounds = Arrays.copyOf(key, dimensions);
				minBounds = Arrays.copyOf(key, dimensions);
				return;
			}
			for (int i = 0; i < key.length; ++i) {
				if (maxBounds[i] < key[i]) maxBounds[i] = key[i];
				if (minBounds[i] > key[i]) minBounds[i] = key[i];
			}
		}
	}
	
	/* I may have borrowed these from an early version of Red's tree. I however forget. */
	private static final double pointDistSq(double[] p1, double[] p2) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < p1.length; ++i) {
            d += (q=(p1[i] - p2[i]))*q;
        }
        return d;
    }

    private static final double regionDistSq(double[] point, double[] min, double[] max) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < point.length; ++i) {
            if (point[i] > max[i]) {
            	d += (q = (point[i] - max[i]))*q;
            } else if (point[i] < min[i]) {
                d += (q = (point[i] - min[i]))*q;
            }
        }
        return d;
    }
}

ResultHeap

package org.csdgn.util;

/**
 * @author Chase
 * 
 * @param <T>
 */
public class ResultHeap<T> {
	private Object[] data;
	private double[] keys;
	private int capacity;
	private int size;

	protected ResultHeap(int capacity) {
		this.data = new Object[capacity];
		this.keys = new double[capacity];
		this.capacity = capacity;
		this.size = 0;
	}

	protected void offer(double key, T value) {
		int i = size;
		for (; i > 0 && keys[i - 1] > key; --i);
		if (i >= capacity) return;
		if (size < capacity) ++size;
		int j = i + 1;
		System.arraycopy(keys, i, keys, j, size - j);
		keys[i] = key;
		System.arraycopy(data, i, data, j, size - j);
		data[i] = value;
	}

	public double getMaxKey() {
		return keys[size - 1];
	}
	
	@SuppressWarnings("unchecked")
	public T removeMax() {
		if(isEmpty()) return null;
		return (T)data[--size];
	}

	public boolean isEmpty() {
		return size == 0;
	}

	public boolean isFull() {
		return size == capacity;
	}

	public int size() {
		return size;
	}

	public int capacity() {
		return capacity;
	}
}