Difference between revisions of "User:Chase-san/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
(Changed distance to use distanceSq in comparator)
(There, updated to my smaller version.)
Line 2: Line 2:
 
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
 
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
  
=== KDTreeB ===
+
=== KDTreeC ===
 
<pre>
 
<pre>
package org.csdgn.util;
+
package org.csdgn.util.kd2;
  
import java.util.Comparator;
+
import java.util.Arrays;
  
public class KDTreeB {
+
/**
public final static int DEFAULT_BUCKET_SIZE = 200;
+
* This is a KD Bucket Tree, for fast sorting and searching of K dimensional data.
protected NodeKD root;
+
* @author Chase
protected int dimensions;
+
*
protected int bucket_size;
+
*/
protected long size = 0;
+
public class KDTreeC {
 
 
/**
 
/**
* This creates a KDTreeB with the given number of dimensions and
+
* Item, for moving data around.
* the default bucket size.
+
* @author Chase
* @param dimensions
 
 
*/
 
*/
public KDTreeB(int dimensions) {
+
public class Item {
this(dimensions,DEFAULT_BUCKET_SIZE);
+
public double[] pnt;
 +
public Object obj;
 +
public double distance;
 +
private Item(double[] p, Object o) {
 +
pnt = p; obj = o;
 +
}
 
}
 
}
 +
private final int dimensions;
 +
private final int bucket_size;
 +
private NodeKD root;
 +
 
/**
 
/**
* This creates a KDTreeB with the given number of dimensions and
+
* Constructor with value for dimensions.
* the given bucket size.
+
* @param dimensions - Number of dimensions
* @param dimensions
 
* @param bucket_size
 
 
*/
 
*/
public KDTreeB(int dimensions, int bucket_size) {
+
public KDTreeC(int dimensions) {
 
this.dimensions = dimensions;
 
this.dimensions = dimensions;
this.bucket_size = bucket_size;
+
this.bucket_size = 64;
root = new BucketKD(this);
+
this.root = new NodeKD(this);
 
}
 
}
 
 
 
/**
 
/**
* Adds the given point to the tree. Uses a recursive algorithm.
+
* Constructor with value for dimensions and bucket size.
* @param point
+
* @param dimensions - Number of dimensions
 +
* @param bucket - Size of the buckets.
 
*/
 
*/
public void add(PointKD point) {
+
public KDTreeC(int dimensions, int bucket) {
root.add(point);
+
this.dimensions = dimensions;
 +
this.bucket_size = bucket;
 +
this.root = new NodeKD(this);
 
}
 
}
 
 
 
/**
 
/**
* Returns the nearest neighbor to the given point.
+
* Add a key and its associated value to the tree.
* @param point - PointKD to find the nearest to.
+
* @param key - Key to add
* @return - The nearest PointKD to  
+
* @param val - object to add
 
*/
 
*/
public PointKD getNearestNeighbor(PointKD point) {
+
public void add(double[] key, Object val) {
return root.nearest(point);
+
root.add(new Item(key,val));
 
}
 
}
 
 
/**
 
* Returns the nearest num neighbors to the given point.
 
* @param point - PointKD to find the nearest to.
 
* @param num - The number of points to find.
 
* @return - The nearest PointKD's to point.
 
*/
 
public PointKD[] getNearestNeighbors(final PointKD point, int num) {
 
if(num == 1) {
 
return new PointKD[] { getNearestNeighbor(point) };
 
}
 
PriorityDeque<PointKD> queue = new PriorityDeque<PointKD>(
 
new Comparator<PointKD>(){
 
@Override
 
public int compare(PointKD o1, PointKD o2) {
 
double dist0 = point.distanceSq(o1);
 
double dist1 = point.distanceSq(o2);
 
return (dist0-dist1 < 0) ? -1 : 1;
 
}
 
},num);
 
 
root.nearestn(queue,point);
 
 
PointKD[] array = new PointKD[num];
 
Object[] obj = queue.toArray();
 
for(int i=0; i<num; ++i) {
 
array[i] = (PointKD)obj[i];
 
}
 
 
return array;
 
}
 
 
/**
 
* Returns all PointKD within a certain RectangleKD.
 
* @param rect - area to get PointKD from
 
* @return - All PointKD within rect.
 
*/
 
public PointKD[] getRange(RectangleKD rect) {
 
return root.range(rect);
 
}
 
 
 
/**
 
/**
 
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
 
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
Line 98: Line 66:
 
* @return - All PointKD between low and high.
 
* @return - All PointKD between low and high.
 
*/
 
*/
public PointKD[] getRange(PointKD low, PointKD high) {
+
public Item[] getRange(double[] low, double[] high) {
return this.getRange(new RectangleKD(low, high));
+
return root.range(high, low);
}
 
 
protected abstract class NodeKD {
 
protected KDTreeB owner;
 
protected BranchKD parent;
 
protected RectangleKD rect;
 
protected int depth = 0;
 
 
protected abstract void add(PointKD k);
 
protected abstract PointKD nearest(PointKD k);
 
protected abstract void nearestn(PriorityDeque<PointKD> queue, PointKD k);
 
protected abstract PointKD[] range(RectangleKD r);
 
}
 
protected class BucketKD extends NodeKD {
 
protected PointKD bucket[];
 
protected int current;
 
 
public BucketKD(KDTreeB owner) {
 
this.owner = owner;
 
bucket = new PointKD[owner.bucket_size];
 
rect = new RectangleKD();
 
parent = null;
 
}
 
public BucketKD(BranchKD branch) {
 
this(branch.owner);
 
parent = branch;
 
depth = branch.depth + 1;
 
}
 
 
@Override
 
protected void add(PointKD point) {
 
if(current >= bucket.length) {
 
//Split the bucket into a branch
 
BranchKD branch = new BranchKD(this);
 
if(parent == null) {
 
owner.root = branch;
 
} else {
 
if(parent.isLeft(this)) {
 
parent.left = branch;
 
} else {
 
parent.right = branch;
 
}
 
}
 
branch.add(point);
 
bucket = null;
 
current = 0;
 
return;
 
}
 
bucket[current++] = point;
 
rect.expand(point);
 
}
 
 
 
@Override
 
protected PointKD nearest(PointKD k) {
 
double nearestDist = Double.POSITIVE_INFINITY;
 
int nearest = 0;
 
for (int i = 0; i < current; i++) {
 
double distance = k.distanceSq(bucket[i]);
 
if (distance < nearestDist) {
 
nearestDist = distance;
 
nearest = i;
 
}
 
}
 
return bucket[nearest];
 
}
 
 
@Override
 
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
 
for(int i = 0; i < current; i++) {
 
queue.offer(bucket[i]);
 
}
 
}
 
 
 
@Override
 
protected PointKD[] range(RectangleKD r) {
 
PointKD[] tmp = new PointKD[current];
 
int n = 0;
 
for (int i = 0; i < current; i++) {
 
if (r.contains(bucket[i])) {
 
tmp[n++] = bucket[i];
 
}
 
}
 
PointKD[] tmp2 = new PointKD[n];
 
System.arraycopy(tmp, 0, tmp2, 0, n);
 
return tmp2;
 
}
 
 
}
 
protected class BranchKD extends NodeKD {
 
protected NodeKD left, right;
 
protected double slice;
 
protected int dim;
 
 
public BranchKD(BucketKD k) {
 
owner = k.owner;
 
parent = k.parent;
 
slice = 0;
 
rect = k.rect;
 
depth = k.depth;
 
left = new BucketKD(this);
 
right = new BucketKD(this);
 
 
dim = depth % owner.dimensions;
 
double total = 0;
 
for (int i = 0; i < k.current; i++) {
 
total += k.bucket[i].internal[dim];
 
}
 
slice = total / k.current;
 
for (int i = 0; i < k.current; i++)
 
add(k.bucket[i]);
 
}
 
 
@Override
 
protected void add(PointKD k) {
 
if(k.internal[dim] > slice) {
 
right.add(k);
 
} else {
 
left.add(k);
 
}
 
}
 
 
 
@Override
 
protected PointKD nearest(PointKD k) {
 
PointKD near = null;
 
if (k.internal[dim] > slice) {
 
near = right.nearest(k);
 
double t = near.distanceSq(k);
 
if (k.distanceSq(left.rect.getNearest(k)) < t) {
 
PointKD tmp = left.nearest(k);
 
if (tmp.distanceSq(k) < t) {
 
near = tmp;
 
}
 
}
 
} else {
 
near = left.nearest(k);
 
double t = near.distanceSq(k);
 
if (k.distanceSq(right.rect.getNearest(k)) < t) {
 
PointKD tmp = right.nearest(k);
 
if (tmp.distanceSq(k) < t) {
 
near = tmp;
 
}
 
}
 
}
 
return near;
 
}
 
 
@Override
 
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
 
//TODO
 
if(k.internal[dim] > slice) {
 
right.nearestn(queue, k);
 
double t = queue.peekBottom().distanceSq(k);
 
if(k.distanceSq(left.rect.getNearest(k)) < t) {
 
left.nearestn(queue, k);
 
}
 
} else {
 
left.nearestn(queue, k);
 
double t = queue.peekBottom().distanceSq(k);
 
if(k.distanceSq(right.rect.getNearest(k)) < t) {
 
right.nearestn(queue, k);
 
}
 
}
 
}
 
 
 
@Override
 
protected PointKD[] range(RectangleKD r) {
 
PointKD[] tmp = new PointKD[0];
 
if (r.intersects(left.rect)) {
 
PointKD[] tmpl = left.range(r);
 
if(0 == tmp.length)
 
tmp = tmpl;
 
}
 
if (r.intersects(right.rect)) {
 
PointKD[] tmpr = right.range(r);
 
if (0 == tmp.length)
 
tmp = tmpr;
 
else if (0 < tmpr.length) {
 
PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
 
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
 
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
 
tmp = tmp2;
 
}
 
}
 
return tmp;
 
}
 
 
protected boolean isLeft(BucketKD kd) {
 
if(left == kd) return true;
 
return false;
 
}
 
}
 
}
 
