User:Duyn/kd-tree Tutorial
[this is the beginning of a tutorial on writing a k-d tree].
This tutorial will walk you through the process of writing a k-d tree for k-nearest neighbour search. The tree here will hold multiple items in each leaf, splitting only when a leaf overflows. It will split on the mean of the dimension with the largest variance. There are already several k-d trees on this wiki, see some of the others for ideas.
A k-d tree is a binary tree which successively splits a k-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A k-nearest neighbour search can be implemented in from a nearest neighbour algorithm by not shrinking the search radius until k items have been found. The rest of this tutorial will refer to both as nearest neighbour queries.
Contents
An Exemplar class
We will call each item in our k-d tree an exemplar. Each exemplar has a domain—its spatial co-ordinates in the k-d tree's space. Each exemplar could also have an arbitrary payload, but our tree does not need to know about that. It will only handle storing exemplars based on their domain and returning them in a nearest neighbour search.
You might already have a class somewhere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.
public class Exemplar { public final double[] domain; public Exemplar(final double[] coords) { this.domain = coords; } // Short hand. Shorter than calling Arrays.equals() each time. public boolean domainEquals(final Exemplar other) { return Arrays.equals(domain, other.domain); } }
While this class is fully usable as is, rarely will you be interested in just the domain of nearest neighbours in a search. It is expected that specific data (eg. guess factors) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.
Basic Tree
Here is a basic tree structure:
public class BucketKdTree<T extends Exemplar> { private List<T> exemplars = new LinkedList<T>(); private BucketKdTree<T> left, right; private int bucketSize; private final int dimensions; public BucketKdTree(int bucketSize, int dimensions) { this.bucketSize = bucketSize; this.dimensions = dimensions; } private boolean isTree() { return left != null; } }
Each tree is either a tree with both left and right sub-trees defined, or a leaf with exemplars filled. Because of our splitting algorithm, it is pointless to allow a tree to be both since the mean might not correspond with any actual exemplars.
Bucket size and dimensions must be passed into the constructor. We could infer dimension from the dimension of the first point added, but this is simpler. Bucket size is not final because theoretically it could be varied, though our implementation will not.
Adding
Each of the public API functions defers the actual addition to another private static function. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.
We decide whether to split a leaf only after the add has been completed.
// One at a time public void add(T ex) { BucketKdTree<T> tree = addNoSplit(this, ex); if (shouldSplit(tree)) { split(tree); } } // Bulk add gives us more data to choose a better split point public void addAll(Collection<T> exs) { // Some spurious function calls. Optimised for readability over // efficiency. final Set<BucketKdTree<T>> modTrees = new HashSet<BucketKdTree<T>>(); for(T ex : exs) { modTrees.add(addNoSplit(this, ex)); } for(BucketKdTree<T> tree : modTrees) { if (shouldSplit(tree)) { split(tree); } } }
To add an exemplar, we traverse the tree from top down until we find a leaf. Then we add the exemplar to the list at that leaf.
To decide which sub-tree to traverse, each tree stores two values—a splitting dimension and a splitting value. If our new exemplar's domain along the splitting dimension is greater than the tree's splitting value, we put it in the right sub-tree. Otherwise, it goes in the left one.
Since adding takes little time compared to searching, we take this opportunity to make some optimisations:
- We keep track of the actual hyperrect the points in this tree occupy. This lets us rule out a tree, even though its space may intersect with our search sphere, if it doesn't actually contain any points within the hyperrect bounding our search sphere. This hyperrect is defined by
maxBounds
andminBounds
.
- To save us having to do a full iteration when we come to split a leaf, we compute the running mean and variance for each dimension using Welford's method:
<math> M_1 = x_1,\qquad S_1 = 0</math> <math> M_k = M_{k-1} + {x_k - M_{k-1} \over k} \qquad(exMean)</math> <math> S_k = S_{k-1} + (x_k - M_{k-1})\times(x_k - M_k) \qquad(exSumSqDev)</math>
- We keep track of whether all exemplars in this leaf have the same domain. If they do, we know that comparisons on one exemplar apply to all exemplars in that leaf.
The final adding code:
private int splitDim; private double split; // These aren't initialised until add() is called. private double[] exMean; private double[] exSumSqDev; // Optimisation when tree contains large number of duplicates private boolean exemplarsAreUniform = true; // Optimisation for searches. This lets us skip a node if its // scope intersects with a search hypersphere but it doesn't contain // any points that actually intersect. private double[] maxBounds; private double[] minBounds; // Adds an exemplar without splitting overflowing leaves. // Returns leaf to which exemplar was added. private static <T extends Exemplar> BucketKdTree<T> addNoSplit(BucketKdTree<T> tree, T ex) { // Some spurious function calls. Optimised for readability over // efficiency. BucketKdTree<T> cursor = tree; while (cursor != null) { updateBounds(cursor, ex); if (cursor.isTree()) { // Sub-tree cursor = ex.domain[cursor.splitDim] <= cursor.split ? cursor.left : cursor.right; } else { // Leaf cursor.exemplars.add(ex); final int nExs = cursor.exemplars.size(); if (nExs == 1) { cursor.exMean = Arrays.copyOf(ex.domain, cursor.dimensions); cursor.exSumSqDev = new double[cursor.dimensions]; } else { for(int d = 0; d < cursor.dimensions; d++) { final double coord = ex.domain[d]; final double oldExMean = cursor.exMean[d]; final double newMean = cursor.exMean[d] = oldExMean + (coord - oldExMean)/nExs; final double oldSumSqDev = cursor.exSumSqDev[d]; cursor.exSumSqDev[d] = oldSumSqDev + (coord - oldExMean)*(coord - newMean); } } if (cursor.exemplarsAreUniform) { final List<T> cExs = cursor.exemplars; if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0))) cursor.exemplarsAreUniform = false; } return cursor; } } throw new RuntimeException("Walked tree without adding anything"); } private static <T extends Exemplar> void updateBounds(BucketKdTree<T> tree, Exemplar ex) { final int dims = tree.dimensions; if (tree.maxBounds == null) { tree.maxBounds = Arrays.copyOf(ex.domain, dims); tree.minBounds = Arrays.copyOf(ex.domain, dims); } else { for(int d = 0; d < dims; d++) { final double dimVal = ex.domain[d]; if (dimVal > tree.maxBounds[d]) tree.maxBounds[d] = dimVal; else if (dimVal < tree.minBounds[d]) tree.minBounds[d] = dimVal; } } }
Splitting
We only split when a leaf's exemplars exceed its bucket size. It is only worth splitting if the exemplars don't all have the same domain.
private static <T extends Exemplar> boolean shouldSplit(BucketKdTree<T> tree) { return tree.exemplars.size() > tree.bucketSize && !tree.exemplarsAreUniform; }
Thanks to our pre-computation, splitting is straight-forward—most of the time. We iterate through each dimension to find the one with the largest variance (skip the unnecessary division), then we can directly look up the mean of that dimension.
To make sure the point actually does divide our data, we separate our data into two lists destined for each sub-tree. If all the exemplars end up in only one of the lists, then our split point has failed to actually separate our exemplars. This is most likely due to rounding error when our exemplars are really close together. At a loss for what to do, we simply pick a random point and a random dimension until we find something that parts our exemplars. We know we must find one eventually because our exemplars are not uniform—at least one of them is smaller in at least one dimension than all the others.
Finally, we bulk load our sub-trees, store information about our split point and let go of the exemplars stored in the tree.
@SuppressWarnings("unchecked") private static <T extends Exemplar> void split(BucketKdTree<T> tree) { assert !tree.exemplarsAreUniform; // Find dimension with largest variance to split on double largestVar = -1; int splitDim = 0; for(int d = 0; d < tree.dimensions; d++) { final double var = tree.exSumSqDev[d]/(tree.exemplars.size() - 1); if (var > largestVar) { largestVar = var; splitDim = d; } } // Find mean as position for our split double splitValue = tree.exMean[splitDim]; // Check that our split actually splits our data. This also lets // us bulk load exemplars into sub-trees, which is more likely // to keep optimal balance. final List<T> leftExs = new LinkedList<T>(); final List<T> rightExs = new LinkedList<T>(); for(T s : tree.exemplars) { if (s.domain[splitDim] <= splitValue) leftExs.add(s); else rightExs.add(s); } int leftSize = leftExs.size(); final int treeSize = tree.exemplars.size(); if (leftSize == treeSize || leftSize == 0) { System.err.println( "WARNING: Randomly splitting non-uniform tree"); // We know the exemplars aren't all the same, so try picking // an exemplar and a dimension at random for our split point // This might take several tries, so we copy our exemplars to // an array to speed up process of picking a random point Object[] exs = tree.exemplars.toArray(); while (leftSize == treeSize || leftSize == 0) { leftExs.clear(); rightExs.clear(); splitDim = (int) Math.floor(Math.random()*tree.dimensions); final int splitPtIdx = (int) Math.floor(Math.random()*exs.length); // Cast is inevitable consequence of java's inability to // create a generic array splitValue = ((T)exs[splitPtIdx]).domain[splitDim]; for(T s : tree.exemplars) { if (s.domain[splitDim] <= splitValue) leftExs.add(s); else rightExs.add(s); } leftSize = leftExs.size(); } } // We have found a valid split. Start building our sub-trees final BucketKdTree<T> left = new BucketKdTree<T>(tree.bucketSize, tree.dimensions); final BucketKdTree<T> right = new BucketKdTree<T>(tree.bucketSize, tree.dimensions); left.addAll(leftExs); right.addAll(rightExs); // Finally, commit the split tree.splitDim = splitDim; tree.split = splitValue; tree.left = left; tree.right = right; // Let go of exemplars (and their running stats) held in this leaf tree.exemplars = null; tree.exMean = tree.exSumSqDev = null; }
Searching
Before we can start searching, we need to define our distance metric. This whole tutorial has been assuming euclidian distance will be used. Optimisations may have to be revised if you are using a different distance metric.
// Gets distance from target of nearest point on hyper-rect defined // by supplied min and max bounds private static double minDistanceSqFrom(double[] target, double[] min, double[] max) { // Note: profiling shows this is called lots of times, so it pays // to be well optimised double distanceSq = 0; for(int d = 0; d < target.length; d++) { final double coord = target[d]; double nearCoord; if (((nearCoord = min[d]) > coord) || ((nearCoord = max[d]) < coord)) { final double dst = nearCoord - coord; distanceSq += dst*dst; } } return distanceSq; } // Accessible to testing static double distanceSqFrom(double[] p1, double[] p2) { // Note: profiling shows this is called lots of times, so it pays // to be well optimised double dSq = 0; for(int d = 0; d < p1.length; d++) { final double dst = p1[d] - p2[d]; if (dst != 0) dSq += dst*dst; } return dSq; }
Before we dive into searching, we will introduce two helper classes.
SearchStackEntry
stores a tree along with its minimum distance from the query point. We put these on our tree-walking stack instead of just trees so we don't have to re-compute this minimum distance when we next pop the tree.SearchState
holds some variables we want to be updated in after each nearest neighbour is added. Its sole purpose is to let us separate searching sub-trees and leaves into different methods without sacrificing performance.
// // class SearchStackEntry // // Stores a precomputed distance so we don't have to do it again // when we pop the tree off the search stack. private static class SearchStackEntry<T extends Exemplar> { public final double minDFromQ; public final BucketKdTree<T> tree; public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) { this.minDFromQ = minDFromQ; this.tree = tree; } } // // class SearchState // // Holds data about current state of the search. Used for live updating // of pruning distance. private static class SearchState { int nResults = 0; double maxDistance = Double.POSITIVE_INFINITY; int nExsAtMaxD = 0; }
Our search method continuously pops our search stack, decides whether it's looking at a sub-tree or a leaf and passes the call on as appropriate.
public SortedMap<Double, List<T>> search(double[] query, int nMinResults) { // Forward to a static method to avoid accidental reference to // instance variables while descending the tree return search(this, query, nMinResults); } // May return more results than requested if multiple exemplars have // same distance from target. // // Note: this function works with squared distances to avoid sqrt() // operations private static <T extends Exemplar> SortedMap<Double, List<T>> search(BucketKdTree<T> tree, double[] query, int nMinResults) { // distance => list of points that distance away from query final NavigableMap<Double, List<T>> results = new TreeMap<Double, List<T>>(); final SearchState state = new SearchState(); final Deque<SearchStackEntry<T>> stack = new LinkedList<SearchStackEntry<T>>(); stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree)); while (!stack.isEmpty()) { final SearchStackEntry<T> entry = stack.removeFirst(); final BucketKdTree<T> cur = entry.tree; if (cur.isTree()) { searchTree(query, nMinResults, cur, state, stack); } else if (entry.minDFromQ <= state.maxDistance || state.nResults < nMinResults) { searchLeaf(query, nMinResults, cur, state, results); } } return results; }
Searching a sub-tree
Here is where we see two of our optimisations come into play. For each non-empty sub-tree, we compute the minimum distance from the smallest hyperrect bounding that sub-tree's exemplars to our query point. If this distance exceeds our maximum distance, we can rule out the entire sub-tree. Of course, if we haven't found our k-potential nearest neighbours yet, then both sub-trees may contain potential nearest neighbours so we will add them to the search stack anyway.
As a heuristic, we want to search the sub-tree with an edge closest to the query point first. This is more likely to lead to earlier contraction of our search distance. This is why we add the farthest tree to our stack first.
private static <T extends Exemplar> void searchTree(double[] query, int nMinResults, BucketKdTree<T> tree, SearchState searchState, Deque<SearchStackEntry<T>> stack) { // Left is presumed near. This is verified further down. BucketKdTree<T> nearTree = tree.left, farTree = tree.right; // These variables let us skip empty sub-trees boolean nearEmpty = nearTree.minBounds == null; boolean farEmpty = farTree.minBounds == null; // Find distance from nearest possible point in each // sub-tree to query. If that is greater than max distance, // we can rule out that sub-tree. double nearD = nearEmpty ? Double.POSITIVE_INFINITY : minDistanceSqFrom(query, nearTree.minBounds, nearTree.maxBounds); double farD = farEmpty ? Double.POSITIVE_INFINITY : minDistanceSqFrom(query, farTree.minBounds, farTree.maxBounds); // Swap near and far if they're incorrect if (farD < nearD) { final double tmpD = nearD; final BucketKdTree<T> tmpTree = nearTree; final boolean tmpEmpty = nearEmpty; nearD = farD; nearTree = farTree; nearEmpty = farEmpty; farD = tmpD; farTree = tmpTree; farEmpty = tmpEmpty; } // Add nearest sub-tree to stack later so we descend it // first. This is likely to constrict our max distance // sooner, resulting in less visited nodes if (!farEmpty && (farD <= searchState.maxDistance || searchState.nResults < nMinResults)) { stack.addFirst(new SearchStackEntry<T>(farD, farTree)); } if (!nearEmpty && (nearD <= searchState.maxDistance || searchState.nResults < nMinResults)) { stack.addFirst(new SearchStackEntry<T>(nearD, nearTree)); } }
Searching a leaf
When we reach a leaf, we finally have a chance to contract our search distance. If we haven't yet found our k-potential nearest neighbours yet, or if the exemplar we're checking lies on the edge of our max search distance, we blindly add exemplars and update our max search distance.
Otherwise, we check the exemplar's distance from our query point. If it's closer than our furthest potential match, we drop our furthest result (unless doing so would result in too little k-potential neighbours) to make room for our new neighbour. Finally, we update our max search distance so we can skip any trees or points outside our new search distance.
If all the exemplars in a leaf are uniform, we don't need to do this check for every exemplar. We just do it for the first one, and either add or reject them all.
private static <T extends Exemplar> void searchLeaf(double[] query, int nMinResults, BucketKdTree<T> leaf, SearchState searchState, NavigableMap<Double, List<T>> results) { // Keep track of elements at max distance so we know // whether we can just drop entire list of furthest // exemplars final int nMinResultsBeforeAddition = nMinResults - 1; for(T ex : leaf.exemplars) { final double exD = distanceSqFrom(query, ex.domain); if (searchState.nResults < nMinResults || exD == searchState.maxDistance) { // Blindly add this exemplar List<T> exsAtD = getOrElseInit(results, exD); if (leaf.exemplarsAreUniform) { // No need to go through every one if all // exemplars are the same exsAtD.addAll(leaf.exemplars); searchState.nResults += leaf.exemplars.size(); } else { exsAtD.add(ex); searchState.nResults++; } // Update information about furthest exemplars if (exD > searchState.maxDistance || searchState.maxDistance == Double.POSITIVE_INFINITY) { searchState.maxDistance = exD; } if (searchState.maxDistance == exD) searchState.nExsAtMaxD = exsAtD.size(); } else if (exD < searchState.maxDistance) { // Point closer than furthest neighbour if (searchState.nResults - searchState.nExsAtMaxD >= nMinResultsBeforeAddition) { // Dropping furthest exemplars won't leave // us with too little to meet return // minimum results.remove(searchState.maxDistance); searchState.nResults -= searchState.nExsAtMaxD; } // Add new exemplar List<T> exsAtD = getOrElseInit(results, exD); if (leaf.exemplarsAreUniform) { // No need to go through every one if all // exemplars are the same exsAtD.addAll(leaf.exemplars); searchState.nResults += leaf.exemplars.size(); } else { exsAtD.add(ex); searchState.nResults++; } // Update information about furthest exemplars Map.Entry<Double, List<T>> lastEnt = results.lastEntry(); searchState.maxDistance = lastEnt.getKey(); searchState.nExsAtMaxD = lastEnt.getValue().size(); } // exD < maxDistance // No need to keep going if all exemplars are // the same if (leaf.exemplarsAreUniform) break; } // for(T ex : cur.exemplars) } private static <T extends Exemplar> List<T> getOrElseInit(Map<Double, List<T>> results, double dst) { List<T> lst = results.get(dst); if (lst == null) { lst = new LinkedList<T>(); results.put(dst, lst); } return lst; }
Assessment
Full Source Code
For the full source code to the tree built in this tutorial, see duyn's Bucket kd-tree.