User:Chase-san/Kd-Tree
Jump to navigation
Jump to search
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
Contents
KDTreeB
package org.csdgn.util;
import java.util.Comparator;
public class KDTreeB {
public final static int DEFAULT_BUCKET_SIZE = 200;
protected NodeKD root;
protected int dimensions;
protected int bucket_size;
protected long size = 0;
/**
* This creates a KDTreeB with the given number of dimensions and
* the default bucket size.
* @param dimensions
*/
public KDTreeB(int dimensions) {
this(dimensions,DEFAULT_BUCKET_SIZE);
}
/**
* This creates a KDTreeB with the given number of dimensions and
* the given bucket size.
* @param dimensions
* @param bucket_size
*/
public KDTreeB(int dimensions, int bucket_size) {
this.dimensions = dimensions;
this.bucket_size = bucket_size;
root = new BucketKD(this);
}
/**
* Adds the given point to the tree. Uses a recursive algorithm.
* @param point
*/
public void add(PointKD point) {
root.add(point);
}
/**
* Returns the nearest neighbor to the given point.
* @param point - PointKD to find the nearest to.
* @return - The nearest PointKD to
*/
public PointKD getNearestNeighbor(PointKD point) {
return root.nearest(point);
}
/**
* Returns the nearest num neighbors to the given point.
* @param point - PointKD to find the nearest to.
* @param num - The number of points to find.
* @return - The nearest PointKD's to point.
*/
public PointKD[] getNearestNeighbors(final PointKD point, int num) {
if(num == 1) {
return new PointKD[] { getNearestNeighbor(point) };
}
PriorityDeque<PointKD> queue = new PriorityDeque<PointKD>(
new Comparator<PointKD>(){
@Override
public int compare(PointKD o1, PointKD o2) {
double dist0 = point.distanceSq(o1);
double dist1 = point.distanceSq(o2);
return (dist0-dist1 < 0) ? -1 : 1;
}
},num);
root.nearestn(queue,point);
PointKD[] array = new PointKD[num];
Object[] obj = queue.toArray();
for(int i=0; i<num; ++i) {
array[i] = (PointKD)obj[i];
}
return array;
}
/**
* Returns all PointKD within a certain RectangleKD.
* @param rect - area to get PointKD from
* @return - All PointKD within rect.
*/
public PointKD[] getRange(RectangleKD rect) {
return root.range(rect);
}
/**
* Returns all PointKD within a certain range defined by an upper and lower PointKD.
* @param low - lower bounds of area
* @param high - upper bounds of area
* @return - All PointKD between low and high.
*/
public PointKD[] getRange(PointKD low, PointKD high) {
return this.getRange(new RectangleKD(low, high));
}
protected abstract class NodeKD {
protected KDTreeB owner;
protected BranchKD parent;
protected RectangleKD rect;
protected int depth = 0;
protected abstract void add(PointKD k);
protected abstract PointKD nearest(PointKD k);
protected abstract void nearestn(PriorityDeque<PointKD> queue, PointKD k);
protected abstract PointKD[] range(RectangleKD r);
}
protected class BucketKD extends NodeKD {
protected PointKD bucket[];
protected int current;
public BucketKD(KDTreeB owner) {
this.owner = owner;
bucket = new PointKD[owner.bucket_size];
rect = new RectangleKD();
parent = null;
}
public BucketKD(BranchKD branch) {
this(branch.owner);
parent = branch;
depth = branch.depth + 1;
}
@Override
protected void add(PointKD point) {
if(current >= bucket.length) {
//Split the bucket into a branch
BranchKD branch = new BranchKD(this);
if(parent == null) {
owner.root = branch;
} else {
if(parent.isLeft(this)) {
parent.left = branch;
} else {
parent.right = branch;
}
}
branch.add(point);
bucket = null;
current = 0;
return;
}
bucket[current++] = point;
rect.expand(point);
}
@Override
protected PointKD nearest(PointKD k) {
double nearestDist = Double.POSITIVE_INFINITY;
int nearest = 0;
for (int i = 0; i < current; i++) {
double distance = k.distanceSq(bucket[i]);
if (distance < nearestDist) {
nearestDist = distance;
nearest = i;
}
}
return bucket[nearest];
}
@Override
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
for(int i = 0; i < current; i++) {
queue.offer(bucket[i]);
}
}
@Override
protected PointKD[] range(RectangleKD r) {
PointKD[] tmp = new PointKD[current];
int n = 0;
for (int i = 0; i < current; i++) {
if (r.contains(bucket[i])) {
tmp[n++] = bucket[i];
}
}
PointKD[] tmp2 = new PointKD[n];
System.arraycopy(tmp, 0, tmp2, 0, n);
return tmp2;
}
}
protected class BranchKD extends NodeKD {
protected NodeKD left, right;
protected double slice;
protected int dim;
public BranchKD(BucketKD k) {
owner = k.owner;
parent = k.parent;
slice = 0;
rect = k.rect;
depth = k.depth;
left = new BucketKD(this);
right = new BucketKD(this);
dim = depth % owner.dimensions;
double total = 0;
for (int i = 0; i < k.current; i++) {
total += k.bucket[i].internal[dim];
}
slice = total / k.current;
for (int i = 0; i < k.current; i++)
add(k.bucket[i]);
}
@Override
protected void add(PointKD k) {
if(k.internal[dim] > slice) {
right.add(k);
} else {
left.add(k);
}
}
@Override
protected PointKD nearest(PointKD k) {
PointKD near = null;
if (k.internal[dim] > slice) {
near = right.nearest(k);
double t = near.distanceSq(k);
if (k.distanceSq(left.rect.getNearest(k)) < t) {
PointKD tmp = left.nearest(k);
if (tmp.distanceSq(k) < t) {
near = tmp;
}
}
} else {
near = left.nearest(k);
double t = near.distanceSq(k);
if (k.distanceSq(right.rect.getNearest(k)) < t) {
PointKD tmp = right.nearest(k);
if (tmp.distanceSq(k) < t) {
near = tmp;
}
}
}
return near;
}
@Override
protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) {
//TODO
if(k.internal[dim] > slice) {
right.nearestn(queue, k);
double t = queue.peekBottom().distanceSq(k);
if(k.distanceSq(left.rect.getNearest(k)) < t) {
left.nearestn(queue, k);
}
} else {
left.nearestn(queue, k);
double t = queue.peekBottom().distanceSq(k);
if(k.distanceSq(right.rect.getNearest(k)) < t) {
right.nearestn(queue, k);
}
}
}
@Override
protected PointKD[] range(RectangleKD r) {
PointKD[] tmp = new PointKD[0];
if (r.intersects(left.rect)) {
PointKD[] tmpl = left.range(r);
if(0 == tmp.length)
tmp = tmpl;
}
if (r.intersects(right.rect)) {
PointKD[] tmpr = right.range(r);
if (0 == tmp.length)
tmp = tmpr;
else if (0 < tmpr.length) {
PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length];
System.arraycopy(tmp, 0, tmp2, 0, tmp.length);
System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length);
tmp = tmp2;
}
}
return tmp;
}
protected boolean isLeft(BucketKD kd) {
if(left == kd) return true;
return false;
}
}
}
PointKD
package org.csdgn.util;
import java.io.Serializable;
/**
* PointKD class is a class that wraps a double array, for use in K-Dimensional structures.
* This wrapping is done to improve readability and modularity. Often used to define a
* point in k-dimensional space.
*
* @author Chase
*/
public class PointKD implements Serializable {
private static final long serialVersionUID = -841162798668123755L;
protected double[] internal;
/**
* Constructor
* @param dimensions - Number of dimensions for this PointKD.
*/
public PointKD(int dimensions) {
internal = new double[dimensions];
}
/**
* Constructor
* @param array - An array to use for the internal array of this PointKD.
*/
public PointKD(double[] array) {
internal = array.clone();
}
/**
* Constructor
* @param point - PointKD to clone.
*/
public PointKD(PointKD point) {
this.internal = point.internal.clone();
}
/**
* Returns array.
* @return The internal array of this pointKD, changes made to this will effect this PointKD.
*/
public double[] get() {
return internal;
}
/**
* Sets the location of this point.
* @param point - PointKD to copy.
*/
public void set(PointKD point) {
this.internal = point.internal.clone();
}
/**
* Sets the location of this point.
* @param array - Array to copy.
*/
public void set(double[] array) {
this.internal = array.clone();
}
/**
* Sets the value at dimension index
* @param index - Index
* @param value - Value to set
*/
public void set(int index, double value) {
internal[index] = value;
}
/**
* @return number of dimensions
*/
public int size() {
return internal.length;
}
/**
* Compares this to a selected point and returns the euclidean distance
* between them.
*
* @param p
* - The Point to get the distance to.
* @return The distance between this and <b>p</b>.
*/
public double distance(PointKD p) {
return distance(this.internal, p.internal);
}
/**
* Compares this to a selected point and returns the squared euclidean
* distance between them.
*
* @param p
* - The Point to get the distance to.
* @return The distance between this and <b>p</b>.
*/
public double distanceSq(PointKD p) {
return distanceSq(this.internal, p.internal);
}
/**
* Clones this point.
*/
public PointKD clone() {
return new PointKD(internal);
}
/**
* Prints the class name and the point coordinates.
*/
public String toString() {
String output = getClass().getSimpleName() + "[";
for (int i = 0; i < internal.length; ++i) {
if (0 != i)
output += ",";
output += internal[i];
}
return output + "]";
}
/**
* Compares arrays of double and returns the euclidean distance
* between them.
*
* @param a - The first set of numbers
* @param b - The second set of numbers
* @return The distance squared between <b>a</b> and <b>b</b>.
*/
public static final double distance(double[] a, double[] b) {
double total = 0;
for (int i = 0; i < a.length; ++i)
total += (b[i] - a[i]) * (b[i] - a[i]);
return Math.sqrt(total);
}
/**
* Compares arrays of double and returns the squared euclidean distance
* between them.
*
* @param a - The first set of numbers
* @param b - The second set of numbers
* @return The distance squared between <b>a</b> and <b>b</b>.
*/
public static final double distanceSq(double[] a, double[] b) {
double total = 0;
for (int i = 0; i < a.length; ++i)
total += (b[i] - a[i]) * (b[i] - a[i]);
return total;
}
}
RectangleKD
package org.csdgn.util;
import java.io.Serializable;
public class RectangleKD implements Serializable {
private static final long serialVersionUID = 524388821816648020L;
protected PointKD upper, lower;
/**
* Creates an empty RectangleKD
*/
public RectangleKD() {
upper = null;
lower = null;
}
/**
* Creates a RectangleKD from two points.
* @param lower
* @param upper
*/
public RectangleKD(PointKD lower, PointKD upper) {
this.lower = lower.clone();
this.upper = upper.clone();
}
/**
* Expands this RectangleKD to include point.
* @param point - Point used to expand this Rectangle
*/
public void expand(PointKD point) {
if(upper == null) {
upper = point.clone();
lower = point.clone();
return;
}
for(int i=0; i<point.size(); ++i) {
if(upper.internal[i] < point.internal[i])
upper.internal[i] = point.internal[i];
if(lower.internal[i] > point.internal[i])
lower.internal[i] = point.internal[i];
}
}
/**
* Checks if this RectangleKD contains the given point.
* @param point - PointKD to check.
* @return true - if this rectangle contains the given point.
*/
public boolean contains(PointKD point) {
if(upper == null) return false;
for(int i=0; i<point.size(); ++i) {
if(point.internal[i] > upper.internal[i] ||
point.internal[i] < lower.internal[i])
return false;
}
return true;
}
/**
* Checks if this table intersects the given table.
* @param rectangle - The rectangle to check intersection with.
* @return true - if this rectangle intersects the given rectangle.
*/
public boolean intersects(RectangleKD rectangle) {
if(rectangle.upper == null || upper == null) return false;
for(int i=0; i<upper.size(); ++i) {
if(rectangle.upper.internal[i] < lower.internal[i]
|| rectangle.lower.internal[i] > upper.internal[i]) return false;
}
return true;
}
/**
* Gets the nearest point in this RectangleKD to the given point
* @param point - PointKD to get the nearest point to.
* @return the nearest point in this RectangleKD to the given point.
*/
public PointKD getNearest(PointKD point) {
if(upper == null) return null;
if(contains(point)) return point; //This check may not be needed.
PointKD nearest = new PointKD(point);
for(int i = 0; i < upper.size(); ++i) {
if(nearest.internal[i] > upper.internal[i])
nearest.internal[i] = upper.internal[i];
if(nearest.internal[i] < lower.internal[i])
nearest.internal[i] = lower.internal[i];
}
return nearest;
}
}
PriorityDeque
package org.csdgn.util;
import java.util.Comparator;
/**
* This is a home made PriorityDeque with maximum size limiter, and
* comparator or natural ordering selection.
*
* @author Chase
* @param <E>
*/
public class PriorityDeque<E> {
private class Item {
public Item down, up;
public E obj;
public Item(E item) {
obj = item;
down = up = null;
}
}
private final Comparator<? super E> comparator;
private Item bottom, top;
private int maximum_size;
private int size;
/**
* Generic Constructor
*/
public PriorityDeque() {
this(null,-1);
}
/**
* Constructor with a defined maximum number of entries
* @param maximum
*/
public PriorityDeque(int maximum) {
this(null,maximum);
}
/**
* Constructor with a defined comparator and an unlimited number of items.
* @param comp
*/
public PriorityDeque(Comparator<? super E> comp) {
this(comp,-1);
}
/**
* Constructor with a defined conparator and limited number of items.
* @param comp
* @param maximum
*/
public PriorityDeque(Comparator<? super E> comp, int maximum) {
comparator = comp;
maximum_size = maximum;
bottom = top = null;
size = 0;
}
/**
* Adds an item to this deque.
* @param value - item to add
*/
public void offer(E value) {
//System.err.println("Offering: " + value.toString());
if(bottom == null) {
//System.err.println("-Is first item.");
bottom = top = new Item(value);
return;
}
//do ordering etc
if(comparator != null) {
//System.err.println("-Comparator.");
offerComparator(value);
} else {
//System.err.println("-Natural.");
offerNatural(value);
}
}
/**
* Removes and returns an item from the bottom of the list (lowest value)
* @return the lowest value (bottom)
*/
public E pollBottom() {
if(bottom == null) return null;
Item tmp = bottom;
if(bottom == top) bottom = top = null;
else {
bottom = bottom.up;
bottom.down = null;
tmp.up = null;
}
--size;
return tmp.obj;
}
/**
* Returns but does not remove the item from the bottom of the list.
* @return the lowest value (bottom)
*/
public E peekBottom() {
if(bottom == null) return null;
return bottom.obj;
}
/**
* Removes and returns an item from the top of the list (highest value).
* @return the highest value (top)
*/
public E pollTop() {
if(top == null) return null;
Item tmp = top;
if(bottom == top) bottom = top = null;
else {
top = top.down;
top.up = null;
tmp.down = null;
}
--size;
return tmp.obj;
}
/**
* Returns but does not remove the item from the top of the list (highest value).
* @return the highest value (top)
*/
public E peekTop() {
if(top == null) return null;
return top.obj;
}
/**
* This is a generic toArray, returns a non-castable Object
* array with all the items in this deque.
* @return an Object array with everything in this deque.
*/
public Object[] toArray() {
Object array[] = new Object[size+1];
Item current = bottom;
int i = 0;
while(current != null) {
array[i++] = current.obj;
current = current.up;
}
return array;
}
/**
* Add/sorts a value using the comparator
* @param value
*/
private void offerComparator(E value) {
if(maximum_size > 0 && size >= maximum_size) {
if(comparator.compare(value, top.obj) >= 0) {
return;
}
}
Item nItem = new Item(value);
if(comparator.compare(value,bottom.obj) < 0) {
//less than the bottom, put on the bottom
nItem.up = bottom;
bottom.down = nItem;
bottom = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
return;
}
//start at the bottom
Item current = bottom;
while(current != null) {
if(comparator.compare(value,current.obj) < 0) {
nItem.up = current;
nItem.down = current.down;
current.down.up = nItem;
current.down = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
return;
}
current = current.up;
}
//else put it on top
nItem.down = top;
top.up = nItem;
top = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
}
/**
* Add/sorts a value using the natural order
* @param value
*/
@SuppressWarnings("unchecked")
private void offerNatural(E value) {
Comparable<? super E> key = (Comparable<? super E>)value;
if(maximum_size > 0 && size >= maximum_size) {
if(key.compareTo(top.obj) >= 0) {
return;
}
}
Item nItem = new Item(value);
if(key.compareTo(bottom.obj) < 0) {
//less than the bottom, put on the bottom
nItem.up = bottom;
bottom.down = nItem;
bottom = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
return;
}
//start at the bottom
Item current = bottom;
while(current != null) {
if(key.compareTo(current.obj) < 0) {
nItem.up = current;
nItem.down = current.down;
current.down.up = nItem;
current.down = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
return;
}
current = current.up;
}
//else put it on top
nItem.down = top;
top.up = nItem;
top = nItem;
++size;
if(maximum_size > 0 && size > maximum_size) {
pollTop();
}
}
}