Difference between revisions of "User:Rednaxela/kD-Tree"

From Robowiki
Jump to navigation Jump to search
m (Remove some silly newlines that crept in)
(The repo now includes other contributions)
 
(38 intermediate revisions by 5 users not shown)
Line 1: Line 1:
A nice efficent small kD-Tree. It's quite fast... Feel free to use
+
A nice efficient small kD-Tree. Currently the fasted kD-Tree implementation on Robowiki. Feel free to use.
  
<code><pre>
+
== Plans ==
 +
 
 +
Right now I'm working a rewrite, intended to have cleaner code, follow Java convention better, and be at least as fast. Current plans for the rewrite are:
 +
* '''''Done!''''' <s>'''Cleaner code:''' Follow Java/OOP conventions better, since much that I abandoned in the below code was not necessary for speed.</s>
 +
* '''''Done!''''' <s>'''Nearest Neighbor Iterator:''' Provides an iterator to get nearest neighbor. This allows iterated fetching in case one doesn't know exactly how many neighbors one needs (i.e. if some are unusable data points due to other checks). Theoretical speed penalety should be very slim, perhaps even negligible.</s>
 +
* '''Further improved speed''': Yes, it's possible! Today I thought of three brand new techniques I should be able to use to increase speed further!
 +
:* '''''Done!''''' <s>'''Flexible path ordering:''' Since 'second choice' paths already have a full distance-to-bounding-box calculation done, why not use this information in order to check the 'paths not yet taken' based that computed distance rather than tree structure. Should be more optimal.</s>
 +
:* '''''Unsuccessful. No improvement.''''' <s>'''Dimension-pruned distance calculations:''' With real data, there is often a situation where within a particular node, only some of the dimensions differ between points. It should be simple to track these 'unused' dimensions in a particular node and use this to optimize the distance calculation.</s>
 +
:* '''Implicit Subtrees:''' I thought about how I'm using an array to store the 'bucket', and thought "wouldn't it be nice to not have to calculate the distance for every single point in the bucket..." Well, it turns out, that can be avoided, all while keeping it in the nice compact array! It's just a matter of turning the bucket arrays into [[wikipedia:Implicit kd-tree|implicit kd-trees]]! This should keep the advantages of the bucket system for making the incrementally created tree balanced, while at the same time being more efficient!
 +
 
 +
I also plan to explore:
 +
* [[wikipedia:R-tree|R-Tree]]/[[wikipedia:X-tree|X-Tree]] type structures. They allow n-ary trees instead of only 2-ary trees like kd-trees, are self-balancing. Might have good results.
 +
* [[wikipedia:VP-tree|VP-Tree]] type structures. Splits based on distance to points may be more effective perhaps.
 +
 
 +
If you have any comments on these plans, comments would be appreciated: [[User talk:Rednaxela/kD-Tree]]
 +
 
 +
== The Code ==
 +
