Difference between revisions of "User:Chase-san/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
(Updated, now with nearest N neighbors)
m (→‎KDTreeF: (no major change, just some minor tweaks, i++ to ++i, removing the enhanced for loop))
 
(11 intermediate revisions by the same user not shown)
Line 2: Line 2:
 
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
 
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
  
=== KDTreeB ===
+
This and all my other code in which I display on the robowiki falls under the [http://en.wikipedia.org/wiki/Zlib_License ZLIB License].
<pre>
+
 
 +
Oh yeah, am I the only one that has a Range function?
 +
 
 +
=== KDTreeF ===
 +
<syntaxhighlight>
 
package org.csdgn.util;
 
package org.csdgn.util;
  
import java.util.Comparator;
+
import java.util.ArrayList;
 +
import java.util.Arrays;
 +
import java.util.List;
 +
 
 +
/**
 +
* This is a KD Bucket Tree, for fast sorting and searching of K dimensional
 +
* data.
 +
*
 +
* @author Chase
 +
*
 +
*/
 +
public class KDTree<T> {
 +
protected static final int defaultBucketSize = 48;
 +
 
 +
private final int dimensions;
 +
private final int bucketSize;
 +
private NodeKD root;
  
public class KDTreeB {
 
public final static int DEFAULT_BUCKET_SIZE = 200;
 
protected NodeKD root;
 
protected int dimensions;
 
protected int bucket_size;
 
protected long size = 0;
 
 
 
/**
 
/**
* This creates a KDTreeB with the given number of dimensions and
+
* Constructor with value for dimensions.
* the default bucket size.
+
*  
 
* @param dimensions
 
* @param dimensions
 +
*            - Number of dimensions
 
*/
 
*/
public KDTreeB(int dimensions) {
+
public KDTree(int dimensions) {
this(dimensions,DEFAULT_BUCKET_SIZE);
+
this.dimensions = dimensions;
 +
this.bucketSize = defaultBucketSize;
 +
this.root = new NodeKD();
 
}
 
}
 +
 
/**
 
/**
* This creates a KDTreeB with the given number of dimensions and
+
* Constructor with value for dimensions and bucket size.
* the given bucket size.
+
*  
 
* @param dimensions
 
* @param dimensions
* @param bucket_size
+
*            - Number of dimensions
 +
* @param bucket
 +
*            - Size of the buckets.
 
*/
 
*/
public KDTreeB(int dimensions, int bucket_size) {
+
public KDTree(int dimensions, int bucket) {
 
this.dimensions = dimensions;
 
this.dimensions = dimensions;
this.bucket_size = bucket_size;
+
this.bucketSize = bucket;
root = new BucketKD(this);
+
this.root = new NodeKD();
 
}
 
}
+
 
 
/**
 
/**
* Adds the given point to the tree. Uses a recursive algorithm.
+
* Add a key and its associated value to the tree.
* @param point
+
*
 +
* @param key
 +
*            - Key to add
 +
* @param val
 +
*            - object to add
 
*/
 
*/
public void add(PointKD point) {
+
public void add(double[] key, T val) {
root.add(point);
+
root.addPoint(key, val);
 
}
 
}
+
 
 
/**
 
/**
* Returns the nearest neighbor to the given point.
+
* Returns all PointKD within a certain range defined by an upper and lower
* @param point - PointKD to find the nearest to.
+
* PointKD.
* @return - The nearest PointKD to
+
*
 +
* @param low
 +
*            - lower bounds of area
 +
* @param high
 +
*            - upper bounds of area
 +
* @return - All PointKD between low and high.
 
*/
 
*/
public PointKD getNearestNeighbor(PointKD point) {
+
@SuppressWarnings("unchecked")
return root.nearest(point);
+
public List<T> getRange(double[] low, double[] high) {
}
+
Object[] objs = root.range(high, low);
+
ArrayList<T> range = new ArrayList<T>(objs.length);
/**
+
for(int i=0; i<objs.length; ++i) {
* Returns the nearest num neighbors to the given point.
+
range.add((T)objs[i]);
* @param point - PointKD to find the nearest to.
 
* @param num - The number of points to find.
 
* @return - The nearest PointKD's to point.
 
*/
 
public PointKD[] getNearestNeighbors(final PointKD point, int num) {
 
if(num == 1) {
 
return new PointKD[] { getNearestNeighbor(point) };
 
 
}
 
}
PriorityDeque<PointKD> queue = new PriorityDeque<PointKD>(
+
return range;
new Comparator<PointKD>(){
 
@Override
 
public int compare(PointKD o1, PointKD o2) {
 
double dist0 = point.distance(o1);
 
double dist1 = point.distance(o2);
 
return (dist0-dist1 < 0) ? -1 : 1;
 
}
 
},num);
 
 
root.nearestn(queue,point);
 
 
PointKD[] array = new PointKD[num];
 
Object[] obj = queue.toArray();
 
for(int i=0; i<num; ++i) {
 
array[i] = (PointKD)obj[i];
 
}
 
 
return array;
 
 
}
 
}
+
 
 
/**
 
/**
* Returns all PointKD within a certain RectangleKD.
+
* Gets the N nearest neighbors to the given key.
* @param rect - area to get PointKD from
+
*
* @return - All PointKD within rect.
+
* @param key
 +
*            - Key
 +
* @param num
 +
*            - Number of results
 +
* @return Array of Item Objects, distances within the items are the square
 +
*        of the actual distance between them and the key
 
*/
 
*/
public PointKD[] getRange(RectangleKD rect) {
+
public ResultHeap<T> getNearestNeighbors(double[] key, int num) {
return root.range(rect);
+
ResultHeap<T> heap = new ResultHeap<T>(num);
 +
root.nearest(heap, key);
 +
return heap;
 
}
 
}
  
/**
+
 
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
+
// Internal tree node
* @param low - lower bounds of area
+
private class NodeKD {
* @param high - upper bounds of area
+
private NodeKD left, right;
* @return - All PointKD between low and high.
+
private double[] maxBounds, minBounds;
*/
+
private Object[] bucketValues;
public PointKD[] getRange(PointKD low, PointKD high) {
+
private double[][] bucketKeys;
return this.getRange(new RectangleKD(low, high));
+
private boolean isLeaf;
}
+
private int current, sliceDimension;
+
private double slice;
protected abstract class NodeKD {
+
 
protected KDTreeB owner;
+
private NodeKD() {
protected BranchKD parent;
+
bucketValues = new Object[bucketSize];
protected RectangleKD rect;
+
bucketKeys = new double[bucketSize][];
protected int depth = 0;
+
 
 +
left = right = null;
 +
maxBounds = minBounds = null;
 +
 +
isLeaf = true;
 +
 +
current = 0;
 +
}
 +
 
 +
// what it says on the tin
 +
private void addPoint(double[] key, Object val) {
 +
if(isLeaf) {
 +
addLeafPoint(key,val);
 +
} else {
 +
extendBounds(key);
 +
if (key[sliceDimension] > slice) {
 +
right.addPoint(key, val);
 +
} else {
 +
left.addPoint(key, val);
 +
}
 +
}
 +
}
 
 
protected abstract void add(PointKD k);
+
private void addLeafPoint(double[] key, Object val) {
protected abstract PointKD nearest(PointKD k);
+
extendBounds(key);
protected abstract void nearestn(PriorityDeque<PointKD> queue, PointKD k);
+
if (current + 1 > bucketSize) {
protected abstract PointKD[] range(RectangleKD r);
+
splitLeaf();
}
+
addPoint(key, val);
protected class BucketKD extends NodeKD {
+
return;
protected PointKD bucket[];
+
}
protected int current;
+
bucketKeys[current] = key;
+
bucketValues[current] = val;
public BucketKD(KDTreeB owner) {
+
++current;
this.owner = owner;
 
bucket = new PointKD[owner.bucket_size];
 
rect = new RectangleKD();
 
parent = null;
 
}
 
public BucketKD(BranchKD branch) {
 
this(branch.owner);
 
parent = branch;
 
depth = branch.depth + 1;
 
 
}
 
}
 
 
@Override
+
/**
protected void add(PointKD point) {
+
* Find the nearest neighbor recursively.
if(current >= bucket.length) {
+
*/
//Split the bucket into a branch
+
@SuppressWarnings("unchecked")
BranchKD branch = new BranchKD(this);
+
private void nearest(ResultHeap<T> heap, double[] data) {
if(parent == null) {
+
if(current == 0)
owner.root = branch;
+
return;
 +
if(isLeaf) {
 +
//IS LEAF
 +
for (int i = 0; i < current; ++i) {
 +
double dist = pointDistSq(bucketKeys[i], data);
 +
heap.offer(dist, (T) bucketValues[i]);
 +
}
 +
} else {
 +
//IS BRANCH
 +
if (data[sliceDimension] > slice) {
 +
right.nearest(heap, data);
 +
if(left.current == 0)
 +
return;
 +
if (!heap.isFull() || regionDistSq(data,left.minBounds,left.maxBounds) < heap.getMaxKey()) {
 +
left.nearest(heap, data);
 +
}
 
} else {
 
} else {
if(parent.isLeft(this)) {
+
left.nearest(heap, data);
parent.left = branch;
+
if (right.current == 0)
} else {
+
return;
parent.right = branch;
+
if (!heap.isFull() || regionDistSq(data,right.minBounds,right.maxBounds) < heap.getMaxKey()) {
 +
right.nearest(heap, data);
 
}
 
}
 
}
 
}
branch.add(point);
 
bucket = null;
 
current = 0;
 
return;
 
 
}
 
}
bucket[current++] = point;
 
rect.expand(point);
 
 
}
 
}
  
@Override
+
// gets all items from within a range
protected PointKD nearest(PointKD k) {
+
private Object[] range(double[] upper, double[] lower) {
double nearestDist = Double.POSITIVE_INFINITY;
+
if (bucketValues == null) {
int nearest = 0;
+
// Branch
for (int i = 0; i < current; i++) {
+
Object[] tmp = new Object[0];
double distance = k.distanceSq(bucket[i]);
+
if (intersects(upper, lower, left.maxBounds, left.minBounds)) {
if (distance < nearestDist) {
+
Object[] tmpl = left.range(upper, lower);
nearestDist = distance;
+
if (0 == tmp.length) tmp = tmpl;
nearest = i;
+
}
 +
if (intersects(upper, lower, right.maxBounds, right.minBounds)) {
 +
Object[] tmpr = right.range(upper, lower);
 +
if (0 == tmp.length)
 +
tmp = tmpr;
 +
else if (0 < tmpr.length) {
 +
Object[] tmp2 = new Object[tmp.length + tmpr.length];
 +
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
 +
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
 +
tmp = tmp2;
 +
}
 
}
 
}
 +
return tmp;
 
}
 
}
return bucket[nearest];
+
// Leaf
}
+
Object[] tmp = new Object[current];
 
@Override
 
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
 
for(int i = 0; i < current; i++) {
 
queue.offer(bucket[i]);
 
}
 
}
 
 
 
@Override
 
protected PointKD[] range(RectangleKD r) {
 
PointKD[] tmp = new PointKD[current];
 
 
int n = 0;
 
int n = 0;
for (int i = 0; i < current; i++) {
+
for (int i = 0; i < current; ++i) {
if (r.contains(bucket[i])) {
+
if (contains(upper, lower, bucketKeys[i])) {
tmp[n++] = bucket[i];
+
tmp[n++] = bucketValues[i];
 
}
 
}
 
}
 
}
PointKD[] tmp2 = new PointKD[n];
+
Object[] tmp2 = new Object[n];
 
System.arraycopy(tmp, 0, tmp2, 0, n);
 
System.arraycopy(tmp, 0, tmp2, 0, n);
 
return tmp2;
 
return tmp2;
 
}
 
}
+
 
}
+
// These are helper functions from here down
protected class BranchKD extends NodeKD {
+
// check if this hyper rectangle contains a give hyper-point
protected NodeKD left, right;
+
public boolean contains(double[] upper, double[] lower, double[] point) {
protected double slice;
+
if (current == 0) return false;
protected int dim;
+
for (int i = 0; i < point.length; ++i) {
+
if (point[i] > upper[i] || point[i] < lower[i]) return false;
public BranchKD(BucketKD k) {
 
owner = k.owner;
 
parent = k.parent;
 
slice = 0;
 
rect = k.rect;
 
depth = k.depth;
 
left = new BucketKD(this);
 
right = new BucketKD(this);
 
 
dim = depth % owner.dimensions;
 
double total = 0;
 
for (int i = 0; i < k.current; i++) {
 
total += k.bucket[i].internal[dim];
 
 
}
 
}
slice = total / k.current;
+
return true;
for (int i = 0; i < k.current; i++)
 
add(k.bucket[i]);
 
 
}
 
}
+
 
@Override
+
// checks if two hyper-rectangles intersect
protected void add(PointKD k) {
+
public boolean intersects(double[] up0, double[] low0, double[] up1, double[] low1) {
if(k.internal[dim] > slice) {
+
for (int i = 0; i < up0.length; ++i) {
right.add(k);
+
if (up1[i] < low0[i] || low1[i] > up0[i]) return false;
} else {
 
left.add(k);
 
 
}
 
}
 +
return true;
 
}
 
}
  
@Override
+
private void splitLeaf() {
protected PointKD nearest(PointKD k) {
+
double bestRange = 0;
PointKD near = null;
+
for(int i=0;i<dimensions;++i) {
if (k.internal[dim] > slice) {
+
double range = maxBounds[i] - minBounds[i];
near = right.nearest(k);
+
if(range > bestRange) {
double t = near.distanceSq(k);
+
sliceDimension = i;
if (k.distanceSq(left.rect.getNearest(k)) < t) {
+
bestRange = range;
PointKD tmp = left.nearest(k);
 
if (tmp.distanceSq(k) < t) {
 
near = tmp;
 
}
 
}
 
} else {
 
near = left.nearest(k);
 
double t = near.distanceSq(k);
 
if (k.distanceSq(right.rect.getNearest(k)) < t) {
 
PointKD tmp = right.nearest(k);
 
if (tmp.distanceSq(k) < t) {
 
near = tmp;
 
}
 
 
}
 
}
 
}
 
}
return near;
+
}
+
left = new NodeKD();
+
right = new NodeKD();
@Override
+
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
+
slice = (maxBounds[sliceDimension] + minBounds[sliceDimension]) * 0.5;
//TODO
+
if(k.internal[dim] > slice) {
+
for (int i = 0; i < current; ++i) {
right.nearestn(queue, k);
+
if (bucketKeys[i][sliceDimension] > slice) {
double t = queue.peekBottom().distanceSq(k);
+
right.addLeafPoint(bucketKeys[i], bucketValues[i]);
if(k.distanceSq(left.rect.getNearest(k)) < t) {
+
} else {
left.nearestn(queue, k);
+
left.addLeafPoint(bucketKeys[i], bucketValues[i]);
}
 
} else {
 
left.nearestn(queue, k);
 
double t = queue.peekBottom().distanceSq(k);
 
if(k.distanceSq(right.rect.getNearest(k)) < t) {
 
right.nearestn(queue, k);
 
 
}
 
}
 
}
 
}
 +
bucketKeys = null;
 +
bucketValues = null;
 +
isLeaf = false;
 
}
 
}
  
