Difference between revisions of "User:Rednaxela/kD-Tree"

From Robowiki
Jump to navigation Jump to search
m (Remove some silly newlines that crept in)
m (more reformatting)
Line 36: Line 36:
 
  */
 
  */
 
public class KdTree<T> {
 
public class KdTree<T> {
// Static variables
+
    // Static variables
private static final int bucketSize = 32;
+
    private static final int bucketSize = 32;
  
// All types
+
    // All types
private final int dimensions;
+
    private final int dimensions;
  
// Root only
+
    // Root only
private final HashMap<Object, T> map;
+
    private final HashMap<Object, T> map;
private double[] weights;
+
    private double[] weights;
  
// Leaf only
+
    // Leaf only
private double[][] locations;
+
    private double[][] locations;
private int locationCount;
+
    private int locationCount;
  
// Stem only
+
    // Stem only
private KdTree<T> left, right;
+
    private KdTree<T> left, right;
private int splitDimension;
+
    private int splitDimension;
private double splitValue;
+
    private double splitValue;
  
// Bounds
+
    // Bounds
private double[] minLimit, maxLimit;
+
    private double[] minLimit, maxLimit;
  
/**
+
    /**
* Extends the bounds of this node do include a new location
+
    * Extends the bounds of this node do include a new location
*/
+
    */
private final void extendBounds(double[] location) {
+
    private final void extendBounds(double[] location) {
if (minLimit == null) {
+
        if (minLimit == null) {
minLimit = Arrays.copyOf(location, dimensions);
+
            minLimit = Arrays.copyOf(location, dimensions);
maxLimit = Arrays.copyOf(location, dimensions);
+
            maxLimit = Arrays.copyOf(location, dimensions);
return;
+
            return;
}
+
        }
  
for (int i=0; i<dimensions; i++) {
+
        for (int i=0; i<dimensions; i++) {
if (minLimit[i] > location[i]) {
+
            if (minLimit[i] > location[i]) {
minLimit[i] = location[i];
+
                minLimit[i] = location[i];
}
+
            }
if (maxLimit[i] < location[i]) {
+
            if (maxLimit[i] < location[i]) {
maxLimit[i] = location[i];
+
                maxLimit[i] = location[i];
}
+
            }
}
+
        }
}
+
    }
  
/**
+
    /**
* Find the widest axis of the bounds of this node
+
    * Find the widest axis of the bounds of this node
*/
+
    */
private final int findWidestAxis() {
+
    private final int findWidestAxis() {
int widest = 0;
+
        int widest = 0;
double width = (maxLimit[0] - minLimit[0]);
+
        double width = (maxLimit[0] - minLimit[0]);
for (int i = 1; i < dimensions; i++) {
+
        for (int i = 1; i < dimensions; i++) {
double nwidth = maxLimit[i] - minLimit[i];
+
            double nwidth = maxLimit[i] - minLimit[i];
if (nwidth > width) {
+
            if (nwidth > width) {
widest = i;
+
                widest = i;
width = nwidth;
+
                width = nwidth;
}
+
            }
}
+
        }
return widest;
+
        return widest;
}
+
    }
  
// Main constructor
+
    // Main constructor
public KdTree(int dimensions) {
+
    public KdTree(int dimensions) {
this.dimensions = dimensions;
+
        this.dimensions = dimensions;
  
// Init as leaf
+
        // Init as leaf
this.locations = new double[bucketSize][];
+
        this.locations = new double[bucketSize][];
this.locationCount = 0;
+
        this.locationCount = 0;
  
// Init as root
+
        // Init as root
this.map = new HashMap<Object, T>();
+
        this.map = new HashMap<Object, T>();
this.weights = new double[dimensions];
+
        this.weights = new double[dimensions];
Arrays.fill(this.weights, 1.0);
+
        Arrays.fill(this.weights, 1.0);
}
+
    }
  
// Child constructor
+
    // Child constructor
private KdTree(KdTree<T> parent, boolean right) {
+
    private KdTree(KdTree<T> parent, boolean right) {
this.dimensions = parent.dimensions;
+
        this.dimensions = parent.dimensions;
  
// Init as leaf
+
        // Init as leaf
this.locations = new double[bucketSize][];
+
        this.locations = new double[bucketSize][];
this.locationCount = 0;
+
        this.locationCount = 0;
  
// Init as non-root
+
        // Init as non-root
this.map = null;
+
        this.map = null;
}
+
    }
  
/**
+
    /**
* Add a point and associated value to the tree
+
    * Add a point and associated value to the tree
*/
+
    */
public static <T> void addPoint(KdTree<T> tree, double[] location, T
+
    public static <T> void addPoint(KdTree<T> tree, double[] location, T
value) {
+
            value) {
KdTree<T> cursor = tree;
+
        KdTree<T> cursor = tree;
  
while (cursor.locations == null || cursor.locationCount >=
+
        while (cursor.locations == null || cursor.locationCount >=
cursor.locations.length) {
+
            cursor.locations.length) {
if (cursor.locations != null) {
+
            if (cursor.locations != null) {
cursor.splitDimension = cursor.findWidestAxis();
+
                cursor.splitDimension = cursor.findWidestAxis();
cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
+
                cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
cursor.maxLimit[cursor.splitDimension]) * 0.5;
+
                        cursor.maxLimit[cursor.splitDimension]) * 0.5;
  
// Don't split node if it has no width in any axis. Double the bucket size instead
+
                // Don't split node if it has no width in any axis. Double the bucket size instead
if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) {
+
                if ((cursor.minLimit[cursor.splitDimension] -   cursor.maxLimit[cursor.splitDimension]) == 0) {
cursor.locations = Arrays.copyOf(cursor.locations,
+
                    cursor.locations = Arrays.copyOf(cursor.locations, cursor.locations.length * 2);
cursor.locations.length * 2);
+
                    break;
break;
+
                }
}
 
  
// Create child leaves
+
                // Create child leaves
KdTree<T> left = new KdTree<T>(cursor, false);
+
                KdTree<T> left = new KdTree<T>(cursor, false);
KdTree<T> right = new KdTree<T>(cursor, true);
+
                KdTree<T> right = new KdTree<T>(cursor, true);
  
// Move locations into children
+
                // Move locations into children
for (double[] oldLocation : cursor.locations) {
+
                for (double[] oldLocation : cursor.locations) {
if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
+
                    if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
// Right
+
                        // Right
right.locations[right.locationCount] = oldLocation;
+
                        right.locations[right.locationCount] = oldLocation;
right.locationCount++;
+
                        right.locationCount++;
right.extendBounds(oldLocation);
+
                        right.extendBounds(oldLocation);
}
+
                    }
else {
+
                    else {
// Left
+
                        // Left
left.locations[left.locationCount] = oldLocation;
+
                        left.locations[left.locationCount] = oldLocation;
left.locationCount++;
+
                        left.locationCount++;
left.extendBounds(oldLocation);
+
                        left.extendBounds(oldLocation);
}
+
                    }
}
+
                }
  
// Make into stem
+
                // Make into stem
cursor.left = left;
+
                cursor.left = left;
cursor.right = right;
+
                cursor.right = right;
cursor.locations = null;
+
                cursor.locations = null;
}
+
            }
  
cursor.extendBounds(location);
+
            cursor.extendBounds(location);
  
if (location[cursor.splitDimension] > cursor.splitValue) {
+
            if (location[cursor.splitDimension] > cursor.splitValue) {
cursor = cursor.right;
+
                cursor = cursor.right;
}
+
            }
else {
+
            else {
cursor = cursor.left;
+
                cursor = cursor.left;
}
+
            }
}
+
        }
  