My latest released (circa 2010) version of this tree, aka my "3rd gen" one, is [https://gitlab.com/agschultz/robocode-knn-benchmark/-/tree/master/ags/utils/dataStructures/trees/thirdGenKD now on Gitlab]. It supports a KNN iterator that can save you computational time if you aren't sure exactly how many points you will need. This version also includes some weighted distance functions from [[User:Tkiesel|Tkiesel]] in 2012, and and a bug fix by [[User:Xor|Xor]] from 2016.
 +
 
 +
(Looking at my old backups it also looks like I have some unreleased test performance optimization variants dating to Jul 2013, but not sure if they were fruitful)
 +
 
 +
== Old Code ==
 +
 
 +
My old "2nd gen" version of my tree is as follows. This is outdated and the above "3rd gen" version is recommended over it.
 +
 
 +
<code><syntaxhighlight>
 
/**
 
/**
 
  * Copyright 2009 Rednaxela
 
  * Copyright 2009 Rednaxela
Line 22: Line 47:
 
  */
 
  */
  
package ags.utils.newtree2;
+
package ags.utils;
  
 
import java.util.ArrayList;
 
import java.util.ArrayList;
 
import java.util.Arrays;
 
import java.util.Arrays;
import java.util.HashMap;
+
import java.util.LinkedList;
 
import java.util.List;
 
import java.util.List;
import java.util.Stack;
 
  
 
/**
 
/**
  * An efficent well-optimized kd-tree
+
  * An efficient well-optimized kd-tree
 
  *  
 
  *  
 
  * @author Rednaxela
 
  * @author Rednaxela
 
  */
 
  */
public class KdTree<T> {
+
public abstract class KdTree<T> {
// Static variables
+
    // Static variables
private static final int bucketSize = 32;
+
    private static final int           bucketSize = 24;
 +
 
 +
    // All types
 +
    private final int                  dimensions;
 +
    private final KdTree<T>            parent;
 +
 
 +
    // Root only
 +
    private final LinkedList<double[]> locationStack;
 +
    private final Integer              sizeLimit;
 +
 
 +
    // Leaf only
 +
    private double[][]                locations;
 +
    private Object[]                  data;
 +
    private int                        locationCount;
 +
 
 +
    // Stem only
 +
    private KdTree<T>                  left, right;
 +
    private int                        splitDimension;
 +
    private double                    splitValue;
 +
 
 +
    // Bounds
 +
    private double[]                  minLimit, maxLimit;
 +
    private boolean                    singularity;
 +
 
 +
    // Temporary
 +
    private Status                    status;
 +
 
 +
    /**
 +
    * Construct a KdTree with a given number of dimensions and a limit on
 +
    * maxiumum size (after which it throws away old points)
 +
    */
 +
    private KdTree(int dimensions, Integer sizeLimit) {
 +
        this.dimensions = dimensions;
 +
 
 +
        // Init as leaf
 +
        this.locations = new double[bucketSize][];
 +
        this.data = new Object[bucketSize];
 +
        this.locationCount = 0;
 +
        this.singularity = true;
 +
 
 +
        // Init as root
 +
        this.parent = null;
 +
        this.sizeLimit = sizeLimit;
 +
        if (sizeLimit != null) {
 +
            this.locationStack = new LinkedList<double[]>();
 +
        }
 +
        else {
 +
            this.locationStack = null;
 +
        }
 +
    }
 +
 
 +
    /**
 +
    * Constructor for child nodes. Internal use only.
 +
    */
 +
    private KdTree(KdTree<T> parent, boolean right) {
 +
        this.dimensions = parent.dimensions;
 +
 
 +
        // Init as leaf
 +
        this.locations = new double[Math.max(bucketSize, parent.locationCount)][];
 +
        this.data = new Object[Math.max(bucketSize, parent.locationCount)];
 +
        this.locationCount = 0;
 +
        this.singularity = true;
 +
 
 +
        // Init as non-root
 +
        this.parent = parent;
 +
        this.locationStack = null;
 +
        this.sizeLimit = null;
 +
    }
 +
 
 +
    /**
 +
    * Get the number of points in the tree
 +
    */
 +
    public int size() {
 +
        return locationCount;
 +
    }
 +
 
 +
    /**
 +
    * Add a point and associated value to the tree
 +
    */
 +
    public void addPoint(double[] location, T value) {
 +
        KdTree<T> cursor = this;
 +
 
 +
        while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) {
 +
            if (cursor.locations != null) {
 +
                cursor.splitDimension = cursor.findWidestAxis();
 +
                cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;
 +
 
 +
                // Never split on infinity or NaN
 +
                if (cursor.splitValue == Double.POSITIVE_INFINITY) {
 +
                    cursor.splitValue = Double.MAX_VALUE;
 +
                }
 +
                else if (cursor.splitValue == Double.NEGATIVE_INFINITY) {
 +
                    cursor.splitValue = -Double.MAX_VALUE;
 +
                }
 +
                else if (Double.isNaN(cursor.splitValue)) {
 +
                    cursor.splitValue = 0;
 +
                }
 +
 
 +
                // Don't split node if it has no width in any axis. Double the
 +
                // bucket size instead
 +
                if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) {
 +
                    double[][] newLocations = new double[cursor.locations.length * 2][];
 +
                    System.arraycopy(cursor.locations, 0, newLocations, 0, cursor.locationCount);
 +
                    cursor.locations = newLocations;
 +
                    Object[] newData = new Object[newLocations.length];
 +
                    System.arraycopy(cursor.data, 0, newData, 0, cursor.locationCount);
 +
                    cursor.data = newData;
 +
                    break;
 +
                }
 +
 
 +
                // Don't let the split value be the same as the upper value as
 +
                // can happen due to rounding errors!
 +
                if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) {
 +
                    cursor.splitValue = cursor.minLimit[cursor.splitDimension];
 +
                }
 +
 
 +
                // Create child leaves
 +
                KdTree<T> left = new ChildNode(cursor, false);
 +
                KdTree<T> right = new ChildNode(cursor, true);
 +
 
 +
                // Move locations into children
 +
                for (int i = 0; i < cursor.locationCount; i++) {
 +
                    double[] oldLocation = cursor.locations[i];
 +
                    Object oldData = cursor.data[i];
 +
                    if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
 +
                        // Right
 +
                        right.locations[right.locationCount] = oldLocation;
 +
                        right.data[right.locationCount] = oldData;
 +
                        right.locationCount++;
 +
                        right.extendBounds(oldLocation);
 +
                    }
 +
                    else {
 +
                        // Left
 +
                        left.locations[left.locationCount] = oldLocation;
 +
                        left.data[left.locationCount] = oldData;
 +
                        left.locationCount++;
 +
                        left.extendBounds(oldLocation);
 +
                    }
 +
                }
 +
 
 +
                // Make into stem
 +
                cursor.left = left;
 +
                cursor.right = right;
 +
                cursor.locations = null;
 +
                cursor.data = null;
 +
            }
 +
 
 +
            cursor.locationCount++;
 +
            cursor.extendBounds(location);
 +
 
 +
            if (location[cursor.splitDimension] > cursor.splitValue) {
 +
                cursor = cursor.right;
 +
            }
 +
            else {
 +
                cursor = cursor.left;
 +
            }
 +
        }
 +
 
 +
        cursor.locations[cursor.locationCount] = location;
 +
        cursor.data[cursor.locationCount] = value;
 +
        cursor.locationCount++;
 +
        cursor.extendBounds(location);
 +
 
 +
        if (this.sizeLimit != null) {
 +
            this.locationStack.add(location);
 +
            if (this.locationCount > this.sizeLimit) {
 +
                this.removeOld();
 +
            }
 +
        }
 +
    }
 +
 
 +
    /**
 +
    * Extends the bounds of this node do include a new location
 +
    */
 +
    private final void extendBounds(double[] location) {
 +
        if (minLimit == null) {
 +
            minLimit = new double[dimensions];
 +
            System.arraycopy(location, 0, minLimit, 0, dimensions);
 +
            maxLimit = new double[dimensions];
 +
            System.arraycopy(location, 0, maxLimit, 0, dimensions);
 +
            return;
 +
        }
 +
 
 +
        for (int i = 0; i < dimensions; i++) {
 +
            if (Double.isNaN(location[i])) {
 +
                minLimit[i] = Double.NaN;
 +
                maxLimit[i] = Double.NaN;
 +
                singularity = false;
 +
            }
 +
            else if (minLimit[i] > location[i]) {
 +
                minLimit[i] = location[i];
 +
                singularity = false;
 +
            }
 +
            else if (maxLimit[i] < location[i]) {
 +
                maxLimit[i] = location[i];
 +
                singularity = false;
 +
            }
 +
        }
 +
    }
 +
 
 +
    /**
 +
    * Find the widest axis of the bounds of this node
 +
    */
 +
    private final int findWidestAxis() {
 +
        int widest = 0;
 +
        double width = (maxLimit[0] - minLimit[0]) * getAxisWeightHint(0);
 +
        if (Double.isNaN(width)) width = 0;
 +
        for (int i = 1; i < dimensions; i++) {
 +
            double nwidth = (maxLimit[i] - minLimit[i]) * getAxisWeightHint(i);
 +
            if (Double.isNaN(nwidth)) nwidth = 0;
 +
            if (nwidth > width) {
 +
                widest = i;
 +
                width = nwidth;
 +
            }
 +
        }
 +
        return widest;
 +
    }
 +
 
 +
    /**
 +
    * Remove the oldest value from the tree. Note: This cannot trim the bounds
 +
    * of nodes, nor empty nodes, and thus you can't expect it to perfectly
 +
    * preserve the speed of the tree as you keep adding.
 +
    */
 +
    private void removeOld() {
 +
        double[] location = this.locationStack.removeFirst();
 +
        KdTree<T> cursor = this;
 +
 
 +
        // Find the node where the point is
 +
        while (cursor.locations == null) {
 +
            if (location[cursor.splitDimension] > cursor.splitValue) {
 +
                cursor = cursor.right;
 +
            }
 +
            else {
 +
                cursor = cursor.left;
 +
            }
 +
        }
 +
 
 +
        for (int i = 0; i < cursor.locationCount; i++) {
 +
            if (cursor.locations[i] == location) {
 +
                System.arraycopy(cursor.locations, i + 1, cursor.locations, i, cursor.locationCount - i - 1);
 +
                cursor.locations[cursor.locationCount-1] = null;
 +
                System.arraycopy(cursor.data, i + 1, cursor.data, i, cursor.locationCount - i - 1);
 +
                cursor.data[cursor.locationCount-1] = null;
 +
                do {
 +
                    cursor.locationCount--;
 +
                    cursor = cursor.parent;
 +
                } while (cursor != null);
 +
                return;
 +
            }
 +
        }
 +
        // If we got here... we couldn't find the value to remove. Weird...
 +
    }
 +
 
 +
    /**
 +
    * Enumeration representing the status of a node during the running
 +
    */
 +
    private static enum Status {
 +
        NONE, LEFTVISITED, RIGHTVISITED, ALLVISITED
 +
    }
 +
 
 +
    /**
 +
    * Stores a distance and value to output
 +
    */
 +
    public static class Entry<T> {
 +
        public final double distance;
 +
        public final T      value;
 +
 
 +
        private Entry(double distance, T value) {
 +
            this.distance = distance;
 +
            this.value = value;
 +
        }
 +
    }
  
