Difference between revisions of "User:Rednaxela/kD-Tree"
Jump to navigation
Jump to search
(Refactor heap code. Nearest neighbour search now with optional sequential sorting (descending order). This is for you ABC :)) |
(Finally, a tree that does what everyone wants ;)) |
||
Line 35: | Line 35: | ||
* @author Rednaxela | * @author Rednaxela | ||
*/ | */ | ||
− | public class KdTree<T> { | + | public abstract class KdTree<T> { |
// Static variables | // Static variables | ||
− | private static final int bucketSize = 32; | + | private static final int bucketSize = 32; |
// All types | // All types | ||
− | private final int dimensions; | + | private final int dimensions; |
− | private final KdTree<T> parent; | + | private final KdTree<T> parent; |
// Root only | // Root only | ||
− | private final HashMap<Object, T> map | + | private final HashMap<Object, T> map; |
− | |||
private final LinkedList<double[]> locationStack; | private final LinkedList<double[]> locationStack; | ||
− | private final Integer sizeLimit; | + | private final Integer sizeLimit; |
// Leaf only | // Leaf only | ||
− | private double[][] locations; | + | private double[][] locations; |
− | private int locationCount; | + | private int locationCount; |
// Stem only | // Stem only | ||
− | private KdTree<T> left, right; | + | private KdTree<T> left, right; |
− | private int splitDimension; | + | private int splitDimension; |
− | private double splitValue; | + | private double splitValue; |
// Bounds | // Bounds | ||
− | private double[] minLimit, maxLimit; | + | private double[] minLimit, maxLimit; |
− | private boolean singularity; | + | private boolean singularity; |
// Temporary | // Temporary | ||
− | private Status status; | + | private Status status; |
− | |||
− | |||
− | |||
− | |||
− | |||
− | |||
− | |||
/** | /** | ||
Line 76: | Line 68: | ||
* maxiumum size (after which it throws away old points) | * maxiumum size (after which it throws away old points) | ||
*/ | */ | ||
− | + | private KdTree(int dimensions, Integer sizeLimit) { | |
this.dimensions = dimensions; | this.dimensions = dimensions; | ||
Line 86: | Line 78: | ||
// Init as root | // Init as root | ||
this.map = new HashMap<Object, T>(); | this.map = new HashMap<Object, T>(); | ||
− | |||
− | |||
this.parent = null; | this.parent = null; | ||
this.sizeLimit = sizeLimit; | this.sizeLimit = sizeLimit; | ||
Line 129: | Line 119: | ||
KdTree<T> cursor = this; | KdTree<T> cursor = this; | ||
− | while (cursor.locations == null || cursor.locationCount >= | + | while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) { |
− | |||
if (cursor.locations != null) { | if (cursor.locations != null) { | ||
− | cursor.splitDimension = cursor.findWidestAxis( | + | cursor.splitDimension = cursor.findWidestAxis(); |
cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5; | cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5; | ||
Line 146: | Line 135: | ||
} | } | ||
− | // Don't split node if it has no width in any axis. Double the bucket size instead | + | // Don't split node if it has no width in any axis. Double the |
+ | // bucket size instead | ||
if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) { | if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) { | ||
double[][] newLocations = new double[cursor.locations.length * 2][]; | double[][] newLocations = new double[cursor.locations.length * 2][]; | ||
Line 154: | Line 144: | ||
} | } | ||
− | // Don't let the split value be the same as the upper value as can happen due to rounding errors! | + | // Don't let the split value be the same as the upper value as |
+ | // can happen due to rounding errors! | ||
if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) { | if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) { | ||
cursor.splitValue = cursor.minLimit[cursor.splitDimension]; | cursor.splitValue = cursor.minLimit[cursor.splitDimension]; | ||
Line 160: | Line 151: | ||
// Create child leaves | // Create child leaves | ||
− | KdTree<T> left = new | + | KdTree<T> left = new ChildNode(cursor, false); |
− | KdTree<T> right = new | + | KdTree<T> right = new ChildNode(cursor, true); |
// Move locations into children | // Move locations into children | ||
Line 221: | Line 212: | ||
} | } | ||
− | for (int i=0; i<dimensions; i++) { | + | for (int i = 0; i < dimensions; i++) { |
if (Double.isNaN(location[i])) { | if (Double.isNaN(location[i])) { | ||
minLimit[i] = Double.NaN; | minLimit[i] = Double.NaN; | ||
Line 241: | Line 232: | ||
* Find the widest axis of the bounds of this node | * Find the widest axis of the bounds of this node | ||
*/ | */ | ||
− | private final int findWidestAxis( | + | private final int findWidestAxis() { |
int widest = 0; | int widest = 0; | ||
− | double width = (maxLimit[0] - minLimit[0]) * | + | double width = (maxLimit[0] - minLimit[0]) * getAxisWeightHint(0); |
if (Double.isNaN(width)) width = 0; | if (Double.isNaN(width)) width = 0; | ||
for (int i = 1; i < dimensions; i++) { | for (int i = 1; i < dimensions; i++) { | ||
− | double nwidth = (maxLimit[i] - minLimit[i]) * | + | double nwidth = (maxLimit[i] - minLimit[i]) * getAxisWeightHint(i); |
if (Double.isNaN(nwidth)) nwidth = 0; | if (Double.isNaN(nwidth)) nwidth = 0; | ||
if (nwidth > width) { | if (nwidth > width) { | ||
Line 257: | Line 248: | ||
/** | /** | ||
− | * Remove the oldest value from the tree. | + | * Remove the oldest value from the tree. Note: This cannot trim the bounds |
− | + | * of nodes, nor empty nodes, and thus you can't expect it to perfectly | |
− | + | * preserve the speed of the tree as you keep adding. | |
− | |||
− | |||
*/ | */ | ||
private void removeOld() { | private void removeOld() { | ||
Line 280: | Line 269: | ||
} | } | ||
− | for (int i=0; i<cursor.locationCount; i++) { | + | for (int i = 0; i < cursor.locationCount; i++) { |
if (cursor.locations[i] == location) { | if (cursor.locations[i] == location) { | ||
− | System.arraycopy(cursor.locations, i+1, cursor.locations, i, cursor.locationCount - i - 1); | + | System.arraycopy(cursor.locations, i + 1, cursor.locations, i, cursor.locationCount - i - 1); |
do { | do { | ||
cursor.locationCount--; | cursor.locationCount--; | ||
Line 290: | Line 279: | ||
} | } | ||
} | } | ||
− | // If we got here... we couldn't find the value to remove. Weird... | + | // If we got here... we couldn't find the value to remove. Weird... |
− | |||
− | |||
− | |||
− | |||
− | |||
− | |||
− | |||
} | } | ||
/** | /** | ||
− | * Enumeration representing the status of a node during the running | + | * Enumeration representing the status of a node during the running |
*/ | */ | ||
private static enum Status { | private static enum Status { | ||
− | NONE, | + | NONE, LEFTVISITED, RIGHTVISITED, ALLVISITED |
− | |||
− | |||
− | |||
} | } | ||
Line 315: | Line 294: | ||
public static class Entry<T> { | public static class Entry<T> { | ||
public final double distance; | public final double distance; | ||
− | public final T value; | + | public final T value; |
+ | |||
private Entry(double distance, T value) { | private Entry(double distance, T value) { | ||
this.distance = distance; | this.distance = distance; | ||
Line 329: | Line 309: | ||
cursor.status = Status.NONE; | cursor.status = Status.NONE; | ||
double range = Double.POSITIVE_INFINITY; | double range = Double.POSITIVE_INFINITY; | ||
− | ResultHeap resultHeap = new ResultHeap(count); | + | ResultHeap resultHeap = new ResultHeap(count); |
do { | do { | ||
Line 342: | Line 322: | ||
if (cursor.locationCount > 0) { | if (cursor.locationCount > 0) { | ||
if (cursor.singularity) { | if (cursor.singularity) { | ||
− | double dist = | + | double dist = pointDist(cursor.locations[0], location); |
if (dist <= range) { | if (dist <= range) { | ||
− | for (int i=0; i<cursor.locationCount; i++) { | + | for (int i = 0; i < cursor.locationCount; i++) { |
resultHeap.addValue(dist, cursor.locations[i]); | resultHeap.addValue(dist, cursor.locations[i]); | ||
} | } | ||
Line 350: | Line 330: | ||
} | } | ||
else { | else { | ||
− | for (int i=0; i<cursor.locationCount; i++) { | + | for (int i = 0; i < cursor.locationCount; i++) { |
− | double dist = | + | double dist = pointDist(cursor.locations[i], location); |
resultHeap.addValue(dist, cursor.locations[i]); | resultHeap.addValue(dist, cursor.locations[i]); | ||
} | } | ||
Line 391: | Line 371: | ||
} | } | ||
− | // Check if it's worth descending. Assume it is if it's sibling has not been visited yet. | + | // Check if it's worth descending. Assume it is if it's sibling has |
+ | // not been visited yet. | ||
if (cursor.status == Status.ALLVISITED) { | if (cursor.status == Status.ALLVISITED) { | ||
− | if (nextCursor.locationCount == 0 || (!nextCursor.singularity && | + | if (nextCursor.locationCount == 0 |
+ | || (!nextCursor.singularity && pointRegionDist(location, nextCursor.minLimit, | ||
+ | nextCursor.maxLimit) > range)) { | ||
continue; | continue; | ||
} | } | ||
Line 412: | Line 395: | ||
} | } | ||
else { | else { | ||
− | for (int i=0; i<resultHeap.values; i++) { | + | for (int i = 0; i < resultHeap.values; i++) { |
T value = this.map.get(resultHeap.data[i]); | T value = this.map.get(resultHeap.data[i]); | ||
results.add(new Entry<T>(resultHeap.distance[i], value)); | results.add(new Entry<T>(resultHeap.distance[i], value)); | ||
Line 419: | Line 402: | ||
return results; | return results; | ||
+ | } | ||
+ | |||
+ | // Override in subclasses | ||
+ | protected abstract double pointDist(double[] p1, double[] p2); | ||
+ | |||
+ | protected abstract double pointRegionDist(double[] point, double[] min, double[] max); | ||
+ | |||
+ | protected double getAxisWeightHint(int i) { | ||
+ | return 1.0; | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Internal class for child nodes | ||
+ | */ | ||
+ | private class ChildNode extends KdTree<T> { | ||
+ | private ChildNode(KdTree<T> parent, boolean right) { | ||
+ | super(parent, right); | ||
+ | } | ||
+ | |||
+ | // Distance measurements are always called from the root node | ||
+ | protected double pointDist(double[] p1, double[] p2) { | ||
+ | throw new IllegalStateException(); | ||
+ | } | ||
+ | |||
+ | protected double pointRegionDist(double[] point, double[] min, double[] max) { | ||
+ | throw new IllegalStateException(); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Class for tree with Weighted Squared Elucidian distancing | ||
+ | */ | ||
+ | public static class WeightedSqrElucid<T> extends KdTree<T> { | ||
+ | private double[] weights; | ||
+ | |||
+ | public WeightedSqrElucid(int dimensions, Integer sizeLimit) { | ||
+ | super(dimensions, sizeLimit); | ||
+ | this.weights = new double[dimensions]; | ||
+ | Arrays.fill(this.weights, 1.0); | ||
+ | } | ||
+ | |||
+ | public void setWeights(double[] weights) { | ||
+ | this.weights = weights; | ||
+ | } | ||
+ | |||
+ | protected double getAxisWeightHint(int i) { | ||
+ | return weights[i]; | ||
+ | } | ||
+ | |||
+ | protected double pointDist(double[] p1, double[] p2) { | ||
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < p1.length; i++) { | ||
+ | double diff = (p1[i] - p2[i]) * weights[i]; | ||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff * diff; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
+ | |||
+ | protected double pointRegionDist(double[] point, double[] min, double[] max) { | ||
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < point.length; i++) { | ||
+ | double diff = 0; | ||
+ | if (point[i] > max[i]) { | ||
+ | diff = (point[i] - max[i]) * weights[i]; | ||
+ | } | ||
+ | else if (point[i] < min[i]) { | ||
+ | diff = (point[i] - min[i]) * weights[i]; | ||
+ | } | ||
+ | |||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff * diff; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
} | } | ||
/** | /** | ||
− | * | + | * Class for tree with Unweighted Squared Elucidian distancing |
*/ | */ | ||
− | + | public static class SqrElucid<T> extends KdTree<T> { | |
− | + | public SqrElucid(int dimensions, Integer sizeLimit) { | |
+ | super(dimensions, sizeLimit); | ||
+ | } | ||
− | for (int i=0; i<p1.length; i++) { | + | protected double pointDist(double[] p1, double[] p2) { |
− | + | double d = 0; | |
− | + | ||
− | + | for (int i = 0; i < p1.length; i++) { | |
+ | double diff = (p1[i] - p2[i]); | ||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff * diff; | ||
+ | } | ||
} | } | ||
+ | |||
+ | return d; | ||
} | } | ||
− | return d; | + | protected double pointRegionDist(double[] point, double[] min, double[] max) { |
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < point.length; i++) { | ||
+ | double diff = 0; | ||
+ | if (point[i] > max[i]) { | ||
+ | diff = (point[i] - max[i]); | ||
+ | } | ||
+ | else if (point[i] < min[i]) { | ||
+ | diff = (point[i] - min[i]); | ||
+ | } | ||
+ | |||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff * diff; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
} | } | ||
/** | /** | ||
− | * | + | * Class for tree with Weighted Manhattan distancing |
*/ | */ | ||
− | private | + | public static class WeightedManhattan<T> extends KdTree<T> { |
− | + | private double[] weights; | |
+ | |||
+ | public WeightedManhattan(int dimensions, Integer sizeLimit) { | ||
+ | super(dimensions, sizeLimit); | ||
+ | this.weights = new double[dimensions]; | ||
+ | Arrays.fill(this.weights, 1.0); | ||
+ | } | ||
+ | |||
+ | public void setWeights(double[] weights) { | ||
+ | this.weights = weights; | ||
+ | } | ||
+ | |||
+ | protected double getAxisWeightHint(int i) { | ||
+ | return weights[i]; | ||
+ | } | ||
+ | |||
+ | protected double pointDist(double[] p1, double[] p2) { | ||
+ | double d = 0; | ||
− | + | for (int i = 0; i < p1.length; i++) { | |
− | + | double diff = (p1[i] - p2[i]); | |
− | + | if (!Double.isNaN(diff)) { | |
− | + | d += ((diff < 0) ? -diff : diff) * weights[i]; | |
− | + | } | |
− | |||
} | } | ||
− | if (!Double.isNaN(diff)) { | + | return d; |
− | + | } | |
+ | |||
+ | protected double pointRegionDist(double[] point, double[] min, double[] max) { | ||
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < point.length; i++) { | ||
+ | double diff = 0; | ||
+ | if (point[i] > max[i]) { | ||
+ | diff = (point[i] - max[i]); | ||
+ | } | ||
+ | else if (point[i] < min[i]) { | ||
+ | diff = (min[i] - point[i]); | ||
+ | } | ||
+ | |||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff * weights[i]; | ||
+ | } | ||
} | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Class for tree with Manhattan distancing | ||
+ | */ | ||
+ | public static class Manhattan<T> extends KdTree<T> { | ||
+ | public Manhattan(int dimensions, Integer sizeLimit) { | ||
+ | super(dimensions, sizeLimit); | ||
} | } | ||
− | return d; | + | protected double pointDist(double[] p1, double[] p2) { |
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < p1.length; i++) { | ||
+ | double diff = (p1[i] - p2[i]); | ||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += (diff < 0) ? -diff : diff; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
+ | |||
+ | protected double pointRegionDist(double[] point, double[] min, double[] max) { | ||
+ | double d = 0; | ||
+ | |||
+ | for (int i = 0; i < point.length; i++) { | ||
+ | double diff = 0; | ||
+ | if (point[i] > max[i]) { | ||
+ | diff = (point[i] - max[i]); | ||
+ | } | ||
+ | else if (point[i] < min[i]) { | ||
+ | diff = (min[i] - point[i]); | ||
+ | } | ||
+ | |||
+ | if (!Double.isNaN(diff)) { | ||
+ | d += diff; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | return d; | ||
+ | } | ||
} | } | ||
Line 465: | Line 629: | ||
private final Object[] data; | private final Object[] data; | ||
private final double[] distance; | private final double[] distance; | ||
− | private final int size; | + | private final int size; |
− | private int values; | + | private int values; |
− | public Object removedData; | + | public Object removedData; |
− | public double removedDist; | + | public double removedDist; |
public ResultHeap(int size) { | public ResultHeap(int size) { | ||
Line 486: | Line 650: | ||
values++; | values++; | ||
} | } | ||
− | // If there is no room left in the heap, and the new entry is lower than the max entry | + | // If there is no room left in the heap, and the new entry is lower |
+ | // than the max entry | ||
else if (dist < distance[0]) { | else if (dist < distance[0]) { | ||
// Replace the max entry with the new entry | // Replace the max entry with the new entry | ||
Line 509: | Line 674: | ||
private void upHeapify(int c) { | private void upHeapify(int c) { | ||
− | for (int p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) { | + | for (int p = (c - 1) / 2; c != 0 && distance[c] > distance[p]; c = p, p = (c - 1) / 2) { |
Object pData = data[p]; | Object pData = data[p]; | ||
double pDist = distance[p]; | double pDist = distance[p]; | ||
Line 520: | Line 685: | ||
private void downHeapify(int p) { | private void downHeapify(int p) { | ||
− | for (int c = p*2+1; c < values; p = c,c = p*2+1) { | + | for (int c = p * 2 + 1; c < values; p = c, c = p * 2 + 1) { |
− | if (c+1 < values && distance[c] < distance[c+1]) { | + | if (c + 1 < values && distance[c] < distance[c + 1]) { |
c++; | c++; | ||
} | } |
Revision as of 06:40, 2 September 2009
A nice efficent small kD-Tree. It's quite fast... Feel free to use
/**
* Copyright 2009 Rednaxela
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
*
* 2. This notice may not be removed or altered from any source
* distribution.
*/
package ags.utils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
/**
* An efficient well-optimized kd-tree
*
* @author Rednaxela
*/
public abstract class KdTree<T> {
// Static variables
private static final int bucketSize = 32;
// All types
private final int dimensions;
private final KdTree<T> parent;
// Root only
private final HashMap<Object, T> map;
private final LinkedList<double[]> locationStack;
private final Integer sizeLimit;
// Leaf only
private double[][] locations;
private int locationCount;
// Stem only
private KdTree<T> left, right;
private int splitDimension;
private double splitValue;
// Bounds
private double[] minLimit, maxLimit;
private boolean singularity;
// Temporary
private Status status;
/**
* Construct a KdTree with a given number of dimensions and a limit on
* maxiumum size (after which it throws away old points)
*/
private KdTree(int dimensions, Integer sizeLimit) {
this.dimensions = dimensions;
// Init as leaf
this.locations = new double[bucketSize][];
this.locationCount = 0;
this.singularity = true;
// Init as root
this.map = new HashMap<Object, T>();
this.parent = null;
this.sizeLimit = sizeLimit;
if (sizeLimit != null) {
this.locationStack = new LinkedList<double[]>();
}
else {
this.locationStack = null;
}
}
/**
* Constructor for child nodes. Internal use only.
*/
private KdTree(KdTree<T> parent, boolean right) {
this.dimensions = parent.dimensions;
// Init as leaf
this.locations = new double[Math.max(bucketSize, parent.locationCount)][];
this.locationCount = 0;
this.singularity = true;
// Init as non-root
this.map = null;
this.parent = parent;
this.locationStack = null;
this.sizeLimit = null;
}
/**
* Get the number of points in the tree
*/
public int size() {
return locationCount;
}
/**
* Add a point and associated value to the tree
*/
public void addPoint(double[] location, T value) {
KdTree<T> cursor = this;
while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) {
if (cursor.locations != null) {
cursor.splitDimension = cursor.findWidestAxis();
cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;
// Never split on infinity or NaN
if (cursor.splitValue == Double.POSITIVE_INFINITY) {
cursor.splitValue = Double.MAX_VALUE;
}
else if (cursor.splitValue == Double.NEGATIVE_INFINITY) {
cursor.splitValue = -Double.MAX_VALUE;
}
else if (Double.isNaN(cursor.splitValue)) {
cursor.splitValue = 0;
}
// Don't split node if it has no width in any axis. Double the
// bucket size instead
if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) {
double[][] newLocations = new double[cursor.locations.length * 2][];
System.arraycopy(cursor.locations, 0, newLocations, 0, cursor.locationCount);
cursor.locations = newLocations;
break;
}
// Don't let the split value be the same as the upper value as
// can happen due to rounding errors!
if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) {
cursor.splitValue = cursor.minLimit[cursor.splitDimension];
}
// Create child leaves
KdTree<T> left = new ChildNode(cursor, false);
KdTree<T> right = new ChildNode(cursor, true);
// Move locations into children
for (double[] oldLocation : cursor.locations) {
if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
// Right
right.locations[right.locationCount] = oldLocation;
right.locationCount++;
right.extendBounds(oldLocation);
}
else {
// Left
left.locations[left.locationCount] = oldLocation;
left.locationCount++;
left.extendBounds(oldLocation);
}
}
// Make into stem
cursor.left = left;
cursor.right = right;
cursor.locations = null;
}
cursor.locationCount++;
cursor.extendBounds(location);
if (location[cursor.splitDimension] > cursor.splitValue) {
cursor = cursor.right;
}
else {
cursor = cursor.left;
}
}
cursor.locations[cursor.locationCount] = location;
cursor.locationCount++;
cursor.extendBounds(location);
this.map.put(location, value);
if (this.sizeLimit != null) {
this.locationStack.add(location);
if (this.locationCount > this.sizeLimit) {
this.removeOld();
}
}
}
/**
* Extends the bounds of this node do include a new location
*/
private final void extendBounds(double[] location) {
if (minLimit == null) {
minLimit = new double[dimensions];
System.arraycopy(location, 0, minLimit, 0, dimensions);
maxLimit = new double[dimensions];
System.arraycopy(location, 0, maxLimit, 0, dimensions);
return;
}
for (int i = 0; i < dimensions; i++) {
if (Double.isNaN(location[i])) {
minLimit[i] = Double.NaN;
maxLimit[i] = Double.NaN;
singularity = false;
}
else if (minLimit[i] > location[i]) {
minLimit[i] = location[i];
singularity = false;
}
else if (maxLimit[i] < location[i]) {
maxLimit[i] = location[i];
singularity = false;
}
}
}
/**
* Find the widest axis of the bounds of this node
*/
private final int findWidestAxis() {
int widest = 0;
double width = (maxLimit[0] - minLimit[0]) * getAxisWeightHint(0);
if (Double.isNaN(width)) width = 0;
for (int i = 1; i < dimensions; i++) {
double nwidth = (maxLimit[i] - minLimit[i]) * getAxisWeightHint(i);
if (Double.isNaN(nwidth)) nwidth = 0;
if (nwidth > width) {
widest = i;
width = nwidth;
}
}
return widest;
}
/**
* Remove the oldest value from the tree. Note: This cannot trim the bounds
* of nodes, nor empty nodes, and thus you can't expect it to perfectly
* preserve the speed of the tree as you keep adding.
*/
private void removeOld() {
double[] location = this.locationStack.removeFirst();
KdTree<T> cursor = this;
// Remove from the HashMap
this.map.remove(location);
// Find the node where the point is
while (cursor.locations == null) {
if (location[cursor.splitDimension] > cursor.splitValue) {
cursor = cursor.right;
}
else {
cursor = cursor.left;
}
}
for (int i = 0; i < cursor.locationCount; i++) {
if (cursor.locations[i] == location) {
System.arraycopy(cursor.locations, i + 1, cursor.locations, i, cursor.locationCount - i - 1);
do {
cursor.locationCount--;
cursor = cursor.parent;
} while (cursor.parent != null);
return;
}
}
// If we got here... we couldn't find the value to remove. Weird...
}
/**
* Enumeration representing the status of a node during the running
*/
private static enum Status {
NONE, LEFTVISITED, RIGHTVISITED, ALLVISITED
}
/**
* Stores a distance and value to output
*/
public static class Entry<T> {
public final double distance;
public final T value;
private Entry(double distance, T value) {
this.distance = distance;
this.value = value;
}
}
/**
* Calculates the nearest 'count' points to 'location'
*/
public List<Entry<T>> nearestNeighbor(double[] location, int count, boolean sequentialSorting) {
KdTree<T> cursor = this;
cursor.status = Status.NONE;
double range = Double.POSITIVE_INFINITY;
ResultHeap resultHeap = new ResultHeap(count);
do {
if (cursor.status == Status.ALLVISITED) {
// At a fully visited part. Move up the tree
cursor = cursor.parent;
continue;
}
if (cursor.status == Status.NONE && cursor.locations != null) {
// At a leaf. Use the data.
if (cursor.locationCount > 0) {
if (cursor.singularity) {
double dist = pointDist(cursor.locations[0], location);
if (dist <= range) {
for (int i = 0; i < cursor.locationCount; i++) {
resultHeap.addValue(dist, cursor.locations[i]);
}
}
}
else {
for (int i = 0; i < cursor.locationCount; i++) {
double dist = pointDist(cursor.locations[i], location);
resultHeap.addValue(dist, cursor.locations[i]);
}
}
range = resultHeap.getMaxDist();
}
if (cursor.parent == null) {
break;
}
cursor = cursor.parent;
continue;
}
// Going to descend
KdTree<T> nextCursor = null;
if (cursor.status == Status.NONE) {
// At a fresh node, descend the most probably useful direction
if (location[cursor.splitDimension] > cursor.splitValue) {
// Descend right
nextCursor = cursor.right;
cursor.status = Status.RIGHTVISITED;
}
else {
// Descend left;
nextCursor = cursor.left;
cursor.status = Status.LEFTVISITED;
}
}
else if (cursor.status == Status.LEFTVISITED) {
// Left node visited, descend right.
nextCursor = cursor.right;
cursor.status = Status.ALLVISITED;
}
else if (cursor.status == Status.RIGHTVISITED) {
// Right node visited, descend left.
nextCursor = cursor.left;
cursor.status = Status.ALLVISITED;
}
// Check if it's worth descending. Assume it is if it's sibling has
// not been visited yet.
if (cursor.status == Status.ALLVISITED) {
if (nextCursor.locationCount == 0
|| (!nextCursor.singularity && pointRegionDist(location, nextCursor.minLimit,
nextCursor.maxLimit) > range)) {
continue;
}
}
// Descend down the tree
cursor = nextCursor;
cursor.status = Status.NONE;
} while (cursor.parent != null || cursor.status != Status.ALLVISITED);
ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(resultHeap.values);
if (sequentialSorting) {
while (resultHeap.values > 0) {
resultHeap.removeLargest();
T value = this.map.get(resultHeap.removedData);
results.add(new Entry<T>(resultHeap.removedDist, value));
}
}
else {
for (int i = 0; i < resultHeap.values; i++) {
T value = this.map.get(resultHeap.data[i]);
results.add(new Entry<T>(resultHeap.distance[i], value));
}
}
return results;
}
// Override in subclasses
protected abstract double pointDist(double[] p1, double[] p2);
protected abstract double pointRegionDist(double[] point, double[] min, double[] max);
protected double getAxisWeightHint(int i) {
return 1.0;
}
/**
* Internal class for child nodes
*/
private class ChildNode extends KdTree<T> {
private ChildNode(KdTree<T> parent, boolean right) {
super(parent, right);
}
// Distance measurements are always called from the root node
protected double pointDist(double[] p1, double[] p2) {
throw new IllegalStateException();
}
protected double pointRegionDist(double[] point, double[] min, double[] max) {
throw new IllegalStateException();
}
}
/**
* Class for tree with Weighted Squared Elucidian distancing
*/
public static class WeightedSqrElucid<T> extends KdTree<T> {
private double[] weights;
public WeightedSqrElucid(int dimensions, Integer sizeLimit) {
super(dimensions, sizeLimit);
this.weights = new double[dimensions];
Arrays.fill(this.weights, 1.0);
}
public void setWeights(double[] weights) {
this.weights = weights;
}
protected double getAxisWeightHint(int i) {
return weights[i];
}
protected double pointDist(double[] p1, double[] p2) {
double d = 0;
for (int i = 0; i < p1.length; i++) {
double diff = (p1[i] - p2[i]) * weights[i];
if (!Double.isNaN(diff)) {
d += diff * diff;
}
}
return d;
}
protected double pointRegionDist(double[] point, double[] min, double[] max) {
double d = 0;
for (int i = 0; i < point.length; i++) {
double diff = 0;
if (point[i] > max[i]) {
diff = (point[i] - max[i]) * weights[i];
}
else if (point[i] < min[i]) {
diff = (point[i] - min[i]) * weights[i];
}
if (!Double.isNaN(diff)) {
d += diff * diff;
}
}
return d;
}
}
/**
* Class for tree with Unweighted Squared Elucidian distancing
*/
public static class SqrElucid<T> extends KdTree<T> {
public SqrElucid(int dimensions, Integer sizeLimit) {
super(dimensions, sizeLimit);
}
protected double pointDist(double[] p1, double[] p2) {
double d = 0;
for (int i = 0; i < p1.length; i++) {
double diff = (p1[i] - p2[i]);
if (!Double.isNaN(diff)) {
d += diff * diff;
}
}
return d;
}
protected double pointRegionDist(double[] point, double[] min, double[] max) {
double d = 0;
for (int i = 0; i < point.length; i++) {
double diff = 0;
if (point[i] > max[i]) {
diff = (point[i] - max[i]);
}
else if (point[i] < min[i]) {
diff = (point[i] - min[i]);
}
if (!Double.isNaN(diff)) {
d += diff * diff;
}
}
return d;
}
}
/**
* Class for tree with Weighted Manhattan distancing
*/
public static class WeightedManhattan<T> extends KdTree<T> {
private double[] weights;
public WeightedManhattan(int dimensions, Integer sizeLimit) {
super(dimensions, sizeLimit);
this.weights = new double[dimensions];
Arrays.fill(this.weights, 1.0);
}
public void setWeights(double[] weights) {
this.weights = weights;
}
protected double getAxisWeightHint(int i) {
return weights[i];
}
protected double pointDist(double[] p1, double[] p2) {
double d = 0;
for (int i = 0; i < p1.length; i++) {
double diff = (p1[i] - p2[i]);
if (!Double.isNaN(diff)) {
d += ((diff < 0) ? -diff : diff) * weights[i];
}
}
return d;
}
protected double pointRegionDist(double[] point, double[] min, double[] max) {
double d = 0;
for (int i = 0; i < point.length; i++) {
double diff = 0;
if (point[i] > max[i]) {
diff = (point[i] - max[i]);
}
else if (point[i] < min[i]) {
diff = (min[i] - point[i]);
}
if (!Double.isNaN(diff)) {
d += diff * weights[i];
}
}
return d;
}
}
/**
* Class for tree with Manhattan distancing
*/
public static class Manhattan<T> extends KdTree<T> {
public Manhattan(int dimensions, Integer sizeLimit) {
super(dimensions, sizeLimit);
}
protected double pointDist(double[] p1, double[] p2) {
double d = 0;
for (int i = 0; i < p1.length; i++) {
double diff = (p1[i] - p2[i]);
if (!Double.isNaN(diff)) {
d += (diff < 0) ? -diff : diff;
}
}
return d;
}
protected double pointRegionDist(double[] point, double[] min, double[] max) {
double d = 0;
for (int i = 0; i < point.length; i++) {
double diff = 0;
if (point[i] > max[i]) {
diff = (point[i] - max[i]);
}
else if (point[i] < min[i]) {
diff = (min[i] - point[i]);
}
if (!Double.isNaN(diff)) {
d += diff;
}
}
return d;
}
}
/**
* Class for tracking up to 'size' closest values
*/
private static class ResultHeap {
private final Object[] data;
private final double[] distance;
private final int size;
private int values;
public Object removedData;
public double removedDist;
public ResultHeap(int size) {
this.data = new Object[size];
this.distance = new double[size];
this.size = size;
this.values = 0;
}
public void addValue(double dist, Object value) {
// If there is still room in the heap
if (values < size) {
// Insert new value at the end
data[values] = value;
distance[values] = dist;
upHeapify(values);
values++;
}
// If there is no room left in the heap, and the new entry is lower
// than the max entry
else if (dist < distance[0]) {
// Replace the max entry with the new entry
data[0] = value;
distance[0] = dist;
downHeapify(0);
}
}
public void removeLargest() {
if (values == 0) {
throw new IllegalStateException();
}
removedData = data[0];
removedDist = distance[0];
values--;
data[0] = data[values];
distance[0] = distance[values];
downHeapify(0);
}
private void upHeapify(int c) {
for (int p = (c - 1) / 2; c != 0 && distance[c] > distance[p]; c = p, p = (c - 1) / 2) {
Object pData = data[p];
double pDist = distance[p];
data[p] = data[c];
distance[p] = distance[c];
data[c] = pData;
distance[c] = pDist;
}
}
private void downHeapify(int p) {
for (int c = p * 2 + 1; c < values; p = c, c = p * 2 + 1) {
if (c + 1 < values && distance[c] < distance[c + 1]) {
c++;
}
if (distance[p] < distance[c]) {
// Swap the points
Object pData = data[p];
double pDist = distance[p];
data[p] = data[c];
distance[p] = distance[c];
data[c] = pData;
distance[c] = pDist;
}
else {
break;
}
}
}
public double getMaxDist() {
if (values < size) {
return Double.POSITIVE_INFINITY;
}
return distance[0];
}
}
}