cursor.locations[cursor.locationCount] = location;
+
        cursor.locations[cursor.locationCount] = location;
cursor.locationCount++;
+
        cursor.locationCount++;
cursor.extendBounds(location);
+
        cursor.extendBounds(location);
  
tree.map.put(location, value);
+
        tree.map.put(location, value);
}
+
    }
  
/**
+
    /**
* Enumeration representing the status of a node during the running  
+
    * Enumeration representing the status of a node during the running  
*/
+
    */
private static enum Status {
+
    private static enum Status {
NONE,
+
        NONE,
LEFTVISITED,
+
        LEFTVISITED,
RIGHTVISITED,
+
        RIGHTVISITED,
ALLVISITED
+
        ALLVISITED
}
+
    }
  
/**
+
    /**
* Stores a distance and value to output
+
    * Stores a distance and value to output
*/
+
    */
public static class Entry<T> {
+
    public static class Entry<T> {
public final double distance;
+
        public final double distance;
public final T value;
+
        public final T value;
private Entry(double distance, T value) {
+
        private Entry(double distance, T value) {
this.distance = distance;
+
            this.distance = distance;
this.value = value;
+
            this.value = value;
}
+
        }
}
+
    }
  
/**
+
    /**
* Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
+
    * Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
*/
+
    */
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
+
    public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