// All types
+
    /**
private final int dimensions;
+
    * Calculates the nearest 'count' points to 'location'
 +
    */
 +
    @SuppressWarnings("unchecked")
 +
    public List<Entry<T>> nearestNeighbor(double[] location, int count, boolean sequentialSorting) {
 +
        KdTree<T> cursor = this;
 +
        cursor.status = Status.NONE;
 +
        double range = Double.POSITIVE_INFINITY;
 +
        ResultHeap resultHeap = new ResultHeap(count);
  
// Root only
+
        do {
private final HashMap<Object, T> map;
+
            if (cursor.status == Status.ALLVISITED) {
private double[] weights;
+
                // At a fully visited part. Move up the tree
 +
                cursor = cursor.parent;
 +
                continue;
 +
            }
  
// Leaf only
+
            if (cursor.status == Status.NONE && cursor.locations != null) {
private double[][] locations;
+
                // At a leaf. Use the data.
private int locationCount;
+
                if (cursor.locationCount > 0) {
 +
                    if (cursor.singularity) {
 +
                        double dist = pointDist(cursor.locations[0], location);
 +
                        if (dist <= range) {
 +
                            for (int i = 0; i < cursor.locationCount; i++) {
 +
                                resultHeap.addValue(dist, cursor.data[i]);
 +
                            }
 +
                        }
 +
                    }
 +
                    else {
 +
                        for (int i = 0; i < cursor.locationCount; i++) {
 +
                            double dist = pointDist(cursor.locations[i], location);
 +
                            resultHeap.addValue(dist, cursor.data[i]);
 +
                        }
 +
                    }
 +
                    range = resultHeap.getMaxDist();
 +
                }
  
// Stem only
+
                if (cursor.parent == null) {
private KdTree<T> left, right;
+
                    break;
private int splitDimension;
+
                }
private double splitValue;
+
                cursor = cursor.parent;
 +
                continue;
 +
            }
  
// Bounds
+
            // Going to descend
private double[] minLimit, maxLimit;
+
            KdTree<T> nextCursor = null;
 +
            if (cursor.status == Status.NONE) {
 +
                // At a fresh node, descend the most probably useful direction
 +
                if (location[cursor.splitDimension] > cursor.splitValue) {
 +
                    // Descend right
 +
                    nextCursor = cursor.right;
 +
                    cursor.status = Status.RIGHTVISITED;
 +
                }
 +
                else {
 +
                    // Descend left;
 +
                    nextCursor = cursor.left;
 +
                    cursor.status = Status.LEFTVISITED;
 +
                }
 +
            }
 +
            else if (cursor.status == Status.LEFTVISITED) {
 +
                // Left node visited, descend right.
 +
                nextCursor = cursor.right;
 +
                cursor.status = Status.ALLVISITED;
 +
            }
 +
            else if (cursor.status == Status.RIGHTVISITED) {
 +
                // Right node visited, descend left.
 +
                nextCursor = cursor.left;
 +
                cursor.status = Status.ALLVISITED;
 +
            }
  
/**
+
            // Check if it's worth descending. Assume it is if it's sibling has
* Extends the bounds of this node do include a new location
+
            // not been visited yet.
*/
+
            if (cursor.status == Status.ALLVISITED) {
private final void extendBounds(double[] location) {
+
                if (nextCursor.locationCount == 0
if (minLimit == null) {
+
                        || (!nextCursor.singularity && pointRegionDist(location, nextCursor.minLimit,
minLimit = Arrays.copyOf(location, dimensions);
+
                                nextCursor.maxLimit) > range)) {
maxLimit = Arrays.copyOf(location, dimensions);
+
                    continue;
return;
+
                }
}
+
            }
  
for (int i=0; i<dimensions; i++) {
+
            // Descend down the tree
if (minLimit[i] > location[i]) {
+
            cursor = nextCursor;
minLimit[i] = location[i];
+
            cursor.status = Status.NONE;
}
+
        } while (cursor.parent != null || cursor.status != Status.ALLVISITED);
if (maxLimit[i] < location[i]) {
 
maxLimit[i] = location[i];
 
}
 
}
 
}
 
  
/**
+
        ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(resultHeap.values);
* Find the widest axis of the bounds of this node
+
        if (sequentialSorting) {
*/
+
            while (resultHeap.values > 0) {
private final int findWidestAxis() {
+
                resultHeap.removeLargest();
int widest = 0;
+
                results.add(new Entry<T>(resultHeap.removedDist, (T)resultHeap.removedData));
double width = (maxLimit[0] - minLimit[0]);
+
            }
for (int i = 1; i < dimensions; i++) {
+
        }
double nwidth = maxLimit[i] - minLimit[i];
+
        else {
if (nwidth > width) {
+
            for (int i = 0; i < resultHeap.values; i++) {
widest = i;
+
                results.add(new Entry<T>(resultHeap.distance[i], (T)resultHeap.data[i]));
width = nwidth;
+
            }
}
+
        }
}
 
return widest;
 
}
 
  
// Main constructor
+
        return results;
public KdTree(int dimensions) {
+
    }
this.dimensions = dimensions;
 
  
// Init as leaf
+
    // Override in subclasses
this.locations = new double[bucketSize][];
+
    protected abstract double pointDist(double[] p1, double[] p2);
this.locationCount = 0;
 
  
// Init as root
+
    protected abstract double pointRegionDist(double[] point, double[] min, double[] max);
this.map = new HashMap<Object, T>();
 
this.weights = new double[dimensions];
 
Arrays.fill(this.weights, 1.0);
 
}
 
  
// Child constructor
+
    protected double getAxisWeightHint(int i) {
private KdTree(KdTree<T> parent, boolean right) {
+
        return 1.0;
this.dimensions = parent.dimensions;
+
    }
  
// Init as leaf
+
    /**
this.locations = new double[bucketSize][];
+
    * Internal class for child nodes
this.locationCount = 0;
+
    */
 +
    private class ChildNode extends KdTree<T> {
 +
        private ChildNode(KdTree<T> parent, boolean right) {
 +
            super(parent, right);
 +
        }
  
// Init as non-root
+
        // Distance measurements are always called from the root node
this.map = null;
+
        protected double pointDist(double[] p1, double[] p2) {
}
+
            throw new IllegalStateException();
 +
        }
  
/**
+
        protected double pointRegionDist(double[] point, double[] min, double[] max) {
* Add a point and associated value to the tree
+
            throw new IllegalStateException();
*/
+
        }
public static <T> void addPoint(KdTree<T> tree, double[] location, T
+
    }