</pre>
 
 
 
=== PointKD ===
 
<pre>
 
package org.csdgn.util;
 
 
 
import java.io.Serializable;
 
 
 
/**
 
* PointKD class is a class that wraps a double array, for use in K-Dimensional structures.
 
* This wrapping is done to improve readability and modularity. Often used to define a
 
* point in k-dimensional space.
 
*
 
* @author Chase
 
*/
 
public class PointKD implements Serializable {
 
private static final long serialVersionUID = -841162798668123755L;
 
protected double[] internal;
 
 
/**
 
* Constructor
 
* @param dimensions - Number of dimensions for this PointKD.
 
*/
 
public PointKD(int dimensions) {
 
internal = new double[dimensions];
 
}
 
 
/**
 
* Constructor
 
* @param array - An array to use for the internal array of this PointKD.
 
*/
 
public PointKD(double[] array) {
 
internal = array.clone();
 
}
 
 
/**
 
* Constructor
 
* @param point - PointKD to clone.
 
*/
 
public PointKD(PointKD point) {
 
this.internal = point.internal.clone();
 
}
 
 
 
/**
 
* Returns array.
 
* @return The internal array of this pointKD, changes made to this will effect this PointKD.
 
*/
 
public double[] get() {
 
return internal;
 
}
 
 
/**
 
* Sets the location of this point.
 
* @param point - PointKD to copy.
 
*/
 
public void set(PointKD point) {
 
this.internal = point.internal.clone();
 
}
 
 
/**
 
* Sets the location of this point.
 
* @param array - Array to copy.
 
*/
 
public void set(double[] array) {
 
this.internal = array.clone();
 
}
 
 
/**
 
* Sets the value at dimension index
 
* @param index - Index
 
* @param value - Value to set
 
*/
 
public void set(int index, double value) {
 
internal[index] = value;
 
}
 
 
/**
 
* @return number of dimensions
 
*/
 
public int size() {
 
return internal.length;
 
}
 
 
/**
 
* Compares this to a selected point and returns the euclidean distance
 
* between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distance(PointKD p) {
 
return distance(this.internal, p.internal);
 
}
 
 
 
/**
 
* Compares this to a selected point and returns the squared euclidean
 
* distance between them.
 
*
 
* @param p
 
*            - The Point to get the distance to.
 
* @return The distance between this and <b>p</b>.
 
*/
 
public double distanceSq(PointKD p) {
 
return distanceSq(this.internal, p.internal);
 
 
}
 
}
 
 
 