@Override
+
// expands this hyper rectangle
protected PointKD[] range(RectangleKD r) {
+
private void extendBounds(double[] key) {
PointKD[] tmp = new PointKD[0];
+
if (maxBounds == null) {
if (r.intersects(left.rect)) {
+
maxBounds = Arrays.copyOf(key, dimensions);
PointKD[] tmpl = left.range(r);
+
minBounds = Arrays.copyOf(key, dimensions);
if(0 == tmp.length)
+
return;
tmp = tmpl;
 
 
}
 
}
if (r.intersects(right.rect)) {
+
for (int i = 0; i < key.length; ++i) {
PointKD[] tmpr = right.range(r);
+
if (maxBounds[i] < key[i]) maxBounds[i] = key[i];
if (0 == tmp.length)
+
if (minBounds[i] > key[i]) minBounds[i] = key[i];
tmp = tmpr;
 
else if (0 < tmpr.length) {
 
PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
 
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
 
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
 
tmp = tmp2;
 
}
 
 
}
 
}
return tmp;
 
}
 
 
protected boolean isLeft(BucketKD kd) {
 
if(left == kd) return true;
 
return false;
 
 
}
 
}
 
}
 
}
 +
 +
/* I may have borrowed these from an early version of Red's tree. I however forget. */
 +
