User:Duyn/BucketKdTree
Jump to navigation
Jump to search
Below is the full source code to a tree that will form the basis for a future kd tree tutorial I'm working on. This code is released under the RWPCL. It is decently optimised—theoretically, anyway. I don't specialise in writing fast code, but the algorithms used are designed to ensure whole trees are ruled out as early as possible.
Exemplar.java
Base class for items to be added into the tree. Sub-classes can carry useful data like guess factors.
import java.util.Arrays; /** * A sample point in multi-dimensional space. Needed because each sample * may contain an arbitrary payload. * * Note: this class does not make allowance for a payload. Sub-class if * you want to store something more than just data points. * * @author dnn */ public class Exemplar { public final double[] domain; public Exemplar(double[] domain) { this.domain = domain; } public boolean domainEquals(final Exemplar other) { return Arrays.equals(domain, other.domain); } @Override public String toString() { return Arrays.toString(domain); } }
BucketKdTree.java
Tree with k-nearest neighbour search. Does not support deletion or rebalancing—a re-build is required if you want to do either.
import java.util.*; /** * A k-dimensional binary partitioning tree which splits space on the * mean of the dimension with the largest variance. Points are held in * buckets so we can pick a better split point than whatever comes first. * * Does not store tree depth. If you want balance, re-build the tree * periodically. * * Optimisations in this tree assume distance metric is euclidian distance. * May work if retrofitted with other metrics, but that is purely * accidental. * * Note: results can become unpredictable if values are different but so * close together that rounding errors in computing their mean result in * all exemplars being on one side of the mean. Performance degrades when * this occurs. Nearest neighbour search tested to work up to range * [1, 1 + 5e-16). * * Ideas for path ordering and bounds come from: * NEAL SAMPLE, MATTHEW HAINES, MARK ARNOLD, TIMOTHY PURCELL, * 'Optimizing Search Strategies in k-d Trees' * http://ilpubs.stanford.edu:8090/723/ * * Computation of variance from: * John Cook, 'Accurately computing running variance' * http://www.johndcook.com/standard_deviation.html * * Terminology note: points are called Exemplars. They must all be * descended from Exemplar class. Position in k-d space is stored in each * exemplar's domain member. This is to avoid conflicting with already * existing classes referring to geometric points. * * Terminology comes from: * Andrew Moore, 'An intoductory tutorial on kd' * http://www.autonlab.org/autonweb/14665 * * @author dnn */ public class BucketKdTree<T extends Exemplar> { // Only leaf nodes contain points private List<T> exemplars = new LinkedList<T>(); // These aren't initialised until add() is called. private double[] exMean; private double[] exSumSqDev; // Optimisation when tree contains large number of duplicates private boolean exemplarsAreUniform = true; private int bucketSize; private BucketKdTree<T> left, right; private int splitDim; private double split; private int dimensions = 0; // Optimisation for searches. This lets us skip a node if its // scope intersects with a search hypersphere but it doesn't contain // any points that actually intersect. private double[] maxBounds; private double[] minBounds; public BucketKdTree(int bucketSize) { this.bucketSize = bucketSize; } // // PUBLIC METHODS // public void add(T ex) { BucketKdTree<T> tree = addNoSplit(this, ex); if (shouldSplit(tree)) { split(tree); } } public void addAll(Collection<T> exs) { // Some spurious function calls. Optimised for readability over // efficiency. final Set<BucketKdTree<T>> modTrees = new HashSet<BucketKdTree<T>>(); for(T ex : exs) { modTrees.add(addNoSplit(this, ex)); } for(BucketKdTree<T> tree : modTrees) { if (shouldSplit(tree)) { split(tree); } } } public SortedMap<Double, List<T>> search(double[] query, int nMinResults) { // Forward to a static method to avoid accidental reference to // instance variables while descending the tree return search(this, query, nMinResults); } @Override public String toString() { return toString(""); } // // IMPLEMENTATION DETAILS // private boolean isTree() { return left != null; } // Addition // Adds an exemplar without splitting overflowing leaves. // Returns leaf to which exemplar was added. private static <T extends Exemplar> BucketKdTree<T> addNoSplit(BucketKdTree<T> tree, T ex) { // Some spurious function calls. Optimised for readability over // efficiency. BucketKdTree<T> cursor = tree; while (cursor != null) { updateBounds(cursor, ex); if (cursor.isTree()) { // Sub-tree cursor = ex.domain[cursor.splitDim] <= cursor.split ? cursor.left : cursor.right; } else { // Leaf // Infer dimensions if we haven't already if (cursor.dimensions == 0) cursor.dimensions = ex.domain.length; // Add exemplar to leaf cursor.exemplars.add(ex); // Calculate running mean and sum of squared deviations final int nExs = cursor.exemplars.size(); final int dims = cursor.dimensions; if (nExs == 1) { cursor.exMean = Arrays.copyOf(ex.domain, dims); cursor.exSumSqDev = new double[dims]; } else { for(int d = 0; d < dims; d++) { final double coord = ex.domain[d]; final double oldMean = cursor.exMean[d], newMean; cursor.exMean[d] = newMean = oldMean + (coord - oldMean)/nExs; cursor.exSumSqDev[d] = cursor.exSumSqDev[d] + (coord - oldMean)*(coord - newMean); } } // Check that exemplars are still uniform if (cursor.exemplarsAreUniform) { final List<T> cExs = cursor.exemplars; if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0))) cursor.exemplarsAreUniform = false; } // Finished walking return cursor; } } throw new RuntimeException("Walked tree without adding anything"); } private static <T extends Exemplar> void updateBounds(BucketKdTree<T> tree, Exemplar ex) { final int dims = ex.domain.length; if (tree.maxBounds == null) { tree.maxBounds = Arrays.copyOf(ex.domain, dims); tree.minBounds = Arrays.copyOf(ex.domain, dims); } else { for(int d = 0; d < dims; d++) { final double dimVal = ex.domain[d]; if (dimVal > tree.maxBounds[d]) tree.maxBounds[d] = dimVal; else if (dimVal < tree.minBounds[d]) tree.minBounds[d] = dimVal; } } } // Splitting (internal operation) private static <T extends Exemplar> boolean shouldSplit(BucketKdTree<T> tree) { return tree.exemplars.size() > tree.bucketSize && !tree.exemplarsAreUniform; } @SuppressWarnings("unchecked") private static <T extends Exemplar> void split(BucketKdTree<T> tree) { assert !tree.exemplarsAreUniform; // Find dimension with largest variance to split on double largestVar = -1; int splitDim = 0; for(int d = 0; d < tree.dimensions; d++) { // Don't need to divide by number of exemplars to find largest // variance final double var = tree.exSumSqDev[d]; if (var > largestVar) { largestVar = var; splitDim = d; } } // Find mean as position for our split double splitValue = tree.exMean[splitDim]; // Check that our split actually splits our data. This also lets // us bulk load exemplars into sub-trees, which is more likely // to keep optimal balance. final List<T> leftExs = new LinkedList<T>(); final List<T> rightExs = new LinkedList<T>(); for(T s : tree.exemplars) { if (s.domain[splitDim] <= splitValue) leftExs.add(s); else rightExs.add(s); } int leftSize = leftExs.size(); final int treeSize = tree.exemplars.size(); if (leftSize == treeSize || leftSize == 0) { System.err.println( "WARNING: Randomly splitting non-uniform tree"); // We know the exemplars aren't all the same, so try picking // an exemplar and a dimension at random for our split point // This might take several tries, so we copy our exemplars to // an array to speed up process of picking a random point Object[] exs = tree.exemplars.toArray(); while (leftSize == treeSize || leftSize == 0) { leftExs.clear(); rightExs.clear(); splitDim = (int) Math.floor(Math.random()*tree.dimensions); final int splitPtIdx = (int) Math.floor(Math.random()*exs.length); // Cast is inevitable consequence of java's inability to // create a generic array splitValue = ((T)exs[splitPtIdx]).domain[splitDim]; for(T s : tree.exemplars) { if (s.domain[splitDim] <= splitValue) leftExs.add(s); else rightExs.add(s); } leftSize = leftExs.size(); } } // We have found a valid split. Start building our sub-trees final BucketKdTree<T> left = new BucketKdTree<T>(tree.bucketSize); final BucketKdTree<T> right = new BucketKdTree<T>(tree.bucketSize); left.addAll(leftExs); right.addAll(rightExs); // Finally, commit the split tree.splitDim = splitDim; tree.split = splitValue; tree.left = left; tree.right = right; // Let go of exemplars (and their running stats) held in this leaf tree.exemplars = null; tree.exMean = tree.exSumSqDev = null; } // Searching // May return more results than requested if multiple exemplars have // same distance from target. // // Note: this function works with squared distances to avoid sqrt() // operations private static <T extends Exemplar> SortedMap<Double, List<T>> search(BucketKdTree<T> tree, double[] query, int nMinResults) { // distance => list of points that distance away from query final NavigableMap<Double, List<T>> results = new TreeMap<Double, List<T>>(); final SearchState state = new SearchState(); final Deque<SearchStackEntry<T>> stack = new LinkedList<SearchStackEntry<T>>(); stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree)); while (!stack.isEmpty()) { final SearchStackEntry<T> entry = stack.removeFirst(); final BucketKdTree<T> cur = entry.tree; if (cur.isTree()) { searchTree(query, nMinResults, cur, state, stack); } else if (entry.minDFromQ <= state.maxDistance || state.nResults < nMinResults) { searchLeaf(query, nMinResults, cur, state, results); } } return results; } private static <T extends Exemplar> void searchTree(double[] query, int nMinResults, BucketKdTree<T> tree, SearchState searchState, Deque<SearchStackEntry<T>> stack) { // Left is presumed near. This is verified further down. BucketKdTree<T> nearTree = tree.left, farTree = tree.right; // These variables let us skip empty sub-trees boolean nearEmpty = nearTree.minBounds == null; boolean farEmpty = farTree.minBounds == null; // Find distance from nearest possible point in each // sub-tree to query. If that is greater than max distance, // we can rule out that sub-tree. double nearD = nearEmpty ? Double.POSITIVE_INFINITY : minDistanceSqFrom(query, nearTree.minBounds, nearTree.maxBounds); double farD = farEmpty ? Double.POSITIVE_INFINITY : minDistanceSqFrom(query, farTree.minBounds, farTree.maxBounds); // Swap near and far if they're incorrect if (farD < nearD) { final double tmpD = nearD; final BucketKdTree<T> tmpTree = nearTree; final boolean tmpEmpty = nearEmpty; nearD = farD; nearTree = farTree; nearEmpty = farEmpty; farD = tmpD; farTree = tmpTree; farEmpty = tmpEmpty; } // Add nearest sub-tree to stack later so we descend it // first. This is likely to constrict our max distance // sooner, resulting in less visited nodes if (!farEmpty && (farD <= searchState.maxDistance || searchState.nResults < nMinResults)) { stack.addFirst(new SearchStackEntry<T>(farD, farTree)); } if (!nearEmpty && (nearD <= searchState.maxDistance || searchState.nResults < nMinResults)) { stack.addFirst(new SearchStackEntry<T>(nearD, nearTree)); } } private static <T extends Exemplar> void searchLeaf(double[] query, int nMinResults, BucketKdTree<T> leaf, SearchState searchState, NavigableMap<Double, List<T>> results) { // Keep track of elements at max distance so we know // whether we can just drop entire list of furthest // exemplars final int nMinResultsBeforeAddition = nMinResults - 1; for(T ex : leaf.exemplars) { final double exD = distanceSqFrom(query, ex.domain); if (searchState.nResults < nMinResults || exD == searchState.maxDistance) { // Blindly add this exemplar List<T> exsAtD = getOrElseInit(results, exD); if (leaf.exemplarsAreUniform) { // No need to go through every one if all // exemplars are the same exsAtD.addAll(leaf.exemplars); searchState.nResults += leaf.exemplars.size(); } else { exsAtD.add(ex); searchState.nResults++; } // Update information about furthest exemplars if (exD > searchState.maxDistance || searchState.maxDistance == Double.POSITIVE_INFINITY) { searchState.maxDistance = exD; } if (searchState.maxDistance == exD) searchState.nExsAtMaxD = exsAtD.size(); } else if (exD < searchState.maxDistance) { // Point closer than furthest neighbour if (searchState.nResults - searchState.nExsAtMaxD >= nMinResultsBeforeAddition) { // Dropping furthest exemplars won't leave // us with too little to meet return // minimum results.remove(searchState.maxDistance); searchState.nResults -= searchState.nExsAtMaxD; } // Add new exemplar List<T> exsAtD = getOrElseInit(results, exD); if (leaf.exemplarsAreUniform) { // No need to go through every one if all // exemplars are the same exsAtD.addAll(leaf.exemplars); searchState.nResults += leaf.exemplars.size(); } else { exsAtD.add(ex); searchState.nResults++; } // Update information about furthest exemplars Map.Entry<Double, List<T>> lastEnt = results.lastEntry(); searchState.maxDistance = lastEnt.getKey(); searchState.nExsAtMaxD = lastEnt.getValue().size(); } // exD < maxDistance // No need to keep going if all exemplars are // the same if (leaf.exemplarsAreUniform) break; } // for(T ex : cur.exemplars) } private static <T extends Exemplar> List<T> getOrElseInit(Map<Double, List<T>> results, double dst) { List<T> lst = results.get(dst); if (lst == null) { lst = new LinkedList<T>(); results.put(dst, lst); } return lst; } // Dumping to string (debug) private String toString(String indent) { if (isTree()){ return String.format("%s{|%d|%f|\n%s\n%s} ", indent, splitDim, split, left.toString(indent + "\t"), right.toString(indent + "\t")); } else { return String.format("%sL%s", indent, Arrays.toString(exemplars.toArray())); } } // Distance calculations // Gets distance from target of nearest point on hyper-rect defined // by supplied min and max bounds private static double minDistanceSqFrom(double[] target, double[] min, double[] max) { // Note: profiling shows this is called lots of times, so it pays // to be well optimised double distanceSq = 0; for(int d = 0; d < target.length; d++) { final double coord = target[d]; double nearCoord; if (((nearCoord = min[d]) > coord) || ((nearCoord = max[d]) < coord)) { final double dst = nearCoord - coord; distanceSq += dst*dst; } } return distanceSq; } // Accessible to testing static double distanceSqFrom(double[] p1, double[] p2) { // Note: profiling shows this is called lots of times, so it pays // to be well optimised double dSq = 0; for(int d = 0; d < p1.length; d++) { final double dst = p1[d] - p2[d]; if (dst != 0) dSq += dst*dst; } return dSq; } // // class SearchStackEntry // // Stores a precomputed distance so we don't have to do it again // when we pop the tree off the search stack. private static class SearchStackEntry<T extends Exemplar> { public final double minDFromQ; public final BucketKdTree<T> tree; public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) { this.minDFromQ = minDFromQ; this.tree = tree; } } // // class SearchState // // Holds data about current state of the search. Used for live updating // of pruning distance. private static class SearchState { int nResults = 0; double maxDistance = Double.POSITIVE_INFINITY; int nExsAtMaxD = 0; } }