/**
 
/**
* Clones this point.
+
* Gets the N nearest neighbors to the given key.
 +
* @param key - Key
 +
* @param num - Number of results
 +
* @return Array of Item Objects
 
*/
 
*/
public PointKD clone() {
+
public Item[] getNearestNeighbor(double[] key, int num) {
return new PointKD(internal);
+
ShiftArray arr = new ShiftArray(num);
 +
root.nearestn(arr, key);
 +
return arr.getArray();
 
}
 
}
 
/**
 
* Prints the class name and the point coordinates.
 
*/
 
public String toString() {
 
String output = getClass().getSimpleName() + "[";
 
for (int i = 0; i < internal.length; ++i) {
 
if (0 != i)
 
output += ",";
 
output += internal[i];
 
}
 
return output + "]";
 
}
 
 
 
/**
 
/**
 
* Compares arrays of double and returns the euclidean distance
 
* Compares arrays of double and returns the euclidean distance
Line 434: Line 95:
 
return Math.sqrt(total);
 
return Math.sqrt(total);
 
}
 
}
 
 
/**
 
/**
 
* Compares arrays of double and returns the squared euclidean distance
 
* Compares arrays of double and returns the squared euclidean distance
Line 448: Line 108:
 
total += (b[i] - a[i]) * (b[i] - a[i]);
 
total += (b[i] - a[i]) * (b[i] - a[i]);
 
return total;
 
return total;
}
 
}
 
</pre>
 
 
=== RectangleKD ===
 
<pre>
 
package org.csdgn.util;
 
 
import java.io.Serializable;
 
 
public class RectangleKD implements Serializable {
 
private static final long serialVersionUID = 524388821816648020L;
 
protected PointKD upper, lower;
 
/**
 
* Creates an empty RectangleKD
 
*/
 