private static final double pointDistSq(double[] p1, double[] p2) {
 +
        double d = 0;
 +
        double q = 0;
 +
        for (int i = 0; i < p1.length; ++i) {
 +
            d += (q=(p1[i] - p2[i]))*q;
 +
        }
 +
        return d;
 +
    }
 +
 +
    private static final double regionDistSq(double[] point, double[] min, double[] max) {
 +
        double d = 0;
 +
        double q = 0;
 +
        for (int i = 0; i < point.length; ++i) {
 +
            if (point[i] > max[i]) {
 +
            d += (q = (point[i] - max[i]))*q;
 +
            } else if (point[i] < min[i]) {
 +
                d += (q = (point[i] - min[i]))*q;
 +
            }
 +
        }
 +
        return d;
 +
    }
 
}
 
}
</pre>
+
</syntaxhighlight>
  
=== PointKD ===
+
=== ResultHeap ===
<pre>
+
<syntaxhighlight>
 
package org.csdgn.util;
 
package org.csdgn.util;
 
import java.io.Serializable;
 
  
 
/**
 
/**
  * PointKD class is a class that wraps a double array, for use in K-Dimensional structures.
+
  * @author Chase
* This wrapping is done to improve readability and modularity. Often used to define a
 
* point in k-dimensional space.
 
 
  *  
 
  *  
  * @author Chase
+
  * @param <T>
 
  */
 
  */
