Difference between revisions of "User:Chase-san/Kd-Tree"
Jump to navigation
Jump to search
(Updated, now with nearest N neighbors) |
(Changed distance to use distanceSq in comparator) |
||
Line 66: | Line 66: | ||
@Override | @Override | ||
public int compare(PointKD o1, PointKD o2) { | public int compare(PointKD o1, PointKD o2) { | ||
− | double dist0 = point. | + | double dist0 = point.distanceSq(o1); |
− | double dist1 = point. | + | double dist1 = point.distanceSq(o2); |
return (dist0-dist1 < 0) ? -1 : 1; | return (dist0-dist1 < 0) ? -1 : 1; | ||
} | } |
Revision as of 22:52, 1 March 2010
Everyone and their brother has one of these now, me and Simonton started it, but I was to inexperienced to get anything written, I took an hour or two to rewrite it today, because I am no longer completely terrible at these things. So here is mine if you care to see it.
Contents
KDTreeB
package org.csdgn.util; import java.util.Comparator; public class KDTreeB { public final static int DEFAULT_BUCKET_SIZE = 200; protected NodeKD root; protected int dimensions; protected int bucket_size; protected long size = 0; /** * This creates a KDTreeB with the given number of dimensions and * the default bucket size. * @param dimensions */ public KDTreeB(int dimensions) { this(dimensions,DEFAULT_BUCKET_SIZE); } /** * This creates a KDTreeB with the given number of dimensions and * the given bucket size. * @param dimensions * @param bucket_size */ public KDTreeB(int dimensions, int bucket_size) { this.dimensions = dimensions; this.bucket_size = bucket_size; root = new BucketKD(this); } /** * Adds the given point to the tree. Uses a recursive algorithm. * @param point */ public void add(PointKD point) { root.add(point); } /** * Returns the nearest neighbor to the given point. * @param point - PointKD to find the nearest to. * @return - The nearest PointKD to */ public PointKD getNearestNeighbor(PointKD point) { return root.nearest(point); } /** * Returns the nearest num neighbors to the given point. * @param point - PointKD to find the nearest to. * @param num - The number of points to find. * @return - The nearest PointKD's to point. */ public PointKD[] getNearestNeighbors(final PointKD point, int num) { if(num == 1) { return new PointKD[] { getNearestNeighbor(point) }; } PriorityDeque<PointKD> queue = new PriorityDeque<PointKD>( new Comparator<PointKD>(){ @Override public int compare(PointKD o1, PointKD o2) { double dist0 = point.distanceSq(o1); double dist1 = point.distanceSq(o2); return (dist0-dist1 < 0) ? -1 : 1; } },num); root.nearestn(queue,point); PointKD[] array = new PointKD[num]; Object[] obj = queue.toArray(); for(int i=0; i<num; ++i) { array[i] = (PointKD)obj[i]; } return array; } /** * Returns all PointKD within a certain RectangleKD. * @param rect - area to get PointKD from * @return - All PointKD within rect. */ public PointKD[] getRange(RectangleKD rect) { return root.range(rect); } /** * Returns all PointKD within a certain range defined by an upper and lower PointKD. * @param low - lower bounds of area * @param high - upper bounds of area * @return - All PointKD between low and high. */ public PointKD[] getRange(PointKD low, PointKD high) { return this.getRange(new RectangleKD(low, high)); } protected abstract class NodeKD { protected KDTreeB owner; protected BranchKD parent; protected RectangleKD rect; protected int depth = 0; protected abstract void add(PointKD k); protected abstract PointKD nearest(PointKD k); protected abstract void nearestn(PriorityDeque<PointKD> queue, PointKD k); protected abstract PointKD[] range(RectangleKD r); } protected class BucketKD extends NodeKD { protected PointKD bucket[]; protected int current; public BucketKD(KDTreeB owner) { this.owner = owner; bucket = new PointKD[owner.bucket_size]; rect = new RectangleKD(); parent = null; } public BucketKD(BranchKD branch) { this(branch.owner); parent = branch; depth = branch.depth + 1; } @Override protected void add(PointKD point) { if(current >= bucket.length) { //Split the bucket into a branch BranchKD branch = new BranchKD(this); if(parent == null) { owner.root = branch; } else { if(parent.isLeft(this)) { parent.left = branch; } else { parent.right = branch; } } branch.add(point); bucket = null; current = 0; return; } bucket[current++] = point; rect.expand(point); } @Override protected PointKD nearest(PointKD k) { double nearestDist = Double.POSITIVE_INFINITY; int nearest = 0; for (int i = 0; i < current; i++) { double distance = k.distanceSq(bucket[i]); if (distance < nearestDist) { nearestDist = distance; nearest = i; } } return bucket[nearest]; } @Override protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) { for(int i = 0; i < current; i++) { queue.offer(bucket[i]); } } @Override protected PointKD[] range(RectangleKD r) { PointKD[] tmp = new PointKD[current]; int n = 0; for (int i = 0; i < current; i++) { if (r.contains(bucket[i])) { tmp[n++] = bucket[i]; } } PointKD[] tmp2 = new PointKD[n]; System.arraycopy(tmp, 0, tmp2, 0, n); return tmp2; } } protected class BranchKD extends NodeKD { protected NodeKD left, right; protected double slice; protected int dim; public BranchKD(BucketKD k) { owner = k.owner; parent = k.parent; slice = 0; rect = k.rect; depth = k.depth; left = new BucketKD(this); right = new BucketKD(this); dim = depth % owner.dimensions; double total = 0; for (int i = 0; i < k.current; i++) { total += k.bucket[i].internal[dim]; } slice = total / k.current; for (int i = 0; i < k.current; i++) add(k.bucket[i]); } @Override protected void add(PointKD k) { if(k.internal[dim] > slice) { right.add(k); } else { left.add(k); } } @Override protected PointKD nearest(PointKD k) { PointKD near = null; if (k.internal[dim] > slice) { near = right.nearest(k); double t = near.distanceSq(k); if (k.distanceSq(left.rect.getNearest(k)) < t) { PointKD tmp = left.nearest(k); if (tmp.distanceSq(k) < t) { near = tmp; } } } else { near = left.nearest(k); double t = near.distanceSq(k); if (k.distanceSq(right.rect.getNearest(k)) < t) { PointKD tmp = right.nearest(k); if (tmp.distanceSq(k) < t) { near = tmp; } } } return near; } @Override protected void nearestn(PriorityDeque<PointKD> queue, PointKD k) { //TODO if(k.internal[dim] > slice) { right.nearestn(queue, k); double t = queue.peekBottom().distanceSq(k); if(k.distanceSq(left.rect.getNearest(k)) < t) { left.nearestn(queue, k); } } else { left.nearestn(queue, k); double t = queue.peekBottom().distanceSq(k); if(k.distanceSq(right.rect.getNearest(k)) < t) { right.nearestn(queue, k); } } } @Override protected PointKD[] range(RectangleKD r) { PointKD[] tmp = new PointKD[0]; if (r.intersects(left.rect)) { PointKD[] tmpl = left.range(r); if(0 == tmp.length) tmp = tmpl; } if (r.intersects(right.rect)) { PointKD[] tmpr = right.range(r); if (0 == tmp.length) tmp = tmpr; else if (0 < tmpr.length) { PointKD[] tmp2 = new PointKD[tmp.length + tmpr.length]; System.arraycopy(tmp, 0, tmp2, 0, tmp.length); System.arraycopy(tmpr, 0, tmp2, tmp.length, tmpr.length); tmp = tmp2; } } return tmp; } protected boolean isLeft(BucketKD kd) { if(left == kd) return true; return false; } } }
PointKD
package org.csdgn.util; import java.io.Serializable; /** * PointKD class is a class that wraps a double array, for use in K-Dimensional structures. * This wrapping is done to improve readability and modularity. Often used to define a * point in k-dimensional space. * * @author Chase */ public class PointKD implements Serializable { private static final long serialVersionUID = -841162798668123755L; protected double[] internal; /** * Constructor * @param dimensions - Number of dimensions for this PointKD. */ public PointKD(int dimensions) { internal = new double[dimensions]; } /** * Constructor * @param array - An array to use for the internal array of this PointKD. */ public PointKD(double[] array) { internal = array.clone(); } /** * Constructor * @param point - PointKD to clone. */ public PointKD(PointKD point) { this.internal = point.internal.clone(); } /** * Returns array. * @return The internal array of this pointKD, changes made to this will effect this PointKD. */ public double[] get() { return internal; } /** * Sets the location of this point. * @param point - PointKD to copy. */ public void set(PointKD point) { this.internal = point.internal.clone(); } /** * Sets the location of this point. * @param array - Array to copy. */ public void set(double[] array) { this.internal = array.clone(); } /** * Sets the value at dimension index * @param index - Index * @param value - Value to set */ public void set(int index, double value) { internal[index] = value; } /** * @return number of dimensions */ public int size() { return internal.length; } /** * Compares this to a selected point and returns the euclidean distance * between them. * * @param p * - The Point to get the distance to. * @return The distance between this and <b>p</b>. */ public double distance(PointKD p) { return distance(this.internal, p.internal); } /** * Compares this to a selected point and returns the squared euclidean * distance between them. * * @param p * - The Point to get the distance to. * @return The distance between this and <b>p</b>. */ public double distanceSq(PointKD p) { return distanceSq(this.internal, p.internal); } /** * Clones this point. */ public PointKD clone() { return new PointKD(internal); } /** * Prints the class name and the point coordinates. */ public String toString() { String output = getClass().getSimpleName() + "["; for (int i = 0; i < internal.length; ++i) { if (0 != i) output += ","; output += internal[i]; } return output + "]"; } /** * Compares arrays of double and returns the euclidean distance * between them. * * @param a - The first set of numbers * @param b - The second set of numbers * @return The distance squared between <b>a</b> and <b>b</b>. */ public static final double distance(double[] a, double[] b) { double total = 0; for (int i = 0; i < a.length; ++i) total += (b[i] - a[i]) * (b[i] - a[i]); return Math.sqrt(total); } /** * Compares arrays of double and returns the squared euclidean distance * between them. * * @param a - The first set of numbers * @param b - The second set of numbers * @return The distance squared between <b>a</b> and <b>b</b>. */ public static final double distanceSq(double[] a, double[] b) { double total = 0; for (int i = 0; i < a.length; ++i) total += (b[i] - a[i]) * (b[i] - a[i]); return total; } }
RectangleKD
package org.csdgn.util; import java.io.Serializable; public class RectangleKD implements Serializable { private static final long serialVersionUID = 524388821816648020L; protected PointKD upper, lower; /** * Creates an empty RectangleKD */ public RectangleKD() { upper = null; lower = null; } /** * Creates a RectangleKD from two points. * @param lower * @param upper */ public RectangleKD(PointKD lower, PointKD upper) { this.lower = lower.clone(); this.upper = upper.clone(); } /** * Expands this RectangleKD to include point. * @param point - Point used to expand this Rectangle */ public void expand(PointKD point) { if(upper == null) { upper = point.clone(); lower = point.clone(); return; } for(int i=0; i<point.size(); ++i) { if(upper.internal[i] < point.internal[i]) upper.internal[i] = point.internal[i]; if(lower.internal[i] > point.internal[i]) lower.internal[i] = point.internal[i]; } } /** * Checks if this RectangleKD contains the given point. * @param point - PointKD to check. * @return true - if this rectangle contains the given point. */ public boolean contains(PointKD point) { if(upper == null) return false; for(int i=0; i<point.size(); ++i) { if(point.internal[i] > upper.internal[i] || point.internal[i] < lower.internal[i]) return false; } return true; } /** * Checks if this table intersects the given table. * @param rectangle - The rectangle to check intersection with. * @return true - if this rectangle intersects the given rectangle. */ public boolean intersects(RectangleKD rectangle) { if(rectangle.upper == null || upper == null) return false; for(int i=0; i<upper.size(); ++i) { if(rectangle.upper.internal[i] < lower.internal[i] || rectangle.lower.internal[i] > upper.internal[i]) return false; } return true; } /** * Gets the nearest point in this RectangleKD to the given point * @param point - PointKD to get the nearest point to. * @return the nearest point in this RectangleKD to the given point. */ public PointKD getNearest(PointKD point) { if(upper == null) return null; if(contains(point)) return point; //This check may not be needed. PointKD nearest = new PointKD(point); for(int i = 0; i < upper.size(); ++i) { if(nearest.internal[i] > upper.internal[i]) nearest.internal[i] = upper.internal[i]; if(nearest.internal[i] < lower.internal[i]) nearest.internal[i] = lower.internal[i]; } return nearest; } }
PriorityDeque
package org.csdgn.util; import java.util.Comparator; /** * This is a home made PriorityDeque with maximum size limiter, and * comparator or natural ordering selection. * * @author Chase * @param <E> */ public class PriorityDeque<E> { private class Item { public Item down, up; public E obj; public Item(E item) { obj = item; down = up = null; } } private final Comparator<? super E> comparator; private Item bottom, top; private int maximum_size; private int size; /** * Generic Constructor */ public PriorityDeque() { this(null,-1); } /** * Constructor with a defined maximum number of entries * @param maximum */ public PriorityDeque(int maximum) { this(null,maximum); } /** * Constructor with a defined comparator and an unlimited number of items. * @param comp */ public PriorityDeque(Comparator<? super E> comp) { this(comp,-1); } /** * Constructor with a defined conparator and limited number of items. * @param comp * @param maximum */ public PriorityDeque(Comparator<? super E> comp, int maximum) { comparator = comp; maximum_size = maximum; bottom = top = null; size = 0; } /** * Adds an item to this deque. * @param value - item to add */ public void offer(E value) { //System.err.println("Offering: " + value.toString()); if(bottom == null) { //System.err.println("-Is first item."); bottom = top = new Item(value); return; } //do ordering etc if(comparator != null) { //System.err.println("-Comparator."); offerComparator(value); } else { //System.err.println("-Natural."); offerNatural(value); } } /** * Removes and returns an item from the bottom of the list (lowest value) * @return the lowest value (bottom) */ public E pollBottom() { if(bottom == null) return null; Item tmp = bottom; if(bottom == top) bottom = top = null; else { bottom = bottom.up; bottom.down = null; tmp.up = null; } --size; return tmp.obj; } /** * Returns but does not remove the item from the bottom of the list. * @return the lowest value (bottom) */ public E peekBottom() { if(bottom == null) return null; return bottom.obj; } /** * Removes and returns an item from the top of the list (highest value). * @return the highest value (top) */ public E pollTop() { if(top == null) return null; Item tmp = top; if(bottom == top) bottom = top = null; else { top = top.down; top.up = null; tmp.down = null; } --size; return tmp.obj; } /** * Returns but does not remove the item from the top of the list (highest value). * @return the highest value (top) */ public E peekTop() { if(top == null) return null; return top.obj; } /** * This is a generic toArray, returns a non-castable Object * array with all the items in this deque. * @return an Object array with everything in this deque. */ public Object[] toArray() { Object array[] = new Object[size+1]; Item current = bottom; int i = 0; while(current != null) { array[i++] = current.obj; current = current.up; } return array; } /** * Add/sorts a value using the comparator * @param value */ private void offerComparator(E value) { if(maximum_size > 0 && size >= maximum_size) { if(comparator.compare(value, top.obj) >= 0) { return; } } Item nItem = new Item(value); if(comparator.compare(value,bottom.obj) < 0) { //less than the bottom, put on the bottom nItem.up = bottom; bottom.down = nItem; bottom = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } return; } //start at the bottom Item current = bottom; while(current != null) { if(comparator.compare(value,current.obj) < 0) { nItem.up = current; nItem.down = current.down; current.down.up = nItem; current.down = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } return; } current = current.up; } //else put it on top nItem.down = top; top.up = nItem; top = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } } /** * Add/sorts a value using the natural order * @param value */ @SuppressWarnings("unchecked") private void offerNatural(E value) { Comparable<? super E> key = (Comparable<? super E>)value; if(maximum_size > 0 && size >= maximum_size) { if(key.compareTo(top.obj) >= 0) { return; } } Item nItem = new Item(value); if(key.compareTo(bottom.obj) < 0) { //less than the bottom, put on the bottom nItem.up = bottom; bottom.down = nItem; bottom = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } return; } //start at the bottom Item current = bottom; while(current != null) { if(key.compareTo(current.obj) < 0) { nItem.up = current; nItem.down = current.down; current.down.up = nItem; current.down = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } return; } current = current.up; } //else put it on top nItem.down = top; top.up = nItem; top = nItem; ++size; if(maximum_size > 0 && size > maximum_size) { pollTop(); } } }