public RectangleKD() {
 
upper = null;
 
lower = null;
 
}
 
/**
 
* Creates a RectangleKD from two points.
 
* @param lower
 
* @param upper
 
*/
 
public RectangleKD(PointKD lower, PointKD upper) {
 
this.lower = lower.clone();
 
this.upper = upper.clone();
 
 
}
 
}
 
 
/**
+
//Internal tree node
* Expands this RectangleKD to include point.
+
private class NodeKD {
* @param point - Point used to expand this Rectangle
+
private KDTreeC owner;
*/
+
private NodeKD left, right;
public void expand(PointKD point) {
+
private double[] upper, lower;
if(upper == null) {
+
private Item[] bucket;
upper = point.clone();
+
private int current, dim;
lower = point.clone();
+
private double slice;
return;
+
 +
//note: we always start as a bucket
 +
private NodeKD(KDTreeC own) {
 +
owner = own;
 +
upper = lower = null;
 +
left = right = null;
 +
bucket = new Item[own.bucket_size];
 +
current = 0;
 +
dim = 0;
 
}
 
}
for(int i=0; i<point.size(); ++i) {
+
//when we create non-root nodes within this class
if(upper.internal[i] < point.internal[i])
+
//we use this one here
upper.internal[i] = point.internal[i];
+
private NodeKD(NodeKD node) {
if(lower.internal[i] > point.internal[i])
+
owner = node.owner;
lower.internal[i] = point.internal[i];
+
dim = node.dim + 1;
 +
bucket = new Item[owner.bucket_size];
 +
if(dim + 1 > owner.dimensions) dim = 0;
 +
left = right = null;
 +
upper = lower = null;
 +
current = 0;
 
}
 
}
}
+
//what it says on the tin
+
private void add(Item m) {
/**
+
if(bucket == null) {
* Checks if this RectangleKD contains the given point.
+
//Branch
* @param point - PointKD to check.
+
if(m.pnt[dim] > slice)
* @return true - if this rectangle contains the given point.
+
right.add(m);
*/
+
else left.add(m);
public boolean contains(PointKD point) {
+
} else {
if(upper == null) return false;
+
//Bucket
for(int i=0; i<point.size(); ++i) {
+
if(current+1>owner.bucket_size) {
if(point.internal[i] > upper.internal[i] ||
+
split(m);
point.internal[i] < lower.internal[i])  
+
return;
return false;
+
}
 +
bucket[current++] = m;
 +
}
 +
expand(m.pnt);
 
}
 
}
return true;
+
//nearest neighbor thing
}
+
private void nearestn(ShiftArray arr, double[] data) {
+
if(bucket == null) {
/**
+
//Branch
* Checks if this table intersects the given table.
+
if(data[dim] > slice) {
* @param rectangle - The rectangle to check intersection with.
+
right.nearestn(arr, data);
* @return true - if this rectangle intersects the given rectangle.
+
if(left.current != 0) {
*/
+
if(KDTreeC.distanceSq(left.nearestRect(data),data)
public boolean intersects(RectangleKD rectangle) {
+
< arr.getLongest()) {
if(rectangle.upper == null || upper == null) return false;
+
left.nearestn(arr, data);
for(int i=0; i<upper.size(); ++i) {
+
}
if(rectangle.upper.internal[i] < lower.internal[i]
+
}
|| rectangle.lower.internal[i] > upper.internal[i]) return false;
+
 +
} else {
 +
left.nearestn(arr, data);
 +
if(right.current != 0) {
 +
if(KDTreeC.distanceSq(right.nearestRect(data),data)
 +
< arr.getLongest()) {
 +
right.nearestn(arr, data);
 +
}
 +
}
 +
}
 +
} else {
 +
//Bucket
 +
for(int i = 0; i < current; i++) {
 +
bucket[i].distance = KDTreeC.distanceSq(bucket[i].pnt, data);
 +
arr.add(bucket[i]);
 +
}
 +
}
 
}
 
}
return true;
+
//gets all items from within a range
}
+
private Item[] range(double[] upper, double[] lower) {
+
//TODO: clean this up a bit
/**
+
if(bucket == null) {
* Gets the nearest point in this RectangleKD to the given point
+
//Branch
* @param point - PointKD to get the nearest point to.
+
Item[] tmp = new Item[0];
* @return the nearest point in this RectangleKD to the given point.
+
if (intersects(upper,lower,left.upper,left.lower)) {
*/
+
Item[] tmpl = left.range(upper,lower);
public PointKD getNearest(PointKD point) {
+
if(0 == tmp.length)
if(upper == null) return null;
+
tmp = tmpl;
if(contains(point)) return point; //This check may not be needed.
+
}
PointKD nearest = new PointKD(point);
+
if (intersects(upper,lower,right.upper,right.lower)) {
for(int i = 0; i < upper.size(); ++i) {
+
Item[] tmpr = right.range(upper,lower);
if(nearest.internal[i] > upper.internal[i])
+
if (0 == tmp.length)
nearest.internal[i] = upper.internal[i];
+
tmp = tmpr;
if(nearest.internal[i] < lower.internal[i])
+
else if (0 < tmpr.length) {
nearest.internal[i] = lower.internal[i];
+
Item[] tmp2 = new Item[tmp.length + tmpr.length];
 +
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
 +
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
 +
tmp = tmp2;
 +
}
 +
}
 +
return tmp;
 +
}
 +
//Bucket
 +
Item[] tmp = new Item[current];
 +
int n = 0;
 +
for (int i = 0; i < current; i++) {
 +
if (contains(upper, lower, bucket[i].pnt)) {
 +
tmp[n++] = bucket[i];
 +
}
 +
}
 +
Item[] tmp2 = new Item[n];
 +
System.arraycopy(tmp, 0, tmp2, 0, n);
 +
return tmp2;
 
}
 
}
return nearest;
+
}
+
//These are helper functions from here down
}
+
//check if this hyper rectangle contains a give hyper-point
</pre>
+
public boolean contains(double[] upper, double[] lower, double[] point) {
 
+
if(current == 0) return false;
=== PriorityDeque ===
+
for(int i=0; i<point.length; ++i) {
<pre>
+
if(point[i] > upper[i] ||
package org.csdgn.util;
+
point[i] < lower[i])
 
+
return false;
import java.util.Comparator;
+
}
 
+
return true;
/**
 
* This is a home made PriorityDeque with maximum size limiter, and
 
* comparator or natural ordering selection.
 
*
 
* @author Chase
 
* @param <E>
 
*/
 
public class PriorityDeque<E> {
 
private class Item {
 
public Item down, up;
 
public E obj;
 
public Item(E item) {
 
obj = item;
 
down = up = null;
 
 
}
 
}
}
+
//checks if two hyper-rectangles intersect
private final Comparator<? super E> comparator;
+
public boolean intersects(double[] up0, double[] low0,
private Item bottom, top;
+
double[] up1, double[] low1) {
private int maximum_size;
+
for(int i=0; i<up0.length; ++i) {
private int size;
+
if(up1[i] < low0[i] || low1[i] > up0[i]) return false;
+
}
/**
+
return true;
* Generic Constructor
 
*/
 
public PriorityDeque() {
 
this(null,-1);
 
}
 
/**
 
* Constructor with a defined maximum number of entries
 
* @param maximum
 
*/
 
public PriorityDeque(int maximum) {
 
this(null,maximum);
 
}
 
/**
 
* Constructor with a defined comparator and an unlimited number of items.
 
* @param comp
 
*/
 
public PriorityDeque(Comparator<? super E> comp) {
 
this(comp,-1);
 
}
 
/**
 
* Constructor with a defined conparator and limited number of items.
 
* @param comp
 
* @param maximum
 
*/
 
public PriorityDeque(Comparator<? super E> comp, int maximum) {
 
comparator = comp;
 
maximum_size = maximum;
 
bottom = top = null;
 
size = 0;
 
}
 
/**
 
* Adds an item to this deque.
 
* @param value - item to add
 
*/
 
public void offer(E value) {
 
//System.err.println("Offering: " + value.toString());
 
if(bottom == null) {
 
//System.err.println("-Is first item.");
 
bottom = top = new Item(value);
 
return;
 
}
 
//do ordering etc
 
if(comparator != null) {
 
//System.err.println("-Comparator.");
 
offerComparator(value);
 
} else {
 
//System.err.println("-Natural.");
 
offerNatural(value);
 
}
 
}
 
 
/**
 
* Removes and returns an item from the bottom of the list (lowest value)
 
* @return the lowest value (bottom)
 
*/
 
public E pollBottom() {
 
if(bottom == null) return null;
 
Item tmp = bottom;
 
if(bottom == top) bottom = top = null;
 
else {
 
bottom = bottom.up;
 
bottom.down = null;
 
tmp.up = null;
 
}
 
--size;
 
return tmp.obj;
 
}
 
/**
 
* Returns but does not remove the item from the bottom of the list.
 
* @return the lowest value (bottom)
 
*/
 
public E peekBottom() {
 
if(bottom == null) return null;
 
return bottom.obj;
 
}
 
/**
 
* Removes and returns an item from the top of the list (highest value).
 
* @return the highest value (top)
 
*/
 
public E pollTop() {
 
if(top == null) return null;
 
Item tmp = top;
 
if(bottom == top) bottom = top = null;
 
else {
 
top = top.down;
 
top.up = null;
 
tmp.down = null;
 
}
 
--size;
 
return tmp.obj;
 
}
 
/**
 
* Returns but does not remove the item from the top of the list (highest value).
 
* @return the highest value (top)
 
*/
 
public E peekTop() {
 
if(top == null) return null;
 
return top.obj;
 
}
 
/**
 
* This is a generic toArray, returns a non-castable Object
 
* array with all the items in this deque.
 
* @return an Object array with everything in this deque.
 
*/
 
public Object[] toArray() {
 
Object array[] = new Object[size+1];
 
Item current = bottom;
 
int i = 0;
 
while(current != null) {
 
array[i++] = current.obj;
 
current = current.up;
 
 
}
 
}
return array;
+
//splits a bucket into a branch
}
+
private void split(Item m) {
/**
+
//split based on our bound data
* Add/sorts a value using the comparator
+
slice = (upper[dim]+lower[dim])/2.0;
* @param value
+
left = new NodeKD(this);
*/
+
right = new NodeKD(this);
private void offerComparator(E value) {
+
for(int i=0; i<current; ++i) {
if(maximum_size > 0 && size >= maximum_size) {
+
if(bucket[i].pnt[dim] > slice)
if(comparator.compare(value, top.obj) >= 0) {
+
right.add(bucket[i]);
return;
+
else left.add(bucket[i]);
 
}
 
}
 +
bucket = null;
 +
add(m);
 
}
 
}
Item nItem = new Item(value);
+
//gets nearest point to data within this hyper rectangle
if(comparator.compare(value,bottom.obj) < 0) {
+
private double[] nearestRect(double[] data) {
//less than the bottom, put on the bottom
+
double[] nearest = data.clone();
nItem.up = bottom;
+
for(int i = 0; i < data.length; ++i) {
bottom.down = nItem;
+
if(nearest[i] > upper[i]) nearest[i] = upper[i];
bottom = nItem;
+
if(nearest[i] < lower[i]) nearest[i] = lower[i];
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
 
}
 
}
return;
+
return nearest;
 
}
 
}
+
//expands this hyper rectangle
//start at the bottom
+
private void expand(double[] data) {
Item current = bottom;
+
if(upper == null) {
while(current != null) {
+
upper = Arrays.copyOf(data, owner.dimensions);
if(comparator.compare(value,current.obj) < 0) {
+
lower = Arrays.copyOf(data, owner.dimensions);
nItem.up = current;
 
nItem.down = current.down;
 
current.down.up = nItem;
 
current.down = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
 
return;
 
return;
 
}
 
}
current = current.up;
+
for(int i=0; i<data.length; ++i) {
 +
if(upper[i] < data[i]) upper[i] = data[i];
 +
if(lower[i] > data[i]) lower[i] = data[i];
 +
}
 
}
 
}
+
}
//else put it on top
+
//A simple shift array that sifts data up
nItem.down = top;
+
//as we add new ones to lower in the array.
top.up = nItem;
+
private class ShiftArray {
top = nItem;
+
private Item[] items;
++size;
+
private final int max;
if(maximum_size > 0 && size > maximum_size) {
+
private int current;
pollTop();
+
private ShiftArray(int maximum) {
 +
max = maximum;
 +
current = 0;
 +
items = new Item[max];
 
}
 
}
}
+
private void add(Item m) {
+
int i;
/**
+
for(i=current;i>0&&items[i-1].distance > m.distance; --i);
* Add/sorts a value using the natural order
+
if(i >= max) return;
* @param value
+
if(current < max) ++current;
*/
+
System.arraycopy(items, i, items, i+1, current-(i+1));
@SuppressWarnings("unchecked")
+
items[i] = m;
private void offerNatural(E value) {
 
Comparable<? super E> key = (Comparable<? super E>)value;
 
if(maximum_size > 0 && size >= maximum_size) {
 
if(key.compareTo(top.obj) >= 0) {
 
return;
 
}
 
 
}
 
}
Item nItem = new Item(value);
+
private double getLongest() {
if(key.compareTo(bottom.obj) < 0) {
+
if (current < max) return Double.POSITIVE_INFINITY;
//less than the bottom, put on the bottom
+
return items[max-1].distance;
nItem.up = bottom;
 
bottom.down = nItem;
 
bottom = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
 
}
 
}
+
private Item[] getArray() {
//start at the bottom
+
return items.clone();
Item current = bottom;
 
while(current != null) {
 
if(key.compareTo(current.obj) < 0) {
 
nItem.up = current;
 
nItem.down = current.down;
 
current.down.up = nItem;
 
current.down = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
return;
 
}
 
current = current.up;
 
 
}
 
}
 
//else put it on top
 
nItem.down = top;
 
top.up = nItem;
 
top = nItem;
 
++size;
 
if(maximum_size > 0 && size > maximum_size) {
 
pollTop();
 
}
 
 
}
 
}
 
}
 
}
 
 
</pre>
 