value) {
 
KdTree<T> cursor = tree;
 
  
while (cursor.locations == null || cursor.locationCount >=
+
    /**
cursor.locations.length) {
+
    * Class for tree with Weighted Squared Euclidean distancing
if (cursor.locations != null) {
+
    */
cursor.splitDimension = cursor.findWidestAxis();
+
    public static class WeightedSqrEuclid<T> extends KdTree<T> {
cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
+
        private double[] weights;
cursor.maxLimit[cursor.splitDimension]) * 0.5;
 
  
// Don't split node if it has no width in any axis. Double the bucket size instead
+
        public WeightedSqrEuclid(int dimensions, Integer sizeLimit) {
if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) {
+
            super(dimensions, sizeLimit);
cursor.locations = Arrays.copyOf(cursor.locations,
+
            this.weights = new double[dimensions];
cursor.locations.length * 2);
+
            Arrays.fill(this.weights, 1.0);
break;
+
        }
}
 
  
// Create child leaves
+
        public void setWeights(double[] weights) {
KdTree<T> left = new KdTree<T>(cursor, false);
+
            this.weights = weights;
KdTree<T> right = new KdTree<T>(cursor, true);
+
        }
  
// Move locations into children
+
        protected double getAxisWeightHint(int i) {
for (double[] oldLocation : cursor.locations) {
+
            return weights[i];
if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
+
        }
// Right
 
right.locations[right.locationCount] = oldLocation;
 
right.locationCount++;
 
right.extendBounds(oldLocation);
 
}
 
else {
 
// Left
 
left.locations[left.locationCount] = oldLocation;
 
left.locationCount++;
 
left.extendBounds(oldLocation);
 
}
 
}
 
  
// Make into stem
+
        protected double pointDist(double[] p1, double[] p2) {
cursor.left = left;
+
            double d = 0;
cursor.right = right;
 
cursor.locations = null;
 
}
 
  
cursor.extendBounds(location);
+
            for (int i = 0; i < p1.length; i++) {
 +
                double diff = (p1[i] - p2[i]) * weights[i];
 +
                if (!Double.isNaN(diff)) {
 +
                    d += diff * diff;
 +
                }
 +
            }
  
if (location[cursor.splitDimension] > cursor.splitValue) {
+
            return d;
cursor = cursor.right;
+
        }
}
 
else {
 
cursor = cursor.left;
 
}
 
}
 
  
cursor.locations[cursor.locationCount] = location;
+
        protected double pointRegionDist(double[] point, double[] min, double[] max) {
cursor.locationCount++;
+
            double d = 0;
cursor.extendBounds(location);
 
  
tree.map.put(location, value);
+
            for (int i = 0; i < point.length; i++) {
}
+
                double diff = 0;
 +
                if (point[i] > max[i]) {
 +
                    diff = (point[i] - max[i]) * weights[i];
 +
                }
 +
                else if (point[i] < min[i]) {
 +
                    diff = (point[i] - min[i]) * weights[i];
 +
                }
  
/**
+
                if (!Double.isNaN(diff)) {
* Enumeration representing the status of a node during the running
+
                    d += diff * diff;
*/
+
                }
private static enum Status {
+
            }
NONE,
 
LEFTVISITED,
 
RIGHTVISITED,
 
ALLVISITED
 
}
 
  
/**
+
            return d;
* Stores a distance and value to output
+
        }
*/
+
    }
public static class Entry<T> {
 
public final double distance;
 
public final T value;
 
private Entry(double distance, T value) {
 
this.distance = distance;
 
this.value = value;
 
}
 
}
 
  
/**
+
    /**
* Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
+
    * Class for tree with Unweighted Squared Euclidean distancing
*/
+
    */
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
+
    public static class SqrEuclid<T> extends KdTree<T> {
double[] location, int count, double[] weights) {
+
        public SqrEuclid(int dimensions, Integer sizeLimit) {
tree.weights = weights;
+
            super(dimensions, sizeLimit);
return nearestNeighbor(tree, location, count);
+
        }
}
 
  
/**
+
        protected double pointDist(double[] p1, double[] p2) {
* Calculates the nearest 'count' points to 'location'
+
            double d = 0;
*/
 
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
 
double[] location, int count) {
 
KdTree<T> cursor = tree;
 
Status status = Status.NONE;
 
Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
 
Stack<Status> statusStack = new Stack<Status>();
 
double range = Double.POSITIVE_INFINITY;
 
ResultHeap resultHeap = new ResultHeap(count);  
 
  
do {
+
            for (int i = 0; i < p1.length; i++) {
if (status == Status.ALLVISITED) {
+
                double diff = (p1[i] - p2[i]);
// At a fully visited part. Move up the tree
+
                if (!Double.isNaN(diff)) {
cursor = stack.pop();
+
                    d += diff * diff;
status = statusStack.pop();
+
                }
continue;
+
            }
}
 
  
if (status == Status.NONE && cursor.locations != null) {
+
            return d;
// At a leaf. Use the data.
+
        }
for (int i=0; i<cursor.locationCount; i++) {
 
double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
 
resultHeap.addValue(dist, cursor.locations[i]);
 
}
 
range = resultHeap.getMaxDist();
 
  
if (stack.empty()) {
+
        protected double pointRegionDist(double[] point, double[] min, double[] max) {
break;
+
            double d = 0;
}
 
cursor = stack.pop();
 
status = statusStack.pop();
 
continue;
 
}
 
  
// Going to descend
+
            for (int i = 0; i < point.length; i++) {
KdTree<T> nextCursor = null;
+
                double diff = 0;
if (status == Status.NONE) {
+
                if (point[i] > max[i]) {
// At a fresh node, descend the most probably useful direction
+
                    diff = (point[i] - max[i]);
if (location[cursor.splitDimension] > cursor.splitValue) {
+
                }
// Descend right
+
                else if (point[i] < min[i]) {
nextCursor = cursor.right;
+
                    diff = (point[i] - min[i]);
status = Status.RIGHTVISITED;
+
                }
}
 
else {
 
// Descend left;
 
nextCursor = cursor.left;
 
status = Status.LEFTVISITED;
 
}
 
}
 
else if (status == Status.LEFTVISITED) {
 
// Left node visited, descend right.
 
nextCursor = cursor.right;
 
status = Status.ALLVISITED;
 
}
 
else if (status == Status.RIGHTVISITED) {
 
// Right node visited, descend left.
 
nextCursor = cursor.left;
 
status = Status.ALLVISITED;
 
}
 
  
// Check if it's worth descending. Assume it is if it's sibling has not been visited yet.
+
                if (!Double.isNaN(diff)) {
if (status == Status.ALLVISITED) {
+
                    d += diff * diff;
if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
+
                }
continue;
+
            }
}
 
}
 
  
// Descend down the tree
+
            return d;
stack.push(cursor);
+
        }
statusStack.push(status);
+
    }
cursor = nextCursor;
 
status = Status.NONE;
 
} while (stack.size() > 0 || status != Status.ALLVISITED);
 
  
ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
+
    /**
Object[] data = resultHeap.getData();
+
    * Class for tree with Weighted Manhattan distancing
double[] dist = resultHeap.getDistances();
+
    */
for (int i=0; i<resultHeap.values; i++) {
+
    public static class WeightedManhattan<T> extends KdTree<T> {
T value = tree.map.get(data[i]);
+
        private double[] weights;
results.add(new Entry<T>(dist[i], value));
 
}
 
  
return results;
+
        public WeightedManhattan(int dimensions, Integer sizeLimit) {
}
+
            super(dimensions, sizeLimit);
 +
            this.weights = new double[dimensions];
 +
            Arrays.fill(this.weights, 1.0);
 +
        }
  