public class PointKD implements Serializable {
+
public class ResultHeap<T> {
private static final long serialVersionUID = -841162798668123755L;
+
private Object[] data;
protected double[] internal;
+
private double[] keys;
+
private int capacity;
/**
+
private int size;
* Constructor
+
 
* @param dimensions - Number of dimensions for this PointKD.
+
protected ResultHeap(int capacity) {
*/
+
this.data = new Object[capacity];
public PointKD(int dimensions) {
+
this.keys = new double[capacity];
internal = new double[dimensions];
+
this.capacity = capacity;
 +
this.size = 0;
 
}
 
}
+
 
/**
+
protected void offer(double key, T value) {
* Constructor
+
int i = size;
* @param array - An array to use for the internal array of this PointKD.
+
for (; i > 0 && keys[i - 1] > key; --i);
*/
+
if (i >= capacity) return;
public PointKD(double[] array) {
+
if (size < capacity) ++size;
internal = array.clone();
+
int j = i + 1;
 +
System.arraycopy(keys, i, keys, j, size - j);
 +
keys[i] = key;
 +
System.arraycopy(data, i, data, j, size - j);
 +
data[i] = value;
 
}
 
}
+
 
/**
+
public double getMaxKey() {
* Constructor
+
return keys[size - 1];
* @param point - PointKD to clone.
 
*/
 
public PointKD(PointKD point) {
 
this.internal = point.internal.clone();
 
 
}
 
}
 
 
+
@SuppressWarnings("unchecked")
/**
+
public T removeMax() {
* Returns array.
+
if(isEmpty()) return null;
* @return The internal array of this pointKD, changes made to this will effect this PointKD.
+
return (T)data[--size];
*/
 
public double[] get() {
 
return internal;
 
 
}
 
}
+
 
/**
+
public boolean isEmpty() {
* Sets the location of this point.
+
return size == 0;
* @param point - PointKD to copy.
 
*/
 
