Difference between revisions of "User:Rednaxela/kD-Tree"
Jump to navigation
Jump to search
m (Remove some silly newlines that crept in) |
m (more reformatting) |
||
Line 36: | Line 36: | ||
*/ | */ | ||
public class KdTree<T> { | public class KdTree<T> { | ||
− | + | // Static variables | |
− | + | private static final int bucketSize = 32; | |
− | + | // All types | |
− | + | private final int dimensions; | |
− | + | // Root only | |
− | + | private final HashMap<Object, T> map; | |
− | + | private double[] weights; | |
− | + | // Leaf only | |
− | + | private double[][] locations; | |
− | + | private int locationCount; | |
− | + | // Stem only | |
− | + | private KdTree<T> left, right; | |
− | + | private int splitDimension; | |
− | + | private double splitValue; | |
− | + | // Bounds | |
− | + | private double[] minLimit, maxLimit; | |
− | + | /** | |
− | + | * Extends the bounds of this node do include a new location | |
− | + | */ | |
− | + | private final void extendBounds(double[] location) { | |
− | + | if (minLimit == null) { | |
− | + | minLimit = Arrays.copyOf(location, dimensions); | |
− | + | maxLimit = Arrays.copyOf(location, dimensions); | |
− | + | return; | |
− | + | } | |
− | + | for (int i=0; i<dimensions; i++) { | |
− | + | if (minLimit[i] > location[i]) { | |
− | + | minLimit[i] = location[i]; | |
− | + | } | |
− | + | if (maxLimit[i] < location[i]) { | |
− | + | maxLimit[i] = location[i]; | |
− | + | } | |
− | + | } | |
− | + | } | |
− | + | /** | |
− | + | * Find the widest axis of the bounds of this node | |
− | + | */ | |
− | + | private final int findWidestAxis() { | |
− | + | int widest = 0; | |
− | + | double width = (maxLimit[0] - minLimit[0]); | |
− | + | for (int i = 1; i < dimensions; i++) { | |
− | + | double nwidth = maxLimit[i] - minLimit[i]; | |
− | + | if (nwidth > width) { | |
− | + | widest = i; | |
− | + | width = nwidth; | |
− | + | } | |
− | + | } | |
− | + | return widest; | |
− | + | } | |
− | + | // Main constructor | |
− | + | public KdTree(int dimensions) { | |
− | + | this.dimensions = dimensions; | |
− | + | // Init as leaf | |
− | + | this.locations = new double[bucketSize][]; | |
− | + | this.locationCount = 0; | |
− | + | // Init as root | |
− | + | this.map = new HashMap<Object, T>(); | |
− | + | this.weights = new double[dimensions]; | |
− | + | Arrays.fill(this.weights, 1.0); | |
− | + | } | |
− | + | // Child constructor | |
− | + | private KdTree(KdTree<T> parent, boolean right) { | |
− | + | this.dimensions = parent.dimensions; | |
− | + | // Init as leaf | |
− | + | this.locations = new double[bucketSize][]; | |
− | + | this.locationCount = 0; | |
− | + | // Init as non-root | |
− | + | this.map = null; | |
− | + | } | |
− | + | /** | |
− | + | * Add a point and associated value to the tree | |
− | + | */ | |
− | + | public static <T> void addPoint(KdTree<T> tree, double[] location, T | |
− | + | value) { | |
− | + | KdTree<T> cursor = tree; | |
− | + | while (cursor.locations == null || cursor.locationCount >= | |
− | + | cursor.locations.length) { | |
− | + | if (cursor.locations != null) { | |
− | + | cursor.splitDimension = cursor.findWidestAxis(); | |
− | + | cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + | |
− | + | cursor.maxLimit[cursor.splitDimension]) * 0.5; | |
− | + | // Don't split node if it has no width in any axis. Double the bucket size instead | |
− | + | if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) { | |
− | + | cursor.locations = Arrays.copyOf(cursor.locations, cursor.locations.length * 2); | |
− | + | break; | |
− | + | } | |
− | |||
− | + | // Create child leaves | |
− | + | KdTree<T> left = new KdTree<T>(cursor, false); | |
− | + | KdTree<T> right = new KdTree<T>(cursor, true); | |
− | + | // Move locations into children | |
− | + | for (double[] oldLocation : cursor.locations) { | |
− | + | if (oldLocation[cursor.splitDimension] > cursor.splitValue) { | |
− | + | // Right | |
− | + | right.locations[right.locationCount] = oldLocation; | |
− | + | right.locationCount++; | |
− | + | right.extendBounds(oldLocation); | |
− | + | } | |
− | + | else { | |
− | + | // Left | |
− | + | left.locations[left.locationCount] = oldLocation; | |
− | + | left.locationCount++; | |
− | + | left.extendBounds(oldLocation); | |
− | + | } | |
− | + | } | |
− | + | // Make into stem | |
− | + | cursor.left = left; | |
− | + | cursor.right = right; | |
− | + | cursor.locations = null; | |
− | + | } | |
− | + | cursor.extendBounds(location); | |
− | + | if (location[cursor.splitDimension] > cursor.splitValue) { | |
− | + | cursor = cursor.right; | |
− | + | } | |
− | + | else { | |
− | + | cursor = cursor.left; | |
− | + | } | |
− | + | } | |
− | + | cursor.locations[cursor.locationCount] = location; | |
− | + | cursor.locationCount++; | |
− | + | cursor.extendBounds(location); | |
− | + | tree.map.put(location, value); | |
− | + | } | |
− | + | /** | |
− | + | * Enumeration representing the status of a node during the running | |
− | + | */ | |
− | + | private static enum Status { | |
− | + | NONE, | |
− | + | LEFTVISITED, | |
− | + | RIGHTVISITED, | |
− | + | ALLVISITED | |
− | + | } | |
− | + | /** | |
− | + | * Stores a distance and value to output | |
− | + | */ | |
− | + | public static class Entry<T> { | |
− | + | public final double distance; | |
− | + | public final T value; | |
− | + | private Entry(double distance, T value) { | |
− | + | this.distance = distance; | |
− | + | this.value = value; | |
− | + | } | |
− | + | } | |
− | + | /** | |
− | + | * Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions | |
− | + | */ | |
− | + | public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree, | |
− | + | double[] location, int count, double[] weights) { | |
− | + | tree.weights = weights; | |
− | + | return nearestNeighbor(tree, location, count); | |
− | + | } | |
− | + | /** | |
− | + | * Calculates the nearest 'count' points to 'location' | |
− | + | */ | |
− | + | public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree, | |
− | + | double[] location, int count) { | |
− | + | KdTree<T> cursor = tree; | |
− | + | Status status = Status.NONE; | |
− | + | Stack<KdTree<T>> stack = new Stack<KdTree<T>>(); | |
− | + | Stack<Status> statusStack = new Stack<Status>(); | |
− | + | double range = Double.POSITIVE_INFINITY; | |
− | + | ResultHeap resultHeap = new ResultHeap(count); | |
− | + | do { | |
− | + | if (status == Status.ALLVISITED) { | |
− | + | // At a fully visited part. Move up the tree | |
− | + | cursor = stack.pop(); | |
− | + | status = statusStack.pop(); | |
− | + | continue; | |
− | + | } | |
− | + | if (status == Status.NONE && cursor.locations != null) { | |
− | + | // At a leaf. Use the data. | |
− | + | for (int i=0; i<cursor.locationCount; i++) { | |
− | + | double dist = sqrPointDist(cursor.locations[i], location, tree.weights); | |
− | + | resultHeap.addValue(dist, cursor.locations[i]); | |
− | + | } | |
− | + | range = resultHeap.getMaxDist(); | |
− | + | if (stack.empty()) { | |
− | + | break; | |
− | + | } | |
− | + | cursor = stack.pop(); | |
− | + | status = statusStack.pop(); | |
− | + | continue; | |
− | + | } | |
− | + | // Going to descend | |
− | + | KdTree<T> nextCursor = null; | |
− | + | if (status == Status.NONE) { | |
− | + | // At a fresh node, descend the most probably useful direction | |
− | + | if (location[cursor.splitDimension] > cursor.splitValue) { | |
− | + | // Descend right | |
− | + | nextCursor = cursor.right; | |
− | + | status = Status.RIGHTVISITED; | |
− | + | } | |
− | + | else { | |
− | + | // Descend left; | |
− | + | nextCursor = cursor.left; | |
− | + | status = Status.LEFTVISITED; | |
− | + | } | |
− | + | } | |
− | + | else if (status == Status.LEFTVISITED) { | |
− | + | // Left node visited, descend right. | |
− | + | nextCursor = cursor.right; | |
− | + | status = Status.ALLVISITED; | |
− | + | } | |
− | + | else if (status == Status.RIGHTVISITED) { | |
− | + | // Right node visited, descend left. | |
− | + | nextCursor = cursor.left; | |
− | + | status = Status.ALLVISITED; | |
− | + | } | |
− | + | // Check if it's worth descending. Assume it is if it's sibling has not been visited yet. | |
− | + | if (status == Status.ALLVISITED) { | |
− | + | if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) { | |
− | + | continue; | |
− | + | } | |
− | + | } | |
− | + | // Descend down the tree | |
− | + | stack.push(cursor); | |
− | + | statusStack.push(status); | |
− | + | cursor = nextCursor; | |
− | + | status = Status.NONE; | |
− | + | } while (stack.size() > 0 || status != Status.ALLVISITED); | |
− | + | ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count); | |
− | + | Object[] data = resultHeap.getData(); | |
− | + | double[] dist = resultHeap.getDistances(); | |
− | + | for (int i=0; i<resultHeap.values; i++) { | |
− | + | T value = tree.map.get(data[i]); | |
− | + | results.add(new Entry<T>(dist[i], value)); | |
− | + | } | |
− | + | return results; | |
− | + | } | |
− | + | /** | |
− | + | * Calculates the (squared euclidean) distance between two points | |
− | + | */ | |
− | + | private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) { | |
− | + | double d = 0; | |
− | + | for (int i=0; i<p1.length; i++) { | |
− | + | double diff = (p1[i] - p2[i]) * weights[i]; | |
− | + | d += diff * diff; | |
− | + | } | |
− | + | return d; | |
− | + | } | |
− | + | /** | |
− | + | * Calculates the closest (squared euclidean) distance between in a point and a bounding region | |
− | + | */ | |
− | + | private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) { | |
− | + | double d = 0; | |
− | + | for (int i=0; i<point.length; i++) { | |
− | + | if (point[i] > max[i]) { | |
− | + | double diff = (point[i] - max[i]) * weights[i]; | |
− | + | d += diff * diff; | |
− | + | } else if (point[i] < min[i]) { | |
− | + | double diff = (point[i] - min[i]) * weights[i]; | |
− | + | d += diff * diff; | |
− | + | } | |
− | + | } | |
− | + | return d; | |
− | + | } | |
− | + | /** | |
− | + | * Class for tracking up to 'size' closest values | |
− | + | */ | |
− | + | private static class ResultHeap { | |
− | + | private final Object[] data; | |
− | + | private final double[] distance; | |
− | + | private final int size; | |
− | + | private int values; | |
− | + | public ResultHeap(int size) { | |
− | + | this.data = new Object[size+1]; | |
− | + | this.distance = new double[size+1]; | |
− | + | this.size = size; | |
− | + | this.values = 0; | |
− | + | } | |
− | + | public void addValue(double dist, Object value) { | |
− | + | if (values == size && dist >= distance[0]) { | |
− | + | return; | |
− | + | } | |
− | + | // Insert value | |
− | + | data[values] = value; | |
− | + | distance[values] = dist; | |
− | + | values++; | |
− | + | // Up-Heapify | |
− | + | for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) { | |
− | + | Object pData = data[p]; | |
− | + | double pDist = distance[p]; | |
− | + | data[p] = data[c]; | |
− | + | distance[p] = distance[c]; | |
− | + | data[c] = pData; | |
− | + | distance[c] = pDist; | |
− | + | } | |
− | + | // If too big, remove the highest value | |
− | + | if (values > size) { | |
− | + | // Move the last entry to the top | |
− | + | values--; | |
− | + | data[0] = data[values]; | |
− | + | distance[0] = distance[values]; | |
− | + | // Down-Heapify | |
− | + | for (int p = 0, c = 1; c < values; p = c,c = p*2+1) { | |
− | + | if (c+1 < values && distance[c] < distance[c+1]) { | |
− | + | c++; | |
− | + | } | |
− | + | if (distance[p] < distance[c]) { | |
− | + | // Swap the points | |
− | + | Object pData = data[p]; | |
− | + | double pDist = distance[p]; | |
− | + | data[p] = data[c]; | |
− | + | distance[p] = distance[c]; | |
− | + | data[c] = pData; | |
− | + | distance[c] = pDist; | |
− | + | } | |
− | + | else { | |
− | + | break; | |
− | + | } | |
− | + | } | |
− | + | } | |
− | + | } | |
− | + | public double getMaxDist() { | |
− | + | if (values < size) { | |
− | + | return Double.POSITIVE_INFINITY; | |
− | + | } | |
− | + | return distance[0]; | |
− | + | } | |
− | + | public Object[] getData() { | |
− | + | return data; | |
− | + | } | |
− | + | public double[] getDistances() { | |
− | + | return distance; | |
− | + | } | |
− | + | } | |
} | } | ||
+ | |||
</pre></code> | </pre></code> |
Revision as of 21:20, 26 August 2009
A nice efficent small kD-Tree. It's quite fast... Feel free to use
/**
* Copyright 2009 Rednaxela
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
*
* 2. This notice may not be removed or altered from any source
* distribution.
*/
package ags.utils.newtree2;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;
/**
* An efficent well-optimized kd-tree
*
* @author Rednaxela
*/
public class KdTree<T> {
// Static variables
private static final int bucketSize = 32;
// All types
private final int dimensions;
// Root only
private final HashMap<Object, T> map;
private double[] weights;
// Leaf only
private double[][] locations;
private int locationCount;
// Stem only
private KdTree<T> left, right;
private int splitDimension;
private double splitValue;
// Bounds
private double[] minLimit, maxLimit;
/**
* Extends the bounds of this node do include a new location
*/
private final void extendBounds(double[] location) {
if (minLimit == null) {
minLimit = Arrays.copyOf(location, dimensions);
maxLimit = Arrays.copyOf(location, dimensions);
return;
}
for (int i=0; i<dimensions; i++) {
if (minLimit[i] > location[i]) {
minLimit[i] = location[i];
}
if (maxLimit[i] < location[i]) {
maxLimit[i] = location[i];
}
}
}
/**
* Find the widest axis of the bounds of this node
*/
private final int findWidestAxis() {
int widest = 0;
double width = (maxLimit[0] - minLimit[0]);
for (int i = 1; i < dimensions; i++) {
double nwidth = maxLimit[i] - minLimit[i];
if (nwidth > width) {
widest = i;
width = nwidth;
}
}
return widest;
}
// Main constructor
public KdTree(int dimensions) {
this.dimensions = dimensions;
// Init as leaf
this.locations = new double[bucketSize][];
this.locationCount = 0;
// Init as root
this.map = new HashMap<Object, T>();
this.weights = new double[dimensions];
Arrays.fill(this.weights, 1.0);
}
// Child constructor
private KdTree(KdTree<T> parent, boolean right) {
this.dimensions = parent.dimensions;
// Init as leaf
this.locations = new double[bucketSize][];
this.locationCount = 0;
// Init as non-root
this.map = null;
}
/**
* Add a point and associated value to the tree
*/
public static <T> void addPoint(KdTree<T> tree, double[] location, T
value) {
KdTree<T> cursor = tree;
while (cursor.locations == null || cursor.locationCount >=
cursor.locations.length) {
if (cursor.locations != null) {
cursor.splitDimension = cursor.findWidestAxis();
cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
cursor.maxLimit[cursor.splitDimension]) * 0.5;
// Don't split node if it has no width in any axis. Double the bucket size instead
if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) {
cursor.locations = Arrays.copyOf(cursor.locations, cursor.locations.length * 2);
break;
}
// Create child leaves
KdTree<T> left = new KdTree<T>(cursor, false);
KdTree<T> right = new KdTree<T>(cursor, true);
// Move locations into children
for (double[] oldLocation : cursor.locations) {
if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
// Right
right.locations[right.locationCount] = oldLocation;
right.locationCount++;
right.extendBounds(oldLocation);
}
else {
// Left
left.locations[left.locationCount] = oldLocation;
left.locationCount++;
left.extendBounds(oldLocation);
}
}
// Make into stem
cursor.left = left;
cursor.right = right;
cursor.locations = null;
}
cursor.extendBounds(location);
if (location[cursor.splitDimension] > cursor.splitValue) {
cursor = cursor.right;
}
else {
cursor = cursor.left;
}
}
cursor.locations[cursor.locationCount] = location;
cursor.locationCount++;
cursor.extendBounds(location);
tree.map.put(location, value);
}
/**
* Enumeration representing the status of a node during the running
*/
private static enum Status {
NONE,
LEFTVISITED,
RIGHTVISITED,
ALLVISITED
}
/**
* Stores a distance and value to output
*/
public static class Entry<T> {
public final double distance;
public final T value;
private Entry(double distance, T value) {
this.distance = distance;
this.value = value;
}
}
/**
* Calculates the nearest 'count' points to 'location', with an arbitrary weighting on dimensions
*/
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
double[] location, int count, double[] weights) {
tree.weights = weights;
return nearestNeighbor(tree, location, count);
}
/**
* Calculates the nearest 'count' points to 'location'
*/
public static <T> List<Entry<T>> nearestNeighbor(KdTree<T> tree,
double[] location, int count) {
KdTree<T> cursor = tree;
Status status = Status.NONE;
Stack<KdTree<T>> stack = new Stack<KdTree<T>>();
Stack<Status> statusStack = new Stack<Status>();
double range = Double.POSITIVE_INFINITY;
ResultHeap resultHeap = new ResultHeap(count);
do {
if (status == Status.ALLVISITED) {
// At a fully visited part. Move up the tree
cursor = stack.pop();
status = statusStack.pop();
continue;
}
if (status == Status.NONE && cursor.locations != null) {
// At a leaf. Use the data.
for (int i=0; i<cursor.locationCount; i++) {
double dist = sqrPointDist(cursor.locations[i], location, tree.weights);
resultHeap.addValue(dist, cursor.locations[i]);
}
range = resultHeap.getMaxDist();
if (stack.empty()) {
break;
}
cursor = stack.pop();
status = statusStack.pop();
continue;
}
// Going to descend
KdTree<T> nextCursor = null;
if (status == Status.NONE) {
// At a fresh node, descend the most probably useful direction
if (location[cursor.splitDimension] > cursor.splitValue) {
// Descend right
nextCursor = cursor.right;
status = Status.RIGHTVISITED;
}
else {
// Descend left;
nextCursor = cursor.left;
status = Status.LEFTVISITED;
}
}
else if (status == Status.LEFTVISITED) {
// Left node visited, descend right.
nextCursor = cursor.right;
status = Status.ALLVISITED;
}
else if (status == Status.RIGHTVISITED) {
// Right node visited, descend left.
nextCursor = cursor.left;
status = Status.ALLVISITED;
}
// Check if it's worth descending. Assume it is if it's sibling has not been visited yet.
if (status == Status.ALLVISITED) {
if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, tree.weights) > range) {
continue;
}
}
// Descend down the tree
stack.push(cursor);
statusStack.push(status);
cursor = nextCursor;
status = Status.NONE;
} while (stack.size() > 0 || status != Status.ALLVISITED);
ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
Object[] data = resultHeap.getData();
double[] dist = resultHeap.getDistances();
for (int i=0; i<resultHeap.values; i++) {
T value = tree.map.get(data[i]);
results.add(new Entry<T>(dist[i], value));
}
return results;
}
/**
* Calculates the (squared euclidean) distance between two points
*/
private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
double d = 0;
for (int i=0; i<p1.length; i++) {
double diff = (p1[i] - p2[i]) * weights[i];
d += diff * diff;
}
return d;
}
/**
* Calculates the closest (squared euclidean) distance between in a point and a bounding region
*/
private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
double d = 0;
for (int i=0; i<point.length; i++) {
if (point[i] > max[i]) {
double diff = (point[i] - max[i]) * weights[i];
d += diff * diff;
} else if (point[i] < min[i]) {
double diff = (point[i] - min[i]) * weights[i];
d += diff * diff;
}
}
return d;
}
/**
* Class for tracking up to 'size' closest values
*/
private static class ResultHeap {
private final Object[] data;
private final double[] distance;
private final int size;
private int values;
public ResultHeap(int size) {
this.data = new Object[size+1];
this.distance = new double[size+1];
this.size = size;
this.values = 0;
}
public void addValue(double dist, Object value) {
if (values == size && dist >= distance[0]) {
return;
}
// Insert value
data[values] = value;
distance[values] = dist;
values++;
// Up-Heapify
for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
Object pData = data[p];
double pDist = distance[p];
data[p] = data[c];
distance[p] = distance[c];
data[c] = pData;
distance[c] = pDist;
}
// If too big, remove the highest value
if (values > size) {
// Move the last entry to the top
values--;
data[0] = data[values];
distance[0] = distance[values];
// Down-Heapify
for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
if (c+1 < values && distance[c] < distance[c+1]) {
c++;
}
if (distance[p] < distance[c]) {
// Swap the points
Object pData = data[p];
double pDist = distance[p];
data[p] = data[c];
distance[p] = distance[c];
data[c] = pData;
distance[c] = pDist;
}
else {
break;
}
}
}
}
public double getMaxDist() {
if (values < size) {
return Double.POSITIVE_INFINITY;
}
return distance[0];
}
public Object[] getData() {
return data;
}
public double[] getDistances() {
return distance;
}
}
}