</pre>

Revision as of 05:09, 2 March 2010

Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.

KDTreeC

package org.csdgn.util.kd2;

import java.util.Arrays;

/**
 * This is a KD Bucket Tree, for fast sorting and searching of K dimensional data.
 * @author Chase
 *
 */
public class KDTreeC {
	/**
	 * Item, for moving data around.
	 * @author Chase
	 */
	public class Item {
		public double[] pnt;
		public Object obj;
		public double distance;
		private Item(double[] p, Object o) {
			pnt = p; obj = o;
		}
	}
	private final int dimensions;
	private final int bucket_size;
	private NodeKD root;
	
	/**
	 * Constructor with value for dimensions.
	 * @param dimensions - Number of dimensions
	 */
	public KDTreeC(int dimensions) {
		this.dimensions = dimensions;
		this.bucket_size = 64;
		this.root = new NodeKD(this);
	}
	
	/**
	 * Constructor with value for dimensions and bucket size.
	 * @param dimensions - Number of dimensions
	 * @param bucket - Size of the buckets.
	 */
	public KDTreeC(int dimensions, int bucket) {
		this.dimensions = dimensions;
		this.bucket_size = bucket;
		this.root = new NodeKD(this);
	}
	
	/**
	 * Add a key and its associated value to the tree.
	 * @param key - Key to add
	 * @param val - object to add
	 */
	public void add(double[] key, Object val) {
		root.add(new Item(key,val));
	}
	