double[] location, int count, double[] weights) {
+
            double[] location, int count, double[] weights) {
tree.weights = weights;
+
        tree.weights = weights;
return nearestNeighbor(tree, location, count);
+
        return nearestNeighbor(tree, location, count);
}
+
    }
  
/**
+
    /**
* Calculates the nearest 'count' points to 'location'
+
    * Calculates the nearest 'count' points to 'location'
*/
+
    */
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
+
    public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
double[] location, int count) {
+
            double[] location, int count) {
KdTree<T> cursor = tree;
+
        KdTree<T> cursor = tree;
Status status = Status.NONE;
+
        Status status = Status.NONE;
Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
+
        Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
Stack<Status> statusStack = new Stack<Status>();
+
        Stack<Status> statusStack = new Stack<Status>();
double range = Double.POSITIVE_INFINITY;
+
        double range = Double.POSITIVE_INFINITY;
ResultHeap resultHeap = new ResultHeap(count);  
+
        ResultHeap resultHeap = new ResultHeap(count);  
  
do {
+
        do {
if (status == Status.ALLVISITED) {
+
            if (status == Status.ALLVISITED) {
// At a fully visited part. Move up the tree
+
                // At a fully visited part. Move up the tree
cursor = stack.pop();
+
                cursor = stack.pop();
status = statusStack.pop();
+
                status = statusStack.pop();
continue;
+
                continue;
}
+
            }
  
if (status == Status.NONE && cursor.locations != null) {
+
            if (status == Status.NONE && cursor.locations != null) {
// At a leaf. Use the data.
+
                // At a leaf. Use the data.
for (int i=0; i<cursor.locationCount; i++) {
+
                for (int i=0; i<cursor.locationCount; i++) {
double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
+
                    double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
resultHeap.addValue(dist, cursor.locations[i]);
+
                    resultHeap.addValue(dist, cursor.locations[i]);
}
+
                }
range = resultHeap.getMaxDist();
+
                range = resultHeap.getMaxDist();
  
if (stack.empty()) {
+
                if (stack.empty()) {
break;
+
                    break;
}
+
                }
cursor = stack.pop();
+
                cursor = stack.pop();
status = statusStack.pop();
+
                status = statusStack.pop();
continue;
+
                continue;
}
+
            }
  
// Going to descend
+
            // Going to descend
KdTree<T> nextCursor = null;
+
            KdTree<T> nextCursor = null;
if (status == Status.NONE) {
+
            if (status == Status.NONE) {
// At a fresh node, descend the most probably useful direction
+
                // At a fresh node, descend the most probably useful direction
if (location[cursor.splitDimension] > cursor.splitValue) {
+
                if (location[cursor.splitDimension] > cursor.splitValue) {
// Descend right
+
                    // Descend right
nextCursor = cursor.right;
+
                    nextCursor = cursor.right;
status = Status.RIGHTVISITED;
+
                    status = Status.RIGHTVISITED;
}
+
                }
else {
+
                else {
// Descend left;
+
                    // Descend left;
nextCursor = cursor.left;
+
                    nextCursor = cursor.left;
status = Status.LEFTVISITED;
+
                    status = Status.LEFTVISITED;
}
+
                }
}
+
            }
else if (status == Status.LEFTVISITED) {
+
            else if (status == Status.LEFTVISITED) {
// Left node visited, descend right.
+
                // Left node visited, descend right.
nextCursor = cursor.right;
+
                nextCursor = cursor.right;
status = Status.ALLVISITED;
+
                status = Status.ALLVISITED;
}
+
            }
else if (status == Status.RIGHTVISITED) {
+
            else if (status == Status.RIGHTVISITED) {
// Right node visited, descend left.
+
                // Right node visited, descend left.
nextCursor = cursor.left;
+
                nextCursor = cursor.left;
status = Status.ALLVISITED;
+
                status = Status.ALLVISITED;
}
+
            }
  
// Check if it's worth descending. Assume it is if it's sibling has not been visited yet.  
+
            // Check if it's worth descending. Assume it is if it's sibling has not been visited yet.  
if (status == Status.ALLVISITED) {
+
            if (status == Status.ALLVISITED) {
if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
+
                if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
continue;
+
                    continue;
}
+
                }
}
+
            }
  
// Descend down the tree
+
            // Descend down the tree
stack.push(cursor);
+
            stack.push(cursor);
statusStack.push(status);
+
            statusStack.push(status);
cursor = nextCursor;
+
            cursor = nextCursor;
status = Status.NONE;
+
            status = Status.NONE;
} while (stack.size() > 0 || status != Status.ALLVISITED);
+
        } while (stack.size() > 0 || status != Status.ALLVISITED);
  
ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
+
        ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