public void set(PointKD point) {
 
this.internal = point.internal.clone();
 
}
 
 
/**
 
* Sets the location of this point.
 
* @param array - Array to copy.
 
*/
 
public void set(double[] array) {
 
this.internal = array.clone();
 
}
 
 
/**
 
* Sets the value at dimension index
 
* @param index - Index
 
* @param value - Value to set
 
*/
 
public void set(int index, double value) {
 
internal[index] = value;
 
}
 
 
/**
 
* @return number of dimensions
 
*/
 
public int size() {
 
return internal.length;
 
}
 
 
/**
 
* Compares this to a selected point and returns the euclidean distance
 
* between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distance(PointKD p) {
 
return distance(this.internal, p.internal);
 
 
}
 
}
  
/**
+
public boolean isFull() {
* Compares this to a selected point and returns the squared euclidean
+
return size == capacity;
* distance between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distanceSq(PointKD p) {
 
return distanceSq(this.internal, p.internal);
 
}
 
 
/**
 
* Clones this point.
 
*/
 
public PointKD clone() {
 
return new PointKD(internal);
 
}
 
 
/**
 
* Prints the class name and the point coordinates.
 
*/
 
public String toString() {
 
String output = getClass().getSimpleName() + "[";
 
for (int i = 0; i < internal.length; ++i) {
 
if (0 != i)
 
output += ",";
 
output += internal[i];
 
}
 
return output + "]";
 
}
 
 
/**
 
* Compares arrays of double and returns the euclidean distance
 
* between them.
 
*
 
* @param a - The first set of numbers
 
* @param b - The second set of numbers
 
* @return The distance squared between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distance(double[] a, double[] b) {
 
double total = 0;
 
for (int i = 0; i < a.length; ++i)
 
total += (b[i] - a[i]) * (b[i] - a[i]);
 
return Math.sqrt(total);
 
 
}
 
}
 
/**
 
* Compares arrays of double and returns the squared euclidean distance
 
* between them.
 
*
 
* @param a - The first set of numbers
 
* @param b - The second set of numbers
 
* @return The distance squared between <b>a</b> and <b>b</b>.
 
*/
 
public static final double distanceSq(double[] a, double[] b) {
 
double total = 0;
 
for (int i = 0; i < a.length; ++i)
 
total += (b[i] - a[i]) * (b[i] - a[i]);
 
return total;
 
}
 
}
 
</pre>
 
  
=== RectangleKD ===
+
public int size() {
<pre>
+
return size;
package org.csdgn.util;
 
 
 
import java.io.Serializable;
 
 
 
public class RectangleKD implements Serializable {
 
private static final long serialVersionUID = 524388821816648020L;
 
protected PointKD upper, lower;
 
/**
 
* Creates an empty RectangleKD
 
*/
 
public RectangleKD() {
 
upper = null;
 
lower = null;
 
}
 
/**
 
* Creates a RectangleKD from two points.
 
* @param lower
 
* @param upper
 
*/
 
public RectangleKD(PointKD lower, PointKD upper) {
 
this.lower = lower.clone();
 
this.upper = upper.clone();
 
 
}
 
}
 
/**
 
* Expands this RectangleKD to include point.
 
* @param point - Point used to expand this Rectangle
 
*/
 
public void expand(PointKD point) {
 
if(upper == null) {
 
upper = point.clone();
 
lower = point.clone();
 
return;
 
}
 
for(int i=0; i<point.size(); ++i) {
 
if(upper.internal[i] < point.internal[i])
 
upper.internal[i] = point.internal[i];
 
if(lower.internal[i] > point.internal[i])
 
lower.internal[i] = point.internal[i];
 
}
 
}
 
 
/**
 
* Checks if this RectangleKD contains the given point.
 
* @param point - PointKD to check.
 
* @return true - if this rectangle contains the given point.
 
*/
 
public boolean contains(PointKD point) {
 
if(upper == null) return false;
 
for(int i=0; i<point.size(); ++i) {
 
if(point.internal[i] > upper.internal[i] ||
 
point.internal[i] < lower.internal[i])
 
return false;
 
}
 
return true;
 
}
 
 
/**
 
* Checks if this table intersects the given table.
 
* @param rectangle - The rectangle to check intersection with.
 
* @return true - if this rectangle intersects the given rectangle.
 
*/
 
public boolean intersects(RectangleKD rectangle) {
 
if(rectangle.upper == null || upper == null) return false;
 
for(int i=0; i<upper.size(); ++i) {
 
if(rectangle.upper.internal[i] < lower.internal[i]
 
|| rectangle.lower.internal[i] > upper.internal[i]) return false;
 
}
 
return true;
 
}
 
 
/**
 
* Gets the nearest point in this RectangleKD to the given point
 
* @param point - PointKD to get the nearest point to.
 
* @return the nearest point in this RectangleKD to the given point.
 
*/
 
public PointKD getNearest(PointKD point) {
 
if(upper == null) return null;
 
if(contains(point)) return point; //This check may not be needed.
 
PointKD nearest = new PointKD(point);
 
for(int i = 0; i < upper.size(); ++i) {
 
if(nearest.internal[i] > upper.internal[i])
 
nearest.internal[i] = upper.internal[i];
 
if(nearest.internal[i] < lower.internal[i])
 
nearest.internal[i] = lower.internal[i];
 
}
 
return nearest;
 
}
 
}
 