	/**
	 * Returns all PointKD within a certain range defined by an upper and lower PointKD.
	 * @param low - lower bounds of area
	 * @param high - upper bounds of area
	 * @return - All PointKD between low and high.
	 */
	public Item[] getRange(double[] low, double[] high) {
		return root.range(high, low);
	}
	
	/**
	 * Gets the N nearest neighbors to the given key.
	 * @param key - Key
	 * @param num - Number of results
	 * @return Array of Item Objects
	 */
	public Item[] getNearestNeighbor(double[] key, int num) {
		ShiftArray arr = new ShiftArray(num);
		root.nearestn(arr, key);
		return arr.getArray();
	}
	/**
	 * Compares arrays of double and returns the euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distance(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return Math.sqrt(total);
	}
	/**
	 * Compares arrays of double and returns the squared euclidean distance
	 * between them.
	 * 
	 * @param a - The first set of numbers
	 * @param b - The second set of numbers
	 * @return The distance squared between <b>a</b> and <b>b</b>.
	 */
	public static final double distanceSq(double[] a, double[] b) {
		double total = 0;
		for (int i = 0; i < a.length; ++i)
			total += (b[i] - a[i]) * (b[i] - a[i]);
		return total;
	}
	
	//Internal tree node
	private class NodeKD {
		private KDTreeC owner;
		private NodeKD left, right;
		private double[] upper, lower;
		private Item[] bucket;
		private int current, dim;
		private double slice;
		