/**
+
        public void setWeights(double[] weights) {
* Calculates the (squared euclidean) distance between two points
+
            this.weights = weights;
*/
+
        }
private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
 
double d = 0;
 
  
for (int i=0; i<p1.length; i++) {
+
        protected double getAxisWeightHint(int i) {
double diff = (p1[i] - p2[i]) * weights[i];
+
            return weights[i];
d += diff * diff;
+
        }
}
 
  
return d;
+
        protected double pointDist(double[] p1, double[] p2) {
}
+
            double d = 0;
  
/**
+
            for (int i = 0; i < p1.length; i++) {
* Calculates the closest (squared euclidean) distance between in a point and a bounding region
+
                double diff = (p1[i] - p2[i]);
*/
+
                if (!Double.isNaN(diff)) {
private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
+
                    d += ((diff < 0) ? -diff : diff) * weights[i];
double d = 0;
+
                }
 +
            }
  
for (int i=0; i<point.length; i++) {
+
            return d;
if (point[i] > max[i]) {
+
        }
double diff = (point[i] - max[i]) * weights[i];
 
d += diff * diff;
 
} else if (point[i] < min[i]) {
 
double diff = (point[i] - min[i]) * weights[i];
 
d += diff * diff;
 
}
 
}
 
  
return d;
+
        protected double pointRegionDist(double[] point, double[] min, double[] max) {
}
+
            double d = 0;
  
/**
+
            for (int i = 0; i < point.length; i++) {
* Class for tracking up to 'size' closest values
+
                double diff = 0;
*/
+
                if (point[i] > max[i]) {
private static class ResultHeap {
+
                    diff = (point[i] - max[i]);
private final Object[] data;
+
                }
private final double[] distance;
+
                else if (point[i] < min[i]) {
private final int size;
+
                    diff = (min[i] - point[i]);
private int values;
+
                }
  
public ResultHeap(int size) {
+
                if (!Double.isNaN(diff)) {
this.data = new Object[size+1];
+
                    d += diff * weights[i];
this.distance = new double[size+1];
+
                }
this.size = size;
+
            }
this.values = 0;
 
}
 
  
public void addValue(double dist, Object value) {
+
            return d;
if (values == size && dist >= distance[0]) {
+
        }
return;
+
    }
}
 
  
// Insert value
+
    /**
data[values] = value;
+
    * Class for tree with Manhattan distancing
distance[values] = dist;
+
    */
values++;
+
    public static class Manhattan<T> extends KdTree<T> {
 +
        public Manhattan(int dimensions, Integer sizeLimit) {
 +
            super(dimensions, sizeLimit);
 +
        }
  
// Up-Heapify
+
        protected double pointDist(double[] p1, double[] p2) {
for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
+
            double d = 0;
Object pData = data[p];
 
double pDist = distance[p];
 
data[p] = data[c];
 
distance[p] = distance[c];
 
data[c] = pData;
 
distance[c] = pDist;
 
}
 
  
// If too big, remove the highest value
+
            for (int i = 0; i < p1.length; i++) {
if (values > size) {
+
                double diff = (p1[i] - p2[i]);
// Move the last entry to the top
+
                if (!Double.isNaN(diff)) {
values--;
+
                    d += (diff < 0) ? -diff : diff;
data[0] = data[values];
+
                }
distance[0] = distance[values];
+
            }
  
// Down-Heapify
+
            return d;
for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
+
        }
if (c+1 < values && distance[c] < distance[c+1]) {
 
c++;
 
}
 
if (distance[p] < distance[c]) {
 
// Swap the points
 
Object pData = data[p];
 
double pDist = distance[p];
 
data[p] = data[c];
 
distance[p] = distance[c];
 
data[c] = pData;
 
distance[c] = pDist;
 
}
 
else {
 
break;
 
}
 
}
 
}
 
}
 
  
public double getMaxDist() {
+
        protected double pointRegionDist(double[] point, double[] min, double[] max) {
if (values < size) {
+
            double d = 0;
return Double.POSITIVE_INFINITY;
 
}
 
return distance[0];
 
}
 
  
public Object[] getData() {
+
            for (int i = 0; i < point.length; i++) {
return data;
+
                double diff = 0;
}
+
                if (point[i] > max[i]) {
 +
                    diff = (point[i] - max[i]);
 +
                }
 +
                else if (point[i] < min[i]) {
 +
                    diff = (min[i] - point[i]);
 +
                }
  
public double[] getDistances() {
+
                if (!Double.isNaN(diff)) {
return distance;
+
                    d += diff;
}
+
                }
}
+
            }
 +
 
 +
            return d;
 +
        }
 +
    }
 +
 
 +
    /**
 +
    * Class for tracking up to 'size' closest values
 +
    */
 +
    private static class ResultHeap {
 +
        private final Object[] data;
 +
        private final double[] distance;
 +
        private final int      size;
 +
        private int            values;
 +
        public Object          removedData;
 +
        public double          removedDist;
 +
 
 +
        public ResultHeap(int size) {
 +
            this.data = new Object[size];
 +
            this.distance = new double[size];
 +
            this.size = size;
 +
            this.values = 0;
 +
        }
 +
 
 +
        public void addValue(double dist, Object value) {
 +
            // If there is still room in the heap
 +
            if (values < size) {
 +
                // Insert new value at the end
 +
                data[values] = value;
 +
                distance[values] = dist;
 +
                upHeapify(values);
 +
                values++;
 +
            }
 +
            // If there is no room left in the heap, and the new entry is lower
 +
            // than the max entry
 +
            else if (dist < distance[0]) {
 +
                // Replace the max entry with the new entry
 +
                data[0] = value;
 +
                distance[0] = dist;
 +
                downHeapify(0);
 +
            }
 +
        }
 +
 
 +
        public void removeLargest() {
 +
            if (values == 0) {
 +
                throw new IllegalStateException();
 +
            }
 +
 
 +
            removedData = data[0];
 +
            removedDist = distance[0];
 +
            values--;
 +
            data[0] = data[values];
 +
            distance[0] = distance[values];
 +
            downHeapify(0);
 +
        }
 +
 
 +
        private void upHeapify(int c) {
 +
            for (int p = (c - 1) / 2; c != 0 && distance[c] > distance[p]; c = p, p = (c - 1) / 2) {
 +
                Object pData = data[p];
 +
                double pDist = distance[p];
 +
                data[p] = data[c];
 +
                distance[p] = distance[c];
 +
                data[c] = pData;
 +
                distance[c] = pDist;
 +
            }
 +
        }
 +
 
 +
        private void downHeapify(int p) {
 +
            for (int c = p * 2 + 1; c < values; p = c, c = p * 2 + 1) {
 +
                if (c + 1 < values && distance[c] < distance[c + 1]) {
 +
                    c++;
 +
                }
 +
                if (distance[p] < distance[c]) {
 +
                    // Swap the points
 +
                    Object pData = data[p];
 +
                    double pDist = distance[p];
 +
                    data[p] = data[c];
 +
                    distance[p] = distance[c];
 +
                    data[c] = pData;
 +
                    distance[c] = pDist;
 +
                }
 +
                else {
 +
                    break;
 +
                }
 +
            }
 +
        }
 +
 
 +
        public double getMaxDist() {
 +
            if (values < size) {
 +
                return Double.POSITIVE_INFINITY;
 +
            }
 +
            return distance[0];
 +
        }
 +
    }
 
}
 
}
</pre></code>
+
 
 +