</pre>
 
 
=== PriorityDeque ===
 
<pre>
 
package org.csdgn.util;
 
 
import java.util.Comparator;
 
  
/**
+
public int capacity() {
* This is a home made PriorityDeque with maximum size limiter, and
+
return capacity;
* comparator or natural ordering selection.
 
*
 
* @author Chase
 
* @param <E>
 
*/
 
public class PriorityDeque<E> {
 
private class Item {
 
public Item down, up;
 
public E obj;
 
public Item(E item) {
 
obj = item;
 
down = up = null;
 
}
 
}
 
private final Comparator<? super E> comparator;
 
private Item bottom, top;
 
private int maximum_size;
 
private int size;
 
 
/**
 
* Generic Constructor
 
*/
 
public PriorityDeque() {
 
this(null,-1);
 
}
 
/**
 
* Constructor with a defined maximum number of entries
 
* @param maximum
 
*/
 
public PriorityDeque(int maximum) {
 
this(null,maximum);
 
}
 
/**
 
* Constructor with a defined comparator and an unlimited number of items.
 
* @param comp
 
*/
 
public PriorityDeque(Comparator<? super E> comp) {
 
this(comp,-1);
 
}
 
/**
 
* Constructor with a defined conparator and limited number of items.
 
* @param comp
 
* @param maximum
 
*/
 
public PriorityDeque(Comparator<? super E> comp, int maximum) {
 
comparator = comp;
 
maximum_size = maximum;
 
bottom = top = null;
 
size = 0;
 
}
 
/**
 
* Adds an item to this deque.
 
* @param value - item to add
 
*/
 
public void offer(E value) {
 
//System.err.println("Offering: " + value.toString());
 
if(bottom == null) {
 
//System.err.println("-Is first item.");
 
bottom = top = new Item(value);
 
return;
 
}
 
//do ordering etc
 
if(comparator != null) {
 
//System.err.println("-Comparator.");
 
offerComparator(value);
 
} else {
 
//System.err.println("-Natural.");
 
offerNatural(value);
 
}
 
}
 
 
/**
 
* Removes and returns an item from the bottom of the list (lowest value)
 
* @return the lowest value (bottom)
 
*/
 
public E pollBottom() {
 
if(bottom == null) return null;
 
Item tmp = bottom;
 
if(bottom == top) bottom = top = null;
 
else {
 
bottom = bottom.up;
 
bottom.down = null;
 
tmp.up = null;
 
}
 
--size;
 
return tmp.obj;
 
}
 
/**
 
* Returns but does not remove the item from the bottom of the list.
 
* @return the lowest value (bottom)
 
*/
 
public E peekBottom() {
 
if(bottom == null) return null;
 
return bottom.obj;
 
}
 
/**
 
* Removes and returns an item from the top of the list (highest value).
 
* @return the highest value (top)
 
*/
 
public E pollTop() {
 
if(top == null) return null;
 
Item tmp = top;
 
if(bottom == top) bottom = top = null;
 
else {
 
top = top.down;
 
top.up = null;
 
tmp.down = null;
 
}
 
--size;
 
return tmp.obj;
 
}
 
/**
 
* Returns but does not remove the item from the top of the list (highest value).
 
* @return the highest value (top)
 
*/
 
public E peekTop() {
 
if(top == null) return null;
 
return top.obj;
 
}
 
/**
 
* This is a generic toArray, returns a non-castable Object
 
* array with all the items in this deque.
 
* @return an Object array with everything in this deque.
 
*/
 
public Object[] toArray() {
 
Object array[] = new Object[size+1];
 
Item current = bottom;
 
int i = 0;
 
while(current != null) {
 
array[i++] = current.obj;
 
current = current.up;
 
}
 
return array;
 
}
 
/**
 
* Add/sorts a value using the comparator
 
* @param value
 
*/
 
private void offerComparator(E value) {
 
if(maximum_size > 0 && size >= maximum_size) {
 
if(comparator.compare(value, top.obj) >= 0) {
 
return;
 
}
 
}
 
Item nItem = new Item(value);
 
if(comparator.compare(value,bottom.obj) < 0) {
 
//less than the bottom, put on the bottom
 
nItem.up = bottom;
 
bottom.down = nItem;
 
bottom = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
}
 
 
//start at the bottom
 
Item current = bottom;
 
while(current != null) {
 
if(comparator.compare(value,current.obj) < 0) {
 
nItem.up = current;
 
nItem.down = current.down;
 
current.down.up = nItem;
 
current.down = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
}
 
current = current.up;
 
}
 
 
//else put it on top
 
nItem.down = top;
 
top.up = nItem;
 
top = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
}
 
 
/**
 
* Add/sorts a value using the natural order
 
* @param value
 
*/
 
@SuppressWarnings("unchecked")
 
private void offerNatural(E value) {
 
Comparable<? super E> key = (Comparable<? super E>)value;
 
if(maximum_size > 0 && size >= maximum_size) {
 
if(key.compareTo(top.obj) >= 0) {
 
return;
 
}
 
}
 
Item nItem = new Item(value);
 
if(key.compareTo(bottom.obj) < 0) {
 
//less than the bottom, put on the bottom
 
nItem.up = bottom;
 
bottom.down = nItem;
 
bottom = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
}
 
 
//start at the bottom
 
Item current = bottom;
 
while(current != null) {
 
if(key.compareTo(current.obj) < 0) {
 
nItem.up = current;
 
nItem.down = current.down;
 
current.down.up = nItem;
 
current.down = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
}
 
current = current.up;
 
}
 
 
//else put it on top
 
nItem.down = top;
 
top.up = nItem;
 
top = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
 
}
 
}
 
}
 
}
 