Object[] data = resultHeap.getData();
+
        Object[] data = resultHeap.getData();
double[] dist = resultHeap.getDistances();
+
        double[] dist = resultHeap.getDistances();
for (int i=0; i<resultHeap.values; i++) {
+
        for (int i=0; i<resultHeap.values; i++) {
T value = tree.map.get(data[i]);
+
            T value = tree.map.get(data[i]);
results.add(new Entry<T>(dist[i], value));
+
            results.add(new Entry<T>(dist[i], value));
}
+
        }
  
return results;
+
        return results;
}
+
    }
  
/**
+
    /**
* Calculates the (squared euclidean) distance between two points
+
    * Calculates the (squared euclidean) distance between two points
*/
+
    */
private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
+
    private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
double d = 0;
+
        double d = 0;
  
for (int i=0; i<p1.length; i++) {
+
        for (int i=0; i<p1.length; i++) {
double diff = (p1[i] - p2[i]) * weights[i];
+
            double diff = (p1[i] - p2[i]) * weights[i];
d += diff * diff;
+
            d += diff * diff;
}
+
        }
  
return d;
+
        return d;
}
+
    }
  
/**
+
    /**
* Calculates the closest (squared euclidean) distance between in a point and a bounding region
+
    * Calculates the closest (squared euclidean) distance between in a point and a bounding region
*/
+
    */
private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
+
    private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
double d = 0;
+
        double d = 0;
  
for (int i=0; i<point.length; i++) {
+
        for (int i=0; i<point.length; i++) {
if (point[i] > max[i]) {
+
            if (point[i] > max[i]) {
double diff = (point[i] - max[i]) * weights[i];
+
                double diff = (point[i] - max[i]) * weights[i];
d += diff * diff;
+
                d += diff * diff;
} else if (point[i] < min[i]) {
+
            } else if (point[i] < min[i]) {
double diff = (point[i] - min[i]) * weights[i];
+
                double diff = (point[i] - min[i]) * weights[i];
d += diff * diff;
+
                d += diff * diff;
}
+
            }
}
+
        }
  
return d;
+
        return d;
}
+
    }
  
/**
+
    /**
* Class for tracking up to 'size' closest values
+
    * Class for tracking up to 'size' closest values
*/
+
    */