</syntaxhighlight></code>

Latest revision as of 01:58, 14 June 2021

A nice efficient small kD-Tree. Currently the fasted kD-Tree implementation on Robowiki. Feel free to use.

Plans

Right now I'm working a rewrite, intended to have cleaner code, follow Java convention better, and be at least as fast. Current plans for the rewrite are:

  • Done! Cleaner code: Follow Java/OOP conventions better, since much that I abandoned in the below code was not necessary for speed.
  • Done! Nearest Neighbor Iterator: Provides an iterator to get nearest neighbor. This allows iterated fetching in case one doesn't know exactly how many neighbors one needs (i.e. if some are unusable data points due to other checks). Theoretical speed penalety should be very slim, perhaps even negligible.
  • Further improved speed: Yes, it's possible! Today I thought of three brand new techniques I should be able to use to increase speed further!
  • Done! Flexible path ordering: Since 'second choice' paths already have a full distance-to-bounding-box calculation done, why not use this information in order to check the 'paths not yet taken' based that computed distance rather than tree structure. Should be more optimal.
  • Unsuccessful. No improvement. Dimension-pruned distance calculations: With real data, there is often a situation where within a particular node, only some of the dimensions differ between points. It should be simple to track these 'unused' dimensions in a particular node and use this to optimize the distance calculation.
  • Implicit Subtrees: I thought about how I'm using an array to store the 'bucket', and thought "wouldn't it be nice to not have to calculate the distance for every single point in the bucket..." Well, it turns out, that can be avoided, all while keeping it in the nice compact array! It's just a matter of turning the bucket arrays into implicit kd-trees! This should keep the advantages of the bucket system for making the incrementally created tree balanced, while at the same time being more efficient!

I also plan to explore:

  • R-Tree/X-Tree type structures. They allow n-ary trees instead of only 2-ary trees like kd-trees, are self-balancing. Might have good results.
  • VP-Tree type structures. Splits based on distance to points may be more effective perhaps.

If you have any comments on these plans, comments would be appreciated: User talk:Rednaxela/kD-Tree

The Code

My latest released (circa 2010) version of this tree, aka my "3rd gen" one, is now on Gitlab. It supports a KNN iterator that can save you computational time if you aren't sure exactly how many points you will need. This version also includes some weighted distance functions from Tkiesel in 2012, and and a bug fix by Xor from 2016.

(Looking at my old backups it also looks like I have some unreleased test performance optimization variants dating to Jul 2013, but not sure if they were fruitful)

Old Code

My old "2nd gen" version of my tree is as follows. This is outdated and the above "3rd gen" version is recommended over it.

/**
 * Copyright 2009 Rednaxela
 * 
 * This software is provided 'as-is', without any express or implied
 * warranty. In no event will the authors be held liable for any damages
 * arising from the use of this software.
 * 
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 * 
 *    1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 
 *    2. This notice may not be removed or altered from any source
 *    distribution.
 */

package ags.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

/**
 * An efficient well-optimized kd-tree
 * 
 * @author Rednaxela
 */
public abstract class KdTree<T> {
    // Static variables
    private static final int           bucketSize = 24;

    // All types
    private final int                  dimensions;
    private final KdTree<T>            parent;

    // Root only
    private final LinkedList<double[]> locationStack;
    private final Integer              sizeLimit;

    // Leaf only
    private double[][]                 locations;
    private Object[]                   data;
    private int                        locationCount;

    // Stem only
    private KdTree<T>                  left, right;
    private int                        splitDimension;
    private double                     splitValue;

    // Bounds
    private double[]                   minLimit, maxLimit;
    private boolean                    singularity;

    // Temporary
    private Status                     status;

    /**
     * Construct a KdTree with a given number of dimensions and a limit on
     * maxiumum size (after which it throws away old points)
     */
    private KdTree(int dimensions, Integer sizeLimit) {
        this.dimensions = dimensions;

        // Init as leaf
        this.locations = new double[bucketSize][];
        this.data = new Object[bucketSize];
        this.locationCount = 0;
        this.singularity = true;

        // Init as root
        this.parent = null;
        this.sizeLimit = sizeLimit;
        if (sizeLimit != null) {
            this.locationStack = new LinkedList<double[]>();
        }
        else {
            this.locationStack = null;
        }
    }

    /**
     * Constructor for child nodes. Internal use only.
     */
    private KdTree(KdTree<T> parent, boolean right) {
        this.dimensions = parent.dimensions;

        // Init as leaf
        this.locations = new double[Math.max(bucketSize, parent.locationCount)][];
        this.data = new Object[Math.max(bucketSize, parent.locationCount)];
        this.locationCount = 0;
        this.singularity = true;

        // Init as non-root
        this.parent = parent;
        this.locationStack = null;
        this.sizeLimit = null;
    }

    /**
     * Get the number of points in the tree
     */
    public int size() {
        return locationCount;
    }

    /**
     * Add a point and associated value to the tree
     */
    public void addPoint(double[] location, T value) {
        KdTree<T> cursor = this;

        while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) {
            if (cursor.locations != null) {
                cursor.splitDimension = cursor.findWidestAxis();
                cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;

                // Never split on infinity or NaN
                if (cursor.splitValue == Double.POSITIVE_INFINITY) {
                    cursor.splitValue = Double.MAX_VALUE;
                }
                else if (cursor.splitValue == Double.NEGATIVE_INFINITY) {
                    cursor.splitValue = -Double.MAX_VALUE;
                }
                else if (Double.isNaN(cursor.splitValue)) {
                    cursor.splitValue = 0;
                }

                // Don't split node if it has no width in any axis. Double the
                // bucket size instead
                if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) {
                    double[][] newLocations = new double[cursor.locations.length * 2][];
                    System.arraycopy(cursor.locations, 0, newLocations, 0, cursor.locationCount);
                    cursor.locations = newLocations;
                    Object[] newData = new Object[newLocations.length];
                    System.arraycopy(cursor.data, 0, newData, 0, cursor.locationCount);
                    cursor.data = newData;
                    break;
                }

                // Don't let the split value be the same as the upper value as
                // can happen due to rounding errors!
                if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) {
                    cursor.splitValue = cursor.minLimit[cursor.splitDimension];
                }

                // Create child leaves
                KdTree<T> left = new ChildNode(cursor, false);
                KdTree<T> right = new ChildNode(cursor, true);