		//note: we always start as a bucket
		private NodeKD(KDTreeC own) {
			owner = own;
			upper = lower = null;
			left = right = null;
			bucket = new Item[own.bucket_size];
			current = 0;
			dim = 0;
		}
		//when we create non-root nodes within this class
		//we use this one here
		private NodeKD(NodeKD node) {
			owner = node.owner;
			dim = node.dim + 1;
			bucket = new Item[owner.bucket_size];
			if(dim + 1 > owner.dimensions) dim = 0;
			left = right = null;
			upper = lower = null;
			current = 0;
		}
		//what it says on the tin
		private void add(Item m) {
			if(bucket == null) {
				//Branch
				if(m.pnt[dim] > slice)
					right.add(m);
				else left.add(m);
			} else {
				//Bucket
				if(current+1>owner.bucket_size) {
					split(m);
					return;
				}
				bucket[current++] = m;
			}
			expand(m.pnt);
		}
		//nearest neighbor thing
		private void nearestn(ShiftArray arr, double[] data) {
			if(bucket == null) {
				//Branch
				if(data[dim] > slice) {
					right.nearestn(arr, data);
					if(left.current != 0) {
						if(KDTreeC.distanceSq(left.nearestRect(data),data)
								< arr.getLongest()) {
							left.nearestn(arr, data);
						}
					}
							
				} else {
					left.nearestn(arr, data);
					if(right.current != 0) {
						if(KDTreeC.distanceSq(right.nearestRect(data),data) 
								< arr.getLongest()) {
							right.nearestn(arr, data);
						}
					}
				}
			} else {
				//Bucket
				for(int i = 0; i < current; i++) {
					bucket[i].distance = KDTreeC.distanceSq(bucket[i].pnt, data);
					arr.add(bucket[i]);
				}
			}
		}
		//gets all items from within a range
		private Item[] range(double[] upper, double[] lower) {
			//TODO: clean this up a bit
			if(bucket == null) {
				//Branch
				Item[] tmp = new Item[0];
				if (intersects(upper,lower,left.upper,left.lower)) {
					Item[] tmpl = left.range(upper,lower);
					if(0 == tmp.length)
						tmp = tmpl;
				}
				if (intersects(upper,lower,right.upper,right.lower)) {
					Item[] tmpr = right.range(upper,lower);
					if (0 == tmp.length)
						tmp = tmpr;
					else if (0 < tmpr.length) {
						Item[] tmp2 = new Item[tmp.length + tmpr.length];
						System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
						System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
						tmp = tmp2;
					}
				}
				return tmp;
			}
			//Bucket
			Item[] tmp = new Item[current];
			int n = 0;
			for (int i = 0; i < current; i++) {
				if (contains(upper, lower, bucket[i].pnt)) {
					tmp[n++] = bucket[i];
				}
			}
			Item[] tmp2 = new Item[n];
			System.arraycopy(tmp, 0, tmp2, 0, n);
			return tmp2;
		}
		