private static class ResultHeap {
+
    private static class ResultHeap {
private final Object[] data;
+
        private final Object[] data;
private final double[] distance;
+
        private final double[] distance;
private final int size;
+
        private final int size;
private int values;
+
        private int values;
  
public ResultHeap(int size) {
+
        public ResultHeap(int size) {
this.data = new Object[size+1];
+
            this.data = new Object[size+1];
this.distance = new double[size+1];
+
            this.distance = new double[size+1];
this.size = size;
+
            this.size = size;
this.values = 0;
+
            this.values = 0;
}
+
        }
  
public void addValue(double dist, Object value) {
+
        public void addValue(double dist, Object value) {
if (values == size && dist >= distance[0]) {
+
            if (values == size && dist >= distance[0]) {
return;
+
                return;
}
+
            }
  
// Insert value
+
            // Insert value
data[values] = value;
+
            data[values] = value;
distance[values] = dist;
+
            distance[values] = dist;
values++;
+
            values++;
  
// Up-Heapify
+
            // Up-Heapify
for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
+
            for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
Object pData = data[p];
+
                Object pData = data[p];
double pDist = distance[p];
+
                double pDist = distance[p];
data[p] = data[c];
+
                data[p] = data[c];
distance[p] = distance[c];
+
                distance[p] = distance[c];
data[c] = pData;
+
                data[c] = pData;
distance[c] = pDist;
+
                distance[c] = pDist;
}
+
            }
  
// If too big, remove the highest value
+
            // If too big, remove the highest value
if (values > size) {
+
            if (values > size) {
// Move the last entry to the top
+
                // Move the last entry to the top
values--;
+
                values--;
data[0] = data[values];
+
                data[0] = data[values];
distance[0] = distance[values];
+
                distance[0] = distance[values];
  
// Down-Heapify
+
                // Down-Heapify
for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
+
                for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
if (c+1 < values && distance[c] < distance[c+1]) {
+
                    if (c+1 < values && distance[c] < distance[c+1]) {
c++;
+
                        c++;
}
+
                    }
if (distance[p] < distance[c]) {
+
                    if (distance[p] < distance[c]) {
// Swap the points
+
                        // Swap the points
Object pData = data[p];
+
                        Object pData = data[p];
double pDist = distance[p];
+
                        double pDist = distance[p];
data[p] = data[c];
+
                        data[p] = data[c];
distance[p] = distance[c];
+
                        distance[p] = distance[c];
data[c] = pData;
+
                        data[c] = pData;
distance[c] = pDist;
+
                        distance[c] = pDist;
}
+
                    }
else {
+
                    else {
break;
+
                        break;
}
+
                    }
}
+
                }
}
+
            }
}
+
        }
  
public double getMaxDist() {
+
        public double getMaxDist() {
if (values < size) {
+
            if (values < size) {
return Double.POSITIVE_INFINITY;
+
                return Double.POSITIVE_INFINITY;
}
+
            }
return distance[0];
+
            return distance[0];
}
+
        }
  
public Object[] getData() {
+
        public Object[] getData() {
return data;
+
            return data;
}
+
        }
  
public double[] getDistances() {
+
        public double[] getDistances() {
return distance;
+
            return distance;
}
+
        }
}
+
    }
 
}
 
}
 +
 
</pre></code>
 
</pre></code>

Revision as of 21:20, 26 August 2009

A nice efficent small kD-Tree. It's quite fast... Feel free to use

/**
 * Copyright 2009 Rednaxela
 * 
 * This software is provided 'as-is', without any express or implied
 * warranty. In no event will the authors be held liable for any damages
 * arising from the use of this software.
 * 
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 * 
 *    1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 
 *    2. This notice may not be removed or altered from any source
 *    distribution.
 */

package ags.utils.newtree2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;

/**
 * An efficent well-optimized kd-tree
 * 
 * @author Rednaxela
 */
public class KdTree<T> {
    // Static variables
    private static final int bucketSize = 32;

    // All types
    private final int dimensions;

    // Root only
    private final HashMap<Object, T> map;
    private double[] weights;

    // Leaf only
    private double[][] locations;
    private int locationCount;

    // Stem only
    private KdTree<T> left, right;
    private int splitDimension;
    private double splitValue;

    // Bounds
    private double[] minLimit, maxLimit;

    /**
     * Extends the bounds of this node do include a new location
     */
    private final void extendBounds(double[] location) {
        if (minLimit == null) {
            minLimit = Arrays.copyOf(location, dimensions);
            maxLimit = Arrays.copyOf(location, dimensions);
            return;
        }

        for (int i=0; i<dimensions; i++) {
            if (minLimit[i] > location[i]) {
                minLimit[i] = location[i];
            }
            if (maxLimit[i] < location[i]) {
                maxLimit[i] = location[i];
            }
        }
    }

