User:Duyn/kd-tree Tutorial
This page will walk you through the implementation of a kd-tree. We start with a basic tree and progressively add levels of optimisation. Writing a kd-tree is not an easy task; having some details explained can make the process easier. There are already other kd-trees on this wiki, see some of the others for ideas.
Contents
Theory
A k-d tree is a binary tree which successively splits a k-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A k-nearest neighbour search can be implemented from a nearest neighbour algorithm by not shrinking the search radius until k items have been found. This page won't make a distinction between the two.
An Exemplar class
We will call each item in our k-d tree an exemplar. Each exemplar has a domain—its co-ordinates in the k-d tree's space. Each exemplar could also have an arbitrary payload, but our tree does not need to know about that. It will only handle storing exemplars based on their domain and returning them in a nearest neighbour search.
We might already have a class elsewere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.
public class Exemplar { public final double[] domain; public Exemplar(double[] domain) { this.domain = domain; } public final boolean collocated(final Exemplar other) { return Arrays.equals(domain, other.domain); } }
While this class is fully usable as is, rarely will the domain of nearest neighbours be of any interest. Often, useful data (such as GuessFactors) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.
Basic Tree
First, we will build a basic kd-tree as described by Wikipedia's kd-tree page. We start off with a standard binary tree. Each tree is either a node with both left and right sub-trees defined, or a leaf carrying an exemplar. For simplicity, we won't allow nodes to contain any exemplars.
public class BasicKdTree<X extends Exemplar> { // Basic tree structure X data = null; BasicKdTree<X> left = null, right = null; // Only need to test one branch since we always populate both // branches at once private boolean isTree() { return left != null; } ... }
Adding
Each of the public API functions defers to a private static method. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.
public void add(X ex) { BasicKdTree.addToTree(this, ex); }
To add an exemplar, we traverse the tree from top down until we find a leaf. If the leaf doesn't already have an exemplar, we can just stow it there. Otherwise, we split the leaf and put the free exemplars in each sub-tree.
To split a leaf, we cycle through split dimensions in order as in most descriptions of kd-trees. Because we only allow each leaf to hold a single exemplar, we can use a simple splitting strategy: smallest on the left, largest on the right. The split value is half way between the two points along the split dimension.
int splitDim = 0; double split = Double.NaN; private static <X extends Exemplar> void addToTree(BasicKdTree<X> tree, X ex) { while(tree != null) { if (tree.isTree()) { // Traverse in search of a leaf tree = ex.domain[tree.splitDim] <= tree.split ? tree.left : tree.right; } else { if (tree.data == null) { tree.data = ex; } else { // Split tree and add // Find smallest exemplar to be our split point final int d = tree.splitDim; X leftX = ex, rightX = tree.data; if (rightX.domain[d] < leftX.domain[d]) { leftX = tree.data; rightX = ex; } tree.split = 0.5*(leftX.domain[d] + rightX.domain[d]); final int nextSplitDim = (tree.splitDim + 1)%tree.dimensions(); tree.left = new BasicKdTree<X>(); tree.left.splitDim = nextSplitDim; tree.left.data = leftX; tree.right = new BasicKdTree<X>(); tree.right.splitDim = nextSplitDim; tree.right.data = rightX; } // Done. tree = null; } } } private int dimensions() { return data.domain.length; }
Searching
Before we start coding a search method, we need a helper class to store search results along with their distance from the query point. This is called PrioNode
because we will eventually re-use it to implement a custom priority queue.
public final class PrioNode<T> { public final double priority; public final T data; public PrioNode(double priority, T data) { this.priority = priority; this.data = data; } }
Like with our add()
method, our search()
method delegates to a static method to avoid introducing bugs by accidentally referring to member variables while we descend the tree.
public Iterable<? extends PrioNode<X>> search(double[] query, int nResults) { return BasicKdTree.search(this, query, nResults); }
To do a nearest neighbours search, we walk the tree, preferring to search sub-trees which are on the same side of the split as the query point first. Once we have found our initial candidates, we can contract our search sphere. We only search the other sub-tree if our search sphere might spill over onto the other side of the split.
Results are collected in a java.util.PriorityQueue
so we can easily remove the farthest exemplars as we find closer ones.
private static <X extends Exemplar> Iterable<? extends PrioNode<X>> search(BasicKdTree<X> tree, double[] query, int nResults) { final Queue<PrioNode<X>> results = new PriorityQueue<PrioNode<X>>(nResults, new Comparator<PrioNode<X>>() { // min-heap public int compare(PrioNode<X> o1, PrioNode<X> o2) { return o1.priority == o2.priority ? 0 : o1.priority > o2.priority ? -1 : 1; } } ); final Deque<BasicKdTree<X>> stack = new LinkedList<BasicKdTree<X>>(); stack.addLast(tree); while (!stack.isEmpty()) { tree = stack.removeLast(); if (tree.isTree()) { // Guess nearest tree to query point BasicKdTree<X> nearTree = tree.left, farTree = tree.right; if (query[tree.splitDim] > tree.split) { nearTree = tree.right; farTree = tree.left; } // Only search far tree if our search sphere might // overlap with splitting plane if (results.size() < nResults || sq(query[tree.splitDim] - tree.split) <= results.peek().priority) { stack.addLast(farTree); } // Always search the nearest branch stack.addLast(nearTree); } else { final double dSq = distanceSqFrom(query, tree.data.domain); if (results.size() < nResults || dSq < results.peek().priority) { while (results.size() >= nResults) { results.poll(); } results.offer(new PrioNode<X>(dSq, tree.data)); } } } return results; } private static double sq(double n) { return n*n; }
Our distance calculation is optimised because it will be called often. We use squared distances to avoid an unnecessary sqrt()
operation. We also don't use Math.pow()
because the JRE's default implementation must do some extra work before it can return with a simple multiplication (see the source in fdlibm).
private static double distanceSqFrom(double[] p1, double[] p2) { // Note: profiling shows this is called lots of times, so it pays // to be well optimised double dSq = 0; for(int d = 0; d < p1.length; d++) { final double dst = p1[d] - p2[d]; if (dst != 0) dSq += dst*dst; } return dSq; }
And that's it. Barely 200 lines of code and we have ourselves a working kd-tree.
Evaluation
As a guide to performance, we will use the k-NN algorithm benchmark with the Diamond vs CunobelinDC gun data. For comparison, the benchmark included an optimised linear search to represent the lower bound of performance, and Rednaxela's tree—credited to be the fastest tree on the wiki.
K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK ----------------------------------------- Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz Read 25621 saves and 10300 searches. Running 30 repetition(s) for k-nearest neighbours searching: :: 13 dimension(s); 40 neighbour(s) Warming up the JIT with 5 repetitions first... Running tests... COMPLETED. RESULT << k-nearest neighbours search with Voidious' Linear search >> : Average searching time = 0.577 miliseconds : Average worst searching time = 11.123 miliseconds : Average adding time = 0.55 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >> : Average searching time = 0.056 miliseconds : Average worst searching time = 15.779 miliseconds : Average adding time = 1.65 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with duyn's basic kd-tree >> : Average searching time = 0.404 miliseconds : Average worst searching time = 5.748 miliseconds : Average adding time = 0.68 microseconds : Accuracy = 100% BEST RESULT: - #1 Rednaxela's Bucket kd-tree [0.0557] - #2 duyn's basic kd-tree [0.404] - #3 Voidious' Linear search [0.5771] Benchmark running time: 334.11 seconds
This test run showed that average add time was close to a linear search while searching performance improved by (0.404/0.5771 - 1) ~= 30%. It's still an order of magnitude slower than Rednaxela's included tree, which shows there is a lot to gain by selecting appropriate optimisations.
To put it in context, this is how our tree performs compared to other implementations bundled with the benchmark:
K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK ----------------------------------------- Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz Read 25621 saves and 10300 searches. Running 30 repetition(s) for k-nearest neighbours searching: :: 13 dimension(s); 40 neighbour(s) Warming up the JIT with 5 repetitions first... Running tests... COMPLETED. RESULT << k-nearest neighbours search with Voidious' Linear search >> : Average searching time = 0.573 miliseconds : Average worst searching time = 15.828 miliseconds : Average adding time = 0.56 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with Simonton's Bucket PR k-d tree >> : Average searching time = 0.161 miliseconds : Average worst searching time = 19.711 miliseconds : Average adding time = 1.47 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with Nat's Bucket PR k-d tree >> : Average searching time = 0.381 miliseconds : Average worst searching time = 315.549 miliseconds : Average adding time = 63.06 microseconds : Accuracy = 31% RESULT << k-nearest neighbours search with Voidious' Bucket PR k-d tree >> : Average searching time = 0.218 miliseconds : Average worst searching time = 104.198 miliseconds : Average adding time = 1.45 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >> : Average searching time = 0.053 miliseconds : Average worst searching time = 0.51 miliseconds : Average adding time = 1.7 microseconds : Accuracy = 100% RESULT << k-nearest neighbours search with duyn's basic kd-tree >> : Average searching time = 0.394 miliseconds : Average worst searching time = 8.718 miliseconds : Average adding time = 0.68 microseconds : Accuracy = 100% BEST RESULT: - #1 Rednaxela's Bucket kd-tree [0.0534] - #2 Simonton's Bucket PR k-d tree [0.1611] - #3 Voidious' Bucket PR k-d tree [0.2182] - #4 Nat's Bucket PR k-d tree [0.3812] - #5 duyn's basic kd-tree [0.3937] - #6 Voidious' Linear search [0.5728] Benchmark running time: 645.63 seconds
The following code was used in the benchmark:
- KNNBenchmark.java
public KNNBenchmark(final int dimension, final int numNeighbours, final SampleData[] samples, final int numReps) { final Class<?>[] searchAlgorithms = new Class<?>[] { FlatKNNSearch.class, SimontonTreeKNNSearch.class, NatTreeKNNSearch.class, VoidiousTreeKNNSearch.class, + DuynBasicKNNSearch.class, RednaxelaTreeKNNSearch.class, }; ... }
- DuynBasicKNNSearch.java
public class DuynBasicKNNSearch extends KNNImplementation { final BasicKdTree<StringExemplar> tree; public DuynBasicKNNSearch(final int dimension) { super(dimension); tree = new BasicKdTree<StringExemplar>(); } @Override public void addPoint(final double[] location, final String value) { tree.add(new StringExemplar(location, value)); } @Override public String getName() { return "duyn's basic kd-tree"; } @Override public KNNPoint[] getNearestNeighbors(final double[] location, final int size) { final List<KNNPoint> justPoints = new LinkedList<KNNPoint>(); for(PrioNode<StringExemplar> sr : tree.search(location, size)) { final double distance = sr.priority; final StringExemplar pt = sr.data; justPoints.add(new KNNPoint(pt.value, distance)); } final KNNPoint[] retVal = new KNNPoint[justPoints.size()]; return justPoints.toArray(retVal); } class StringExemplar extends Exemplar { public final String value; public StringExemplar(final double[] coords, final String value) { super(coords); this.value = value; } } }