+
</syntaxhighlight>
</pre>
 

Latest revision as of 21:00, 7 November 2012

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

This and all my other code in which I display on the robowiki falls under the ZLIB License.

Oh yeah, am I the only one that has a Range function?

KDTreeF

package org.csdgn.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This is a KD Bucket Tree, for fast sorting and searching of K dimensional
 * data.
 * 
 * @author Chase
 * 
 */
public class KDTree<T> {
	protected static final int defaultBucketSize = 48;

	private final int dimensions;
	private final int bucketSize;
	private NodeKD root;

	/**
	 * Constructor with value for dimensions.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 */
	public KDTree(int dimensions) {
		this.dimensions = dimensions;
		this.bucketSize = defaultBucketSize;
		this.root = new NodeKD();
	}

	/**
	 * Constructor with value for dimensions and bucket size.
	 * 
	 * @param dimensions
	 *            - Number of dimensions
	 * @param bucket
	 *            - Size of the buckets.
	 */
	public KDTree(int dimensions, int bucket) {
		this.dimensions = dimensions;
		this.bucketSize = bucket;
		this.root = new NodeKD();
	}

	/**
	 * Add a key and its associated value to the tree.
	 * 
	 * @param key
	 *            - Key to add
	 * @param val
	 *            - object to add
	 */
	public void add(double[] key, T val) {
		root.addPoint(key, val);
	}

	/**
	 * Returns all PointKD within a certain range defined by an upper and lower
	 * PointKD.
	 * 
	 * @param low
	 *            - lower bounds of area
	 * @param high
	 *            - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	@SuppressWarnings("unchecked")
	public List<T> getRange(double[] low, double[] high) {
		Object[] objs = root.range(high, low);
		ArrayList<T> range = new ArrayList<T>(objs.length);
		for(int i=0; i<objs.length; ++i) {
			range.add((T)objs[i]);
		}
		return range;
	}

	/**
	 * Gets the N nearest neighbors to the given key.
	 * 
	 * @param key
	 *            - Key
	 * @param num
	 *            - Number of results
	 * @return Array of Item Objects, distances within the items are the square
	 *         of the actual distance between them and the key
	 */
	public ResultHeap<T> getNearestNeighbors(double[] key, int num) {
		ResultHeap<T> heap = new ResultHeap<T>(num);
		root.nearest(heap, key);
		return heap;
	}


	// Internal tree node
	private class NodeKD {
		private NodeKD left, right;
		private double[] maxBounds, minBounds;
		private Object[] bucketValues;
		private double[][] bucketKeys;
		private boolean isLeaf;
		private int current, sliceDimension;
		private double slice;

		private NodeKD() {
			bucketValues = new Object[bucketSize];
			bucketKeys = new double[bucketSize][];

			left = right = null;
			maxBounds = minBounds = null;
			
			isLeaf = true;
			
			current = 0;
		}

		// what it says on the tin
		private void addPoint(double[] key, Object val) {
			if(isLeaf) {
				addLeafPoint(key,val);
			} else {
				extendBounds(key);
				if (key[sliceDimension] > slice) {
					right.addPoint(key, val);
				} else {
					left.addPoint(key, val);
				}
			}
		}
		
		private void addLeafPoint(double[] key, Object val) {
			extendBounds(key);
			if (current + 1 > bucketSize) {
				splitLeaf();
				addPoint(key, val);
				return;
			}
			bucketKeys[current] = key;
			bucketValues[current] = val;
			++current;
		}
		
		/**
		 * Find the nearest neighbor recursively.
		 */
		@SuppressWarnings("unchecked")
		private void nearest(ResultHeap<T> heap, double[] data) {
			if(current == 0)
				return;
			if(isLeaf) {
				//IS LEAF
				for (int i = 0; i < current; ++i) {
					double dist = pointDistSq(bucketKeys[i], data);
					heap.offer(dist, (T) bucketValues[i]);
				}
			} else {
				//IS BRANCH
				if (data[sliceDimension] > slice) {
					right.nearest(heap, data);
					if(left.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,left.minBounds,left.maxBounds) < heap.getMaxKey()) {
						left.nearest(heap, data);
					}
				} else {
					left.nearest(heap, data);
					if (right.current == 0)
						return;
					if (!heap.isFull() || regionDistSq(data,right.minBounds,right.maxBounds) < heap.getMaxKey()) {
						right.nearest(heap, data);
					}
				}
			}
		}