    /**
     * Find the widest axis of the bounds of this node
     */
    private final int findWidestAxis() {
        int widest = 0;
        double width = (maxLimit[0] - minLimit[0]);
        for (int i = 1; i < dimensions; i++) {
            double nwidth = maxLimit[i] - minLimit[i];
            if (nwidth > width) {
                widest = i;
                width = nwidth;
            }
        }
        return widest;
    }

    // Main constructor
    public KdTree(int dimensions) {
        this.dimensions = dimensions;

        // Init as leaf
        this.locations = new double[bucketSize][];
        this.locationCount = 0;

        // Init as root
        this.map = new HashMap<Object, T>();
        this.weights = new double[dimensions];
        Arrays.fill(this.weights, 1.0);
    }

    // Child constructor
    private KdTree(KdTree<T> parent, boolean right) {
        this.dimensions = parent.dimensions;

        // Init as leaf
        this.locations = new double[bucketSize][];
        this.locationCount = 0;

        // Init as non-root
        this.map = null;
    }

    /**
     * Add a point and associated value to the tree
     */
    public static <T> void addPoint(KdTree<T> tree, double[] location, T
            value) {
        KdTree<T> cursor = tree;

        while (cursor.locations == null || cursor.locationCount >=
            cursor.locations.length) {
            if (cursor.locations != null) {
                cursor.splitDimension = cursor.findWidestAxis();
                cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
                        cursor.maxLimit[cursor.splitDimension]) * 0.5;

                // Don't split node if it has no width in any axis. Double the bucket size instead
                if ((cursor.minLimit[cursor.splitDimension] -    cursor.maxLimit[cursor.splitDimension]) == 0) {
                    cursor.locations = Arrays.copyOf(cursor.locations, cursor.locations.length * 2);
                    break;
                }

                // Create child leaves
                KdTree<T> left = new KdTree<T>(cursor, false);
                KdTree<T> right = new KdTree<T>(cursor, true);

                // Move locations into children
                for (double[] oldLocation : cursor.locations) {
                    if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
                        // Right
                        right.locations[right.locationCount] = oldLocation;
                        right.locationCount++;
                        right.extendBounds(oldLocation);
                    }
                    else {
                        // Left
                        left.locations[left.locationCount] = oldLocation;
                        left.locationCount++;
                        left.extendBounds(oldLocation);
                    }
                }

                // Make into stem
                cursor.left = left;
                cursor.right = right;
                cursor.locations = null;
            }

            cursor.extendBounds(location);