		//These are helper functions from here down
		//check if this hyper rectangle contains a give hyper-point
		public boolean contains(double[] upper, double[] lower, double[] point) {
			if(current == 0) return false;
			for(int i=0; i<point.length; ++i) {
				if(point[i] > upper[i] ||
					point[i] < lower[i]) 
						return false;
			}
			return true;
		}
		//checks if two hyper-rectangles intersect
		public boolean intersects(double[] up0, double[] low0,
				double[] up1, double[] low1) {
			for(int i=0; i<up0.length; ++i) {
				if(up1[i] < low0[i] || low1[i] > up0[i]) return false;
			}
			return true;
		}
		//splits a bucket into a branch
		private void split(Item m) {
			//split based on our bound data
			slice = (upper[dim]+lower[dim])/2.0;
			left = new NodeKD(this);
			right = new NodeKD(this);
			for(int i=0; i<current; ++i) {
				if(bucket[i].pnt[dim] > slice)
					right.add(bucket[i]);
				else left.add(bucket[i]);
			}
			bucket = null;
			add(m);
		}
		//gets nearest point to data within this hyper rectangle
		private double[] nearestRect(double[] data) {
			double[] nearest = data.clone();
			for(int i = 0; i < data.length; ++i) {
				if(nearest[i] > upper[i]) nearest[i] = upper[i];
				if(nearest[i] < lower[i]) nearest[i] = lower[i];
			}
			return nearest;
		}
		//expands this hyper rectangle
		private void expand(double[] data) {
			if(upper == null) {
				upper = Arrays.copyOf(data, owner.dimensions);
				lower = Arrays.copyOf(data, owner.dimensions);
				return;
			}
			for(int i=0; i<data.length; ++i) {
				if(upper[i] < data[i]) upper[i] = data[i];
				if(lower[i] > data[i]) lower[i] = data[i];
			}
		}
	}
	//A simple shift array that sifts data up
	//as we add new ones to lower in the array.
	private class ShiftArray {
		private Item[] items;
		private final int max;
		private int current;
		private ShiftArray(int maximum) {
			max = maximum;
			current = 0;
			items = new Item[max];
		}
		private void add(Item m) {
			int i;
			for(i=current;i>0&&items[i-1].distance >  m.distance; --i);
			if(i >= max) return;
			if(current < max) ++current;
			System.arraycopy(items, i, items, i+1, current-(i+1));
			items[i] = m;
		}
		private double getLongest() {
			if (current < max) return Double.POSITIVE_INFINITY;
			return items[max-1].distance;
		}
		private Item[] getArray() {
			return items.clone();
		}
	}
}