		// gets all items from within a range
		private Object[] range(double[] upper, double[] lower) {
			if (bucketValues == null) {
				// Branch
				Object[] tmp = new Object[0];
				if (intersects(upper, lower, left.maxBounds, left.minBounds)) {
					Object[] tmpl = left.range(upper, lower);
					if (0 == tmp.length) tmp = tmpl;
				}
				if (intersects(upper, lower, right.maxBounds, right.minBounds)) {
					Object[] tmpr = right.range(upper, lower);
					if (0 == tmp.length)
						tmp = tmpr;
					else if (0 < tmpr.length) {
						Object[] tmp2 = new Object[tmp.length + tmpr.length];
						System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
						System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
						tmp = tmp2;
					}
				}
				return tmp;
			}
			// Leaf
			Object[] tmp = new Object[current];
			int n = 0;
			for (int i = 0; i < current; ++i) {
				if (contains(upper, lower, bucketKeys[i])) {
					tmp[n++] = bucketValues[i];
				}
			}
			Object[] tmp2 = new Object[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}

		// These are helper functions from here down
		// check if this hyper rectangle contains a give hyper-point
		public boolean contains(double[] upper, double[] lower, double[] point) {
			if (current == 0) return false;
			for (int i = 0; i < point.length; ++i) {
				if (point[i] > upper[i] || point[i] < lower[i]) return false;
			}
			return true;
		}

		// checks if two hyper-rectangles intersect
		public boolean intersects(double[] up0, double[] low0, double[] up1, double[] low1) {
			for (int i = 0; i < up0.length; ++i) {
				if (up1[i] < low0[i] || low1[i] > up0[i]) return false;
			}
			return true;
		}

		private void splitLeaf() {
			double bestRange = 0;
			for(int i=0;i<dimensions;++i) {
				double range = maxBounds[i] - minBounds[i];
				if(range > bestRange) {
					sliceDimension = i;
					bestRange = range;
				}
			}
			
			left = new NodeKD();
			right = new NodeKD();
			
			slice = (maxBounds[sliceDimension] + minBounds[sliceDimension]) * 0.5;
			
			for (int i = 0; i < current; ++i) {
				if (bucketKeys[i][sliceDimension] > slice) {
					right.addLeafPoint(bucketKeys[i], bucketValues[i]);
				} else {
					left.addLeafPoint(bucketKeys[i], bucketValues[i]);
				}
			}
			bucketKeys = null;
			bucketValues = null;
			isLeaf = false;
		}

		// expands this hyper rectangle
		private void extendBounds(double[] key) {
			if (maxBounds == null) {
				maxBounds = Arrays.copyOf(key, dimensions);
				minBounds = Arrays.copyOf(key, dimensions);
				return;
			}
			for (int i = 0; i < key.length; ++i) {
				if (maxBounds[i] < key[i]) maxBounds[i] = key[i];
				if (minBounds[i] > key[i]) minBounds[i] = key[i];
			}
		}
	}
	
	/* I may have borrowed these from an early version of Red's tree. I however forget. */
	private static final double pointDistSq(double[] p1, double[] p2) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < p1.length; ++i) {
            d += (q=(p1[i] - p2[i]))*q;
        }
        return d;
    }

    private static final double regionDistSq(double[] point, double[] min, double[] max) {
        double d = 0;
        double q = 0;
        for (int i = 0; i < point.length; ++i) {
            if (point[i] > max[i]) {
            	d += (q = (point[i] - max[i]))*q;
            } else if (point[i] < min[i]) {
                d += (q = (point[i] - min[i]))*q;
            }
        }
        return d;
    }
}

ResultHeap

package org.csdgn.util;

/**
 * @author Chase
 * 
 * @param <T>
 */
public class ResultHeap<T> {
	private Object[] data;
	private double[] keys;
	private int capacity;
	private int size;

	protected ResultHeap(int capacity) {
		this.data = new Object[capacity];
		this.keys = new double[capacity];
		this.capacity = capacity;
		this.size = 0;
	}

	protected void offer(double key, T value) {
		int i = size;
		for (; i > 0 && keys[i - 1] > key; --i);
		if (i >= capacity) return;
		if (size < capacity) ++size;
		int j = i + 1;
		System.arraycopy(keys, i, keys, j, size - j);
		keys[i] = key;
		System.arraycopy(data, i, data, j, size - j);
		data[i] = value;
	}

	public double getMaxKey() {
		return keys[size - 1];
	}
	
	@SuppressWarnings("unchecked")
	public T removeMax() {
		if(isEmpty()) return null;
		return (T)data[--size];
	}

	public boolean isEmpty() {
		return size == 0;
	}

	public boolean isFull() {
		return size == capacity;
	}

	public int size() {
		return size;
	}

	public int capacity() {
		return capacity;
	}
}