            if (location[cursor.splitDimension] > cursor.splitValue) {
                cursor = cursor.right;
            }
            else {
                cursor = cursor.left;
            }
        }

        cursor.locations[cursor.locationCount] = location;
        cursor.locationCount++;
        cursor.extendBounds(location);

        tree.map.put(location, value);
    }

    /**
     * Enumeration representing the status of a node during the running 
     */
    private static enum Status {
        NONE,
        LEFTVISITED,
        RIGHTVISITED,
        ALLVISITED
    }

    /**
     * Stores a distance and value to output
     */
    public static class Entry<T> {
        public final double distance;
        public final T value;
        private Entry(double distance, T value) {
            this.distance = distance;
            this.value = value;
        }
    }

    /**
     * Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
     */
    public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
            double[] location, int count, double[] weights) {
        tree.weights = weights;
        return nearestNeighbor(tree, location, count);
    }

    /**
     * Calculates the nearest 'count' points to 'location'
     */
    public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
            double[] location, int count) {
        KdTree<T> cursor = tree;
        Status status = Status.NONE;
        Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
        Stack<Status> statusStack = new Stack<Status>();
        double range = Double.POSITIVE_INFINITY;
        ResultHeap resultHeap = new ResultHeap(count); 

        do {
            if (status == Status.ALLVISITED) {
                // At a fully visited part. Move up the tree
                cursor = stack.pop();
                status = statusStack.pop();
                continue;
            }

            if (status == Status.NONE && cursor.locations != null) {
                // At a leaf. Use the data.
                for (int i=0; i<cursor.locationCount; i++) {
                    double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
                    resultHeap.addValue(dist, cursor.locations[i]);
                }
                range = resultHeap.getMaxDist();

                if (stack.empty()) {
                    break;
                }
                cursor = stack.pop();
                status = statusStack.pop();
                continue;
            }

            // Going to descend
            KdTree<T> nextCursor = null;
            if (status == Status.NONE) {
                // At a fresh node, descend the most probably useful direction
                if (location[cursor.splitDimension] > cursor.splitValue) {
                    // Descend right
                    nextCursor = cursor.right;
                    status = Status.RIGHTVISITED;
                }
                else {
                    // Descend left;
                    nextCursor = cursor.left;
                    status = Status.LEFTVISITED;
                }
            }
            else if (status == Status.LEFTVISITED) {
                // Left node visited, descend right.
                nextCursor = cursor.right;
                status = Status.ALLVISITED;
            }
            else if (status == Status.RIGHTVISITED) {
                // Right node visited, descend left.
                nextCursor = cursor.left;
                status = Status.ALLVISITED;
            }

            // Check if it's worth descending. Assume it is if it's sibling has not been visited yet. 
            if (status == Status.ALLVISITED) {
                if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
                    continue;
                }
            }

            // Descend down the tree
            stack.push(cursor);
            statusStack.push(status);
            cursor = nextCursor;
            status = Status.NONE;
        } while (stack.size() > 0 || status != Status.ALLVISITED);

        ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
        Object[] data = resultHeap.getData();
        double[] dist = resultHeap.getDistances();
        for (int i=0; i<resultHeap.values; i++) {
            T value = tree.map.get(data[i]);
            results.add(new Entry<T>(dist[i], value));
        }

        return results;
    }

    /**
     * Calculates the (squared euclidean) distance between two points
     */
    private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
        double d = 0;

        for (int i=0; i<p1.length; i++) {
            double diff = (p1[i] - p2[i]) * weights[i];
            d += diff * diff;
        }

        return d;
    }

    /**
     * Calculates the closest (squared euclidean) distance between in a point and a bounding region
     */
    private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
        double d = 0;

        for (int i=0; i<point.length; i++) {
            if (point[i] > max[i]) {
                double diff = (point[i] - max[i]) * weights[i];
                d += diff * diff;
            } else if (point[i] < min[i]) {
                double diff = (point[i] - min[i]) * weights[i];
                d += diff * diff;
            }
        }

        return d;
    }

    /**
     * Class for tracking up to 'size' closest values
     */
    private static class ResultHeap {
        private final Object[] data;
        private final double[] distance;
        private final int size;
        private int values;

        public ResultHeap(int size) {
            this.data = new Object[size+1];
            this.distance = new double[size+1];
            this.size = size;
            this.values = 0;
        }

        public void addValue(double dist, Object value) {
            if (values == size && dist >= distance[0]) {
                return;
            }

            // Insert value
            data[values] = value;
            distance[values] = dist;
            values++;

            // Up-Heapify
            for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
                Object pData = data[p];
                double pDist = distance[p];
                data[p] = data[c];
                distance[p] = distance[c];
                data[c] = pData;
                distance[c] = pDist;
            }

            // If too big, remove the highest value
            if (values > size) {
                // Move the last entry to the top
                values--;
                data[0] = data[values];
                distance[0] = distance[values];

                // Down-Heapify
                for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
                    if (c+1 < values && distance[c] < distance[c+1]) {
                        c++;
                    }
                    if (distance[p] < distance[c]) {
                        // Swap the points
                        Object pData = data[p];
                        double pDist = distance[p];
                        data[p] = data[c];
                        distance[p] = distance[c];
                        data[c] = pData;
                        distance[c] = pDist;
                    }
                    else {
                        break;
                    }
                }
            }
        }

        public double getMaxDist() {
            if (values < size) {
                return Double.POSITIVE_INFINITY;
            }
            return distance[0];
        }

        public Object[] getData() {
            return data;
        }

        public double[] getDistances() {
            return distance;
        }
    }
}