                // Move locations into children
                for (int i = 0; i < cursor.locationCount; i++) {
                    double[] oldLocation = cursor.locations[i];
                    Object oldData = cursor.data[i];
                    if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
                        // Right
                        right.locations[right.locationCount] = oldLocation;
                        right.data[right.locationCount] = oldData;
                        right.locationCount++;
                        right.extendBounds(oldLocation);
                    }
                    else {
                        // Left
                        left.locations[left.locationCount] = oldLocation;
                        left.data[left.locationCount] = oldData;
                        left.locationCount++;
                        left.extendBounds(oldLocation);
                    }
                }

                // Make into stem
                cursor.left = left;
                cursor.right = right;
                cursor.locations = null;
                cursor.data = null;
            }

            cursor.locationCount++;
            cursor.extendBounds(location);

            if (location[cursor.splitDimension] > cursor.splitValue) {
                cursor = cursor.right;
            }
            else {
                cursor = cursor.left;
            }
        }

        cursor.locations[cursor.locationCount] = location;
        cursor.data[cursor.locationCount] = value;
        cursor.locationCount++;
        cursor.extendBounds(location);

        if (this.sizeLimit != null) {
            this.locationStack.add(location);
            if (this.locationCount > this.sizeLimit) {
                this.removeOld();
            }
        }
    }

    /**
     * Extends the bounds of this node do include a new location
     */
    private final void extendBounds(double[] location) {
        if (minLimit == null) {
            minLimit = new double[dimensions];
            System.arraycopy(location, 0, minLimit, 0, dimensions);
            maxLimit = new double[dimensions];
            System.arraycopy(location, 0, maxLimit, 0, dimensions);
            return;
        }

        for (int i = 0; i < dimensions; i++) {
            if (Double.isNaN(location[i])) {
                minLimit[i] = Double.NaN;
                maxLimit[i] = Double.NaN;
                singularity = false;
            }
            else if (minLimit[i] > location[i]) {
                minLimit[i] = location[i];
                singularity = false;
            }
            else if (maxLimit[i] < location[i]) {
                maxLimit[i] = location[i];
                singularity = false;
            }
        }
    }

    /**
     * Find the widest axis of the bounds of this node
     */
    private final int findWidestAxis() {
        int widest = 0;
        double width = (maxLimit[0] - minLimit[0]) * getAxisWeightHint(0);
        if (Double.isNaN(width)) width = 0;
        for (int i = 1; i < dimensions; i++) {
            double nwidth = (maxLimit[i] - minLimit[i]) * getAxisWeightHint(i);
            if (Double.isNaN(nwidth)) nwidth = 0;
            if (nwidth > width) {
                widest = i;
                width = nwidth;
            }
        }
        return widest;
    }

    /**
     * Remove the oldest value from the tree. Note: This cannot trim the bounds
     * of nodes, nor empty nodes, and thus you can't expect it to perfectly
     * preserve the speed of the tree as you keep adding.
     */
    private void removeOld() {
        double[] location = this.locationStack.removeFirst();
        KdTree<T> cursor = this;

        // Find the node where the point is
        while (cursor.locations == null) {
            if (location[cursor.splitDimension] > cursor.splitValue) {
                cursor = cursor.right;
            }
            else {
                cursor = cursor.left;
            }
        }

        for (int i = 0; i < cursor.locationCount; i++) {
            if (cursor.locations[i] == location) {
                System.arraycopy(cursor.locations, i + 1, cursor.locations, i, cursor.locationCount - i - 1);
                cursor.locations[cursor.locationCount-1] = null;
                System.arraycopy(cursor.data, i + 1, cursor.data, i, cursor.locationCount - i - 1);
                cursor.data[cursor.locationCount-1] = null;
                do {
                    cursor.locationCount--;
                    cursor = cursor.parent;
                } while (cursor != null);
                return;
            }
        }
        // If we got here... we couldn't find the value to remove. Weird...
    }

    /**
     * Enumeration representing the status of a node during the running
     */
    private static enum Status {
        NONE, LEFTVISITED, RIGHTVISITED, ALLVISITED
    }

    /**
     * Stores a distance and value to output
     */
    public static class Entry<T> {
        public final double distance;
        public final T      value;

        private Entry(double distance, T value) {
            this.distance = distance;
            this.value = value;
        }
    }

    /**
     * Calculates the nearest 'count' points to 'location'
     */
    @SuppressWarnings("unchecked")
    public List<Entry<T>> nearestNeighbor(double[] location, int count, boolean sequentialSorting) {
        KdTree<T> cursor = this;
        cursor.status = Status.NONE;
        double range = Double.POSITIVE_INFINITY;
        ResultHeap resultHeap = new ResultHeap(count);

        do {
            if (cursor.status == Status.ALLVISITED) {
                // At a fully visited part. Move up the tree
                cursor = cursor.parent;
                continue;
            }

            if (cursor.status == Status.NONE && cursor.locations != null) {
                // At a leaf. Use the data.
                if (cursor.locationCount > 0) {
                    if (cursor.singularity) {
                        double dist = pointDist(cursor.locations[0], location);
                        if (dist <= range) {
                            for (int i = 0; i < cursor.locationCount; i++) {
                                resultHeap.addValue(dist, cursor.data[i]);
                            }
                        }
                    }
                    else {
                        for (int i = 0; i < cursor.locationCount; i++) {
                            double dist = pointDist(cursor.locations[i], location);
                            resultHeap.addValue(dist, cursor.data[i]);
                        }
                    }
                    range = resultHeap.getMaxDist();
                }

                if (cursor.parent == null) {
                    break;
                }
                cursor = cursor.parent;
                continue;
            }

            // Going to descend
            KdTree<T> nextCursor = null;
            if (cursor.status == Status.NONE) {
                // At a fresh node, descend the most probably useful direction
                if (location[cursor.splitDimension] > cursor.splitValue) {
                    // Descend right
                    nextCursor = cursor.right;
                    cursor.status = Status.RIGHTVISITED;
                }
                else {
                    // Descend left;
                    nextCursor = cursor.left;
                    cursor.status = Status.LEFTVISITED;
                }
            }
            else if (cursor.status == Status.LEFTVISITED) {
                // Left node visited, descend right.
                nextCursor = cursor.right;
                cursor.status = Status.ALLVISITED;
            }
            else if (cursor.status == Status.RIGHTVISITED) {
                // Right node visited, descend left.
                nextCursor = cursor.left;
                cursor.status = Status.ALLVISITED;
            }

            // Check if it's worth descending. Assume it is if it's sibling has
            // not been visited yet.
            if (cursor.status == Status.ALLVISITED) {
                if (nextCursor.locationCount == 0
                        || (!nextCursor.singularity && pointRegionDist(location, nextCursor.minLimit,
                                nextCursor.maxLimit) > range)) {
                    continue;
                }
            }

            // Descend down the tree
            cursor = nextCursor;
            cursor.status = Status.NONE;
        } while (cursor.parent != null || cursor.status != Status.ALLVISITED);

        ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(resultHeap.values);
        if (sequentialSorting) {
            while (resultHeap.values > 0) {
                resultHeap.removeLargest();
                results.add(new Entry<T>(resultHeap.removedDist, (T)resultHeap.removedData));
            }
        }
        else {
            for (int i = 0; i < resultHeap.values; i++) {
                results.add(new Entry<T>(resultHeap.distance[i], (T)resultHeap.data[i]));
            }
        }

        return results;
    }

    // Override in subclasses
    protected abstract double pointDist(double[] p1, double[] p2);

    protected abstract double pointRegionDist(double[] point, double[] min, double[] max);

    protected double getAxisWeightHint(int i) {
        return 1.0;
    }

    /**
     * Internal class for child nodes
     */
    private class ChildNode extends KdTree<T> {
        private ChildNode(KdTree<T> parent, boolean right) {
            super(parent, right);
        }

        // Distance measurements are always called from the root node
        protected double pointDist(double[] p1, double[] p2) {
            throw new IllegalStateException();
        }

        protected double pointRegionDist(double[] point, double[] min, double[] max) {
            throw new IllegalStateException();
        }
    }

    /**
     * Class for tree with Weighted Squared Euclidean distancing
     */
    public static class WeightedSqrEuclid<T> extends KdTree<T> {
        private double[] weights;

        public WeightedSqrEuclid(int dimensions, Integer sizeLimit) {
            super(dimensions, sizeLimit);
            this.weights = new double[dimensions];
            Arrays.fill(this.weights, 1.0);
        }

        public void setWeights(double[] weights) {
            this.weights = weights;
        }

        protected double getAxisWeightHint(int i) {
            return weights[i];
        }

        protected double pointDist(double[] p1, double[] p2) {
            double d = 0;

            for (int i = 0; i < p1.length; i++) {
                double diff = (p1[i] - p2[i]) * weights[i];
                if (!Double.isNaN(diff)) {
                    d += diff * diff;
                }
            }

            return d;
        }

        protected double pointRegionDist(double[] point, double[] min, double[] max) {
            double d = 0;

            for (int i = 0; i < point.length; i++) {
                double diff = 0;
                if (point[i] > max[i]) {
                    diff = (point[i] - max[i]) * weights[i];
                }
                else if (point[i] < min[i]) {
                    diff = (point[i] - min[i]) * weights[i];
                }

                if (!Double.isNaN(diff)) {
                    d += diff * diff;
                }
            }

            return d;
        }
    }

    /**
     * Class for tree with Unweighted Squared Euclidean distancing
     */
    public static class SqrEuclid<T> extends KdTree<T> {
        public SqrEuclid(int dimensions, Integer sizeLimit) {
            super(dimensions, sizeLimit);
        }

        protected double pointDist(double[] p1, double[] p2) {
            double d = 0;

            for (int i = 0; i < p1.length; i++) {
                double diff = (p1[i] - p2[i]);
                if (!Double.isNaN(diff)) {
                    d += diff * diff;
                }
            }

            return d;
        }

        protected double pointRegionDist(double[] point, double[] min, double[] max) {
            double d = 0;

            for (int i = 0; i < point.length; i++) {
                double diff = 0;
                if (point[i] > max[i]) {
                    diff = (point[i] - max[i]);
                }
                else if (point[i] < min[i]) {
                    diff = (point[i] - min[i]);
                }

                if (!Double.isNaN(diff)) {
                    d += diff * diff;
                }
            }

            return d;
        }
    }

    /**
     * Class for tree with Weighted Manhattan distancing
     */
    public static class WeightedManhattan<T> extends KdTree<T> {
        private double[] weights;

        public WeightedManhattan(int dimensions, Integer sizeLimit) {
            super(dimensions, sizeLimit);
            this.weights = new double[dimensions];
            Arrays.fill(this.weights, 1.0);
        }

        public void setWeights(double[] weights) {
            this.weights = weights;
        }

        protected double getAxisWeightHint(int i) {
            return weights[i];
        }

        protected double pointDist(double[] p1, double[] p2) {
            double d = 0;

            for (int i = 0; i < p1.length; i++) {
                double diff = (p1[i] - p2[i]);
                if (!Double.isNaN(diff)) {
                    d += ((diff < 0) ? -diff : diff) * weights[i];
                }
            }

            return d;
        }

        protected double pointRegionDist(double[] point, double[] min, double[] max) {
            double d = 0;

            for (int i = 0; i < point.length; i++) {
                double diff = 0;
                if (point[i] > max[i]) {
                    diff = (point[i] - max[i]);
                }
                else if (point[i] < min[i]) {
                    diff = (min[i] - point[i]);
                }

                if (!Double.isNaN(diff)) {
                    d += diff * weights[i];
                }
            }

            return d;
        }
    }

    /**
     * Class for tree with Manhattan distancing
     */
    public static class Manhattan<T> extends KdTree<T> {
        public Manhattan(int dimensions, Integer sizeLimit) {
            super(dimensions, sizeLimit);
        }

        protected double pointDist(double[] p1, double[] p2) {
            double d = 0;

            for (int i = 0; i < p1.length; i++) {
                double diff = (p1[i] - p2[i]);
                if (!Double.isNaN(diff)) {
                    d += (diff < 0) ? -diff : diff;
                }
            }

            return d;
        }

        protected double pointRegionDist(double[] point, double[] min, double[] max) {
            double d = 0;

            for (int i = 0; i < point.length; i++) {
                double diff = 0;
                if (point[i] > max[i]) {
                    diff = (point[i] - max[i]);
                }
                else if (point[i] < min[i]) {
                    diff = (min[i] - point[i]);
                }

                if (!Double.isNaN(diff)) {
                    d += diff;
                }
            }

            return d;
        }
    }

    /**
     * Class for tracking up to 'size' closest values
     */
    private static class ResultHeap {
        private final Object[] data;
        private final double[] distance;
        private final int      size;
        private int            values;
        public Object          removedData;
        public double          removedDist;

        public ResultHeap(int size) {
            this.data = new Object[size];
            this.distance = new double[size];
            this.size = size;
            this.values = 0;
        }

        public void addValue(double dist, Object value) {
            // If there is still room in the heap
            if (values < size) {
                // Insert new value at the end
                data[values] = value;
                distance[values] = dist;
                upHeapify(values);
                values++;
            }
            // If there is no room left in the heap, and the new entry is lower
            // than the max entry
            else if (dist < distance[0]) {
                // Replace the max entry with the new entry
                data[0] = value;
                distance[0] = dist;
                downHeapify(0);
            }
        }

        public void removeLargest() {
            if (values == 0) {
                throw new IllegalStateException();
            }

            removedData = data[0];
            removedDist = distance[0];
            values--;
            data[0] = data[values];
            distance[0] = distance[values];
            downHeapify(0);
        }

        private void upHeapify(int c) {
            for (int p = (c - 1) / 2; c != 0 && distance[c] > distance[p]; c = p, p = (c - 1) / 2) {
                Object pData = data[p];
                double pDist = distance[p];
                data[p] = data[c];
                distance[p] = distance[c];
                data[c] = pData;
                distance[c] = pDist;
            }
        }

        private void downHeapify(int p) {
            for (int c = p * 2 + 1; c < values; p = c, c = p * 2 + 1) {
                if (c + 1 < values && distance[c] < distance[c + 1]) {
                    c++;
                }
                if (distance[p] < distance[c]) {
                    // Swap the points
                    Object pData = data[p];
                    double pDist = distance[p];
                    data[p] = data[c];
                    distance[p] = distance[c];
                    data[c] = pData;
                    distance[c] = pDist;
                }
                else {
                    break;
                }
            }
        }

        public double getMaxDist() {
            if (values < size) {
                return Double.POSITIVE_INFINITY;
            }
            return distance[0];
        }
    }
}