Difference between revisions of "User:Duyn/kd-tree Tutorial"

From Robowiki
Jump to navigation Jump to search
(→‎Extensions: Removed comment on standardisation, added heap suggestion)
(Start of re-write to be a true tutorial, and not just a code walk-through.)
Line 1: Line 1:
( Walkthrough: An Optimised kd-tree. This is the beginning of a walk through on writing a k-d tree )
+
This page will walk you through the implementation of a ''k''d-tree. We start with a basic tree and progressively add levels of optimisation. Writing a ''k''d-tree is not an easy task; having some details explained can make the process easier. There are already other ''k''d-trees on this wiki, see some of the others for ideas.
 
 
This page will walk you through the implementation of an optimised kd tree. We will end up with a bucket kd-tree which splits on the mean of the dimension with largest variance. Writing a kd-tree is not an easy task; having some details explained can make the process easier. There are already other k-d trees on this wiki, see some of the others for more ideas.
 
  
 
==Theory==
 
==Theory==
 
A k-d tree is a binary tree which successively splits a ''k''-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A ''k''-nearest neighbour search can be implemented from a nearest neighbour algorithm by not shrinking the search radius until ''k'' items have been found. This page won't make a distinction between the two.
 
A k-d tree is a binary tree which successively splits a ''k''-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A ''k''-nearest neighbour search can be implemented from a nearest neighbour algorithm by not shrinking the search radius until ''k'' items have been found. This page won't make a distinction between the two.
 
Some of the non-obvious optimisations used in this tree are:
 
* '''Path ordering'''. Our search will descend the tree with the edge closest to the query point first. This is heuristically more likely to contain closer neighbours, which will lead to contracting our search radius sooner. It also ensures we start searching from the leaf which would hold the query point if it were added. Combined with a good bucket size, this lets us quickly load up some decent candidates before we have even examined any other branches.
 
* '''Bounds-overlap-ball testing'''. Each tree stores the bounds of the smallest hyperrect which will contain all the data in that tree. This is smaller—maybe significantly so—than the bounds which that sub-tree occupy, giving us more opportunities to skip sub-trees.
 
* '''Splitting on mean'''. While most descriptions of kd-trees describe splitting on the median, we will split on the mean. This is faster for us because we will have already calculated the mean in finding the dimension with largest variance.
 
  
 
==An Exemplar class==
 
==An Exemplar class==
Line 16: Line 9:
 
We might already have a class elsewere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.
 
We might already have a class elsewere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.
  
<pre>
+
public class Exemplar {
public class Exemplar {
+
  public final double[] domain;
  public final double[] domain;
+
 +
  public Exemplar(double[] domain) {
 +
    this.domain = domain;
 +
  }
 +
 +
  public final boolean
 +
  collocated(final Exemplar other) {
 +
    return Arrays.equals(domain, other.domain);
 +
  }
 +
}
  
  public Exemplar(final double[] coords) {
+
While this class is fully usable as is, rarely will the domain of nearest neighbours be of any interest. Often, useful data (such as [[GuessFactor]]s) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.
    this.domain = coords;
 
  }
 
 
 
  // Short hand. Shorter than calling Arrays.equals() each time.
 
  public boolean domainEquals(final Exemplar other) {
 
    return Arrays.equals(domain, other.domain);
 
  }
 
}
 
</pre>
 
 
 
While this class is fully usable as is, rarely will the domain of nearest neighbours be of any interest. Often, useful data (such as [[guess factor]]s) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.
 
  
 
==Basic Tree==
 
==Basic Tree==
We start off with a normal tree. Each tree is either a '''node''' with both left and right sub-trees defined, or a '''leaf''' with a list of exemplars. Because of our splitting algorithm, it is not worth the added complexity to allow a tree to be both since the split point might not correspond with any actual exemplars.
+
First, we will build a basic ''k''d-tree as described by Wikipedia's [http://en.wikipedia.org/w/index.php?title=Kd-tree&oldid=346156055 kd-tree] page. We start off with a standard binary tree. Each tree is either a '''node''' with both left and right sub-trees defined, or a '''leaf''' carrying an exemplar. For simplicity, we won't allow nodes to contain any exemplars.
  
We will explicitly pass bucket size into the constructor. Trees will not have a reference to their parents, each tree must have a local copy of its bucket size. This is not a static constant because we might want to build multiple trees with different bucket sizes. Dimension is declared here, but it will be inferred from the dimension of the first point added.
+
public class BasicKdTree<X extends Exemplar> {
 +
  // Basic tree structure
 +
  X data = null;
 +
  BasicKdTree<X> left = null, right = null;
 +
 +
  // Only need to test one branch since we always populate both
 +
  // branches at once
 +
  private boolean
 +
  isTree() { return left != null; }
 +
 +
  ...
 +
}
  
<pre>
+
===Adding===
public class BucketKdTree<T extends Exemplar> {
+
Each of the public API functions defers to a private static method. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.
  private List<T> exemplars = new LinkedList<T>();
 
  private BucketKdTree<T> left, right;
 
  private int bucketSize;
 
  // Zero dimension means uninitialised
 
  private int dimensions = 0;
 
  
  public BucketKdTree(int bucketSize) {
+
public void
    this.bucketSize = bucketSize;
+
add(X ex) {
  }
+
  BasicKdTree.addToTree(this, ex);
 +
}
  
  private boolean
+
To add an exemplar, we traverse the tree from top down until we find a leaf. If the leaf doesn't already have an exemplar, we can just stow it there. Otherwise, we split the leaf and put the free exemplars in each sub-tree.
  isTree() { return left != null; }
 
}
 
</pre>
 
  
==Adding==
+
To split a leaf, we cycle through split dimensions in order as in most descriptions of ''k''d-trees. Because we only allow each leaf to hold a single exemplar, we can use a simple splitting strategy: smallest on the left, largest on the right. The split value is half way between the two points along the split dimension.
 +
 +
int splitDim = 0;
 +
double split = Double.NaN;
 +
 +
private static <X extends Exemplar> void
 +
addToTree(BasicKdTree<X> tree, X ex) {
 +
  while(tree != null) {
 +
    if (tree.isTree()) {
 +
      // Traverse in search of a leaf
 +
      tree = ex.domain[tree.splitDim] <= tree.split
 +
        ? tree.left : tree.right;
 +
    } else {
 +
      if (tree.data == null) {
 +
        tree.data = ex;
 +
      } else {
 +
        // Split tree and add
 +
        // Find smallest exemplar to be our split point
 +
        final int d = tree.splitDim;
 +
        X leftX = ex, rightX = tree.data;
 +
        if (rightX.domain[d] < leftX.domain[d]) {
 +
          leftX = tree.data;
 +
          rightX = ex;
 +
        }
 +
        tree.split = 0.5*(leftX.domain[d] + rightX.domain[d]);
 +
       
 +
        final int nextSplitDim =
 +
          (tree.splitDim + 1)%tree.dimensions();
 +
       
 +
        tree.left = new BasicKdTree<X>();
 +
        tree.left.splitDim = nextSplitDim;
 +
        tree.left.data = leftX;
 +
       
 +
        tree.right = new BasicKdTree<X>();
 +
        tree.right.splitDim = nextSplitDim;
 +
        tree.right.data = rightX;
 +
      }
 +
 +
      // Done.
 +
      tree = null;
 +
    }
 +
  }
 +
}
 +
 +
private int
 +
dimensions() { return data.domain.length; }
  
Each of the public API functions defers the actual addition to another private static function. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.
+
===Searching===
 +
Before we start coding a search method, we need a helper class to store search results along with their distance from the query point. This is called <code>PrioNode</code> because we will eventually re-use it to implement a custom priority queue.
  
We decide whether to split a leaf only after the add has been completed.
+
public final class PrioNode<T> {
 +
  public final double priority;
 +
  public final T data;
 +
 +
  public PrioNode(double priority, T data) {
 +
    this.priority = priority;
 +
    this.data = data;
 +
  }
 +
}
  
<pre>
+
Like with our <code>add()</code> method, our <code>search()</code> method delegates to a static method to avoid introducing bugs by accidentally referring to member variables while we descend the tree.
// One at a time
 
public void
 
add(T ex) {
 
  BucketKdTree<T> tree = addNoSplit(this, ex);
 
  if (shouldSplit(tree)) {
 
    split(tree);
 
  }
 
}
 
  
// Bulk add gives us more data to choose a better split point
+
public Iterable<? extends PrioNode<X>>
public void
+
search(double[] query, int nResults) {
addAll(Collection<T> exs) {
+
  return BasicKdTree.search(this, query, nResults);
  final Set<BucketKdTree<T>> modTrees =
+
}
    new HashSet<BucketKdTree<T>>();
 
  for(T ex : exs) {
 
    modTrees.add(addNoSplit(this, ex));
 
  }
 
  
  for(BucketKdTree<T> tree : modTrees) {
+
To do a nearest neighbours search, we walk the tree, preferring to search sub-trees which are on the same side of the split as the query point first. Once we have found our initial candidates, we can contract our search sphere. We only search the other sub-tree if our search sphere might spill over onto the other side of the split.
    if (shouldSplit(tree)) {
 
      split(tree);
 
    }
 
  }
 
}
 
</pre>
 
  
To add an exemplar, we traverse the tree from top down until we find a leaf. Then we add the exemplar to the list at that leaf. For now, there is no special ordering to the list of exemplars in a leaf.
+
Results are collected in a <code>java.util.PriorityQueue</code> so we can easily remove the farthest exemplars as we find closer ones.
  
To decide which sub-tree to traverse, each tree stores two values&mdash;a splitting dimension and a splitting value. If our new exemplar's domain along the splitting dimension is greater than the tree's splitting value, we put it in the right sub-tree. Otherwise, it goes in the left one.
+
private static <X extends Exemplar> Iterable<? extends PrioNode<X>>
 
+
search(BasicKdTree<X> tree, double[] query, int nResults) {
Since adding takes little time compared to searching, we take this opportunity to make some optimisations:
+
  final Queue<PrioNode<X>> results =
 
+
    new PriorityQueue<PrioNode<X>>(nResults,
* We keep track of the actual hyperrect the points in this tree occupy. This lets us rule out a tree, even though its space may intersect with our search sphere, if it has no chance of containing any points within our search sphere. This hyperrect is defined by <code>maxBounds</code> and <code>minBounds</code>.
+
      new Comparator<PrioNode<X>>() {
 
+
* To save us having to do a full iteration when we come to split a leaf, we compute the running mean and variance for each dimension using [http://www.johndcook.com/standard_deviation.html Welford's method]:
+
        // min-heap
    <math>
+
        public int
    M_1 = x_1,\qquad S_1 = 0</math>
+
        compare(PrioNode<X> o1, PrioNode<X> o2) {
    <math>
+
          return o1.priority == o2.priority ? 0
    M_k = M_{k-1} + {x_k - M_{k-1} \over k} \qquad(exMean)</math>
+
            : o1.priority > o2.priority ? -1
    <math>
+
            : 1;
    S_k = S_{k-1} + (x_k - M_{k-1})\times(x_k - M_k) \qquad(exSumSqDev)</math>
+
        }
 
+
* We keep track of whether all exemplars in this leaf have the same domain. If they do, we know that comparisons on one exemplar apply to all exemplars in that leaf.
+
      }
 
+
    );
<pre>
+
  final Deque<BasicKdTree<X>> stack
private int splitDim;
+
    = new LinkedList<BasicKdTree<X>>();
private double split;
+
  stack.addLast(tree);
 
+
  while (!stack.isEmpty()) {
// These aren't initialised until add() is called.
+
    tree = stack.removeLast();
private double[] exMean;
+
   
private double[] exSumSqDev;
+
    if (tree.isTree()) {
 
+
      // Guess nearest tree to query point
// Optimisation when tree contains large number of duplicates
+
      BasicKdTree<X> nearTree = tree.left, farTree = tree.right;
private boolean exemplarsAreUniform = true;
+
      if (query[tree.splitDim] > tree.split) {
 
+
        nearTree = tree.right;
// Optimisation for searches. This lets us skip a node if its
+
        farTree = tree.left;
// scope intersects with a search hypersphere but it doesn't contain
+
      }
// any points that actually intersect.
+
     
private double[] maxBounds;
+
      // Only search far tree if our search sphere might
private double[] minBounds;
+
      // overlap with splitting plane
 
+
      if (results.size() < nResults
// Adds an exemplar without splitting overflowing leaves.
+
        || sq(query[tree.splitDim] - tree.split)
// Returns leaf to which exemplar was added.
+
          <= results.peek().priority)
private static <T extends Exemplar> BucketKdTree<T>
+
      {
addNoSplit(BucketKdTree<T> tree, T ex) {
+
        stack.addLast(farTree);
  BucketKdTree<T> cursor = tree;
+
      }
  while (cursor != null) {
+
    updateBounds(cursor, ex);
+
      // Always search the nearest branch
    if (cursor.isTree()) {
+
      stack.addLast(nearTree);
      // Sub-tree
+
    } else {
      cursor = ex.domain[cursor.splitDim] <= cursor.split
+
      final double dSq = distanceSqFrom(query, tree.data.domain);
        ? cursor.left : cursor.right;
+
      if (results.size() < nResults
    } else {
+
        || dSq < results.peek().priority)
      // Leaf
+
      {
 
+
        while (results.size() >= nResults) {
      // Infer dimensions if we haven't already
+
          results.poll();
      if (cursor.dimensions == 0)
+
        }
        cursor.dimensions = ex.domain.length;
+
 
+
        results.offer(new PrioNode<X>(dSq, tree.data));
      // Add exemplar to leaf
+
      }
      cursor.exemplars.add(ex);
+
    }
 
+
  }
      // Calculate running mean and sum of squared deviations
+
  return results;
      final int nExs = cursor.exemplars.size();
+
}
      final int dims = cursor.dimensions;
+
 
      if (nExs == 1) {
+
private static double
        cursor.exMean = Arrays.copyOf(ex.domain, dims);
+
sq(double n) { return n*n; }
        cursor.exSumSqDev = new double[dims];
 
      } else {
 
        for(int d = 0; d < dims; d++) {
 
          final double coord = ex.domain[d];
 
          final double oldMean = cursor.exMean[d], newMean;
 
          cursor.exMean[d] = newMean =
 
            oldMean + (coord - oldMean)/nExs;
 
          cursor.exSumSqDev[d] = cursor.exSumSqDev[d]
 
            + (coord - oldMean)*(coord - newMean);
 
        }
 
      }
 
 
 
      // Check that exemplars are still uniform
 
      if (cursor.exemplarsAreUniform) {
 
        final List<T> cExs = cursor.exemplars;
 
        if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0)))
 
          cursor.exemplarsAreUniform = false;
 
      }
 
 
 
      // Finished walking
 
      return cursor;
 
    }
 
  }
 
  throw new RuntimeException("Walked tree without adding anything");
 
}
 
 
 
private static <T extends Exemplar> void
 
updateBounds(BucketKdTree<T> tree, Exemplar ex) {
 
  final int dims = ex.domain.length;
 
  if (tree.maxBounds == null) {
 
    tree.maxBounds = Arrays.copyOf(ex.domain, dims);
 
    tree.minBounds = Arrays.copyOf(ex.domain, dims);
 
  } else {
 
    for(int d = 0; d < dims; d++) {
 
      final double dimVal = ex.domain[d];
 
      if (dimVal > tree.maxBounds[d])
 
        tree.maxBounds[d] = dimVal;
 
      else if (dimVal < tree.minBounds[d])
 
        tree.minBounds[d] = dimVal;
 
    }
 
  }
 
}
 
</pre>
 
 
 
==Splitting==
 
We only split when a leaf's exemplars exceed its bucket size. It is only worth splitting if the exemplars don't all have the same domain.
 
 
 
<pre>
 
private static <T extends Exemplar> boolean
 
shouldSplit(BucketKdTree<T> tree) {
 
  return tree.exemplars.size() > tree.bucketSize
 
    && !tree.exemplarsAreUniform;
 
}
 
</pre>
 
 
 
Thanks to our pre-computation, splitting is straight-forward&mdash;most of the time. We iterate through each dimension to find the one with the largest variance (skip the unnecessary division), then we can directly look up the mean of that dimension.
 
 
 
To make sure the point actually does divide our data, we separate our data into two lists destined for each sub-tree. If all the exemplars end up in only one of the lists, then our split point has failed to actually separate our exemplars. This is most likely due to rounding error when our exemplars are really close together. At a loss for what to do, we simply pick a random point and a random dimension until we find something that parts our exemplars. We know we must find one eventually because our exemplars are not uniform&mdash;at least one of them is smaller in at least one dimension than all the others.
 
 
 
Finally, we bulk load our sub-trees, store information about our split point and let go of the exemplars stored in the tree.
 
 
 
<pre>
 
@SuppressWarnings("unchecked") private static <T extends Exemplar> void
 
split(BucketKdTree<T> tree) {
 
  assert !tree.exemplarsAreUniform;
 
  // Find dimension with largest variance to split on
 
  double largestVar = -1;
 
  int splitDim = 0;
 
  for(int d = 0; d < tree.dimensions; d++) {
 
    final double var =
 
      tree.exSumSqDev[d]/(tree.exemplars.size() - 1);
 
    if (var > largestVar) {
 
      largestVar = var;
 
      splitDim = d;
 
    }
 
  }
 
 
 
  // Find mean as position for our split
 
  double splitValue = tree.exMean[splitDim];
 
 
 
  // Check that our split actually splits our data. This also lets
 
  // us bulk load exemplars into sub-trees, which is more likely
 
  // to keep optimal balance.
 
  final List<T> leftExs = new LinkedList<T>();
 
  final List<T> rightExs = new LinkedList<T>();
 
  for(T s : tree.exemplars) {
 
    if (s.domain[splitDim] <= splitValue)
 
      leftExs.add(s);
 
    else
 
      rightExs.add(s);
 
  }
 
  int leftSize = leftExs.size();
 
  final int treeSize = tree.exemplars.size();
 
  if (leftSize == treeSize || leftSize == 0) {
 
    System.err.println(
 
      "WARNING: Randomly splitting non-uniform tree");
 
    // We know the exemplars aren't all the same, so try picking
 
    // an exemplar and a dimension at random for our split point
 
 
 
    // This might take several tries, so we copy our exemplars to
 
    // an array to speed up process of picking a random point
 
    Object[] exs = tree.exemplars.toArray();
 
    while (leftSize == treeSize || leftSize == 0) {
 
      leftExs.clear();
 
      rightExs.clear();
 
 
 
      splitDim = (int)
 
        Math.floor(Math.random()*tree.dimensions);
 
      final int splitPtIdx = (int)
 
        Math.floor(Math.random()*exs.length);
 
      // Cast is inevitable consequence of java's inability to
 
      // create a generic array
 
      splitValue = ((T)exs[splitPtIdx]).domain[splitDim];
 
      for(T s : tree.exemplars) {
 
        if (s.domain[splitDim] <= splitValue)
 
          leftExs.add(s);
 
        else
 
          rightExs.add(s);
 
      }
 
      leftSize = leftExs.size();
 
    }
 
  }
 
 
 
  // We have found a valid split. Start building our sub-trees
 
  final BucketKdTree<T> left =
 
    new BucketKdTree<T>(tree.bucketSize, tree.dimensions);
 
  final BucketKdTree<T> right =
 
    new BucketKdTree<T>(tree.bucketSize, tree.dimensions);
 
  left.addAll(leftExs);
 
  right.addAll(rightExs);
 
 
 
  // Finally, commit the split
 
  tree.splitDim = splitDim;
 
  tree.split = splitValue;
 
  tree.left = left;
 
  tree.right = right;
 
 
 
  // Let go of exemplars (and their running stats) held in this leaf
 
  tree.exemplars = null;
 
  tree.exMean = tree.exSumSqDev = null;
 
}
 
</pre>
 
 
 
==Calculating distance==
 
We now need to define our distance metric. We will be using euclidean distance. To speed up calculations, we'll skip the square root and just work with squared distances:
 
 
 
  <math>distance_{x\rightarrow y}^2 = \sum (x[d] - y[d])^2</math>.
 
 
 
To find the minimum distance between a tree and our query point, we need to find the closest point along bounds of that tree to the query point. This is easy to do, even if not immediately obvious. Along each dimension,
 
 
 
  <math>nearestPoint[d]=\begin{cases}
 
min[d] & target[d]\le min[d]\\
 
target[d] & min<target[d]<max[d]\\
 
max[d] & target[d]\ge max[d]\end{cases}</math>
 
 
 
After that, we can find the minimum distance between a sub-tree and a query point. The code here is optimised because it will be called a lot during searching. The most significant optimisation was not using <code>Math.pow</code>, which must do some extra work before it can return with a multiplication (see the source in [http://www.netlib.org/fdlibm/e_pow.c fdlibm]).
 
<pre>
 
// Gets distance from target of nearest point on hyper-rect defined
 
// by supplied min and max bounds
 
private static double
 
minDistanceSqFrom(double[] target, double[] min, double[] max) {
 
  // Note: profiling shows this is called lots of times, so it pays
 
  // to be well optimised
 
  double distanceSq = 0;
 
  for(int d = 0; d < target.length; d++) {
 
    final double coord = target[d];
 
    double nearCoord;
 
    if (((nearCoord = min[d]) > coord)
 
      || ((nearCoord = max[d]) < coord))
 
    {
 
      final double dst = nearCoord - coord;
 
      distanceSq += dst*dst;
 
    }
 
  }
 
  return distanceSq;
 
}
 
 
 
// Accessible to testing
 
static double
 
distanceSqFrom(double[] p1, double[] p2) {
 
  // Note: profiling shows this is called lots of times, so it pays
 
  // to be well optimised
 
  double dSq = 0;
 
  for(int d = 0; d < p1.length; d++) {
 
    final double dst = p1[d] - p2[d];
 
    if (dst != 0)
 
      dSq += dst*dst;
 
  }
 
  return dSq;
 
}
 
</pre>
 
 
 
==Searching==
 
Before we dive into searching, we will introduce two helper classes.
 
* <code>SearchStackEntry</code> stores a tree along with its minimum distance from the query point. We put these on our tree-walking stack instead of just trees so we don't have to re-compute this minimum distance when we next pop the tree.
 
* <code>SearchState</code> holds some variables we want to be updated in after each nearest neighbour is added. Its sole purpose is to let us separate searching sub-trees and leaves into different methods without sacrificing performance.
 
 
 
<pre>
 
 
 
//
 
// class SearchStackEntry
 
//
 
// Stores a precomputed distance so we don't have to do it again
 
// when we pop the tree off the search stack.
 
 
 
private static class SearchStackEntry<T extends Exemplar> {
 
  public final double minDFromQ;
 
  public final BucketKdTree<T> tree;
 
 
 
  public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) {
 
    this.minDFromQ = minDFromQ;
 
    this.tree = tree;
 
  }
 
}
 
 
 
//
 
// class SearchState
 
//
 
// Holds data about current state of the search. Used for live updating
 
// of pruning distance.
 
 
 
private static class SearchState {
 
  int nResults = 0;
 
  double maxDistance = Double.POSITIVE_INFINITY;
 
  int nExsAtMaxD = 0;
 
}
 
</pre>
 
 
 
Our search method continuously pops our search stack, decides whether it's looking at a sub-tree or a leaf and passes the call on as appropriate. We aggragate results in a TreeMap so users have access to both exemplars and their distance from the query point. You can get better speed with a custom heap which does less sorting, but then you'll be writing two trees instead of one.
 
 
 
<pre>
 
public SortedMap<Double, List<T>>
 
search(double[] query, int nMinResults) {
 
  // Forward to a static method to avoid accidental reference to
 
  // instance variables while descending the tree
 
  return search(this, query, nMinResults);
 
}
 
 
 
// May return more results than requested if multiple exemplars have
 
// same distance from target.
 
//
 
// Note: this function works with squared distances to avoid sqrt()
 
// operations
 
private static <T extends Exemplar> SortedMap<Double, List<T>>
 
search(BucketKdTree<T> tree, double[] query, int nMinResults) {
 
  // distance => list of points that distance away from query
 
  final NavigableMap<Double, List<T>> results =
 
    new TreeMap<Double, List<T>>();
 
 
 
  final SearchState state = new SearchState();
 
  final Deque<SearchStackEntry<T>> stack =
 
    new LinkedList<SearchStackEntry<T>>();
 
  stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree));
 
  while (!stack.isEmpty()) {
 
    final SearchStackEntry<T> entry = stack.removeFirst();
 
    final BucketKdTree<T> cur = entry.tree;
 
 
 
    if (cur.isTree()) {
 
      searchTree(query, nMinResults, cur, state, stack);
 
    } else if (entry.minDFromQ <= state.maxDistance
 
      || state.nResults < nMinResults)
 
    {
 
      searchLeaf(query, nMinResults, cur, state, results);
 
    }
 
  }
 
 
 
  return results;
 
}
 
</pre>
 
 
 
===Searching a sub-tree===
 
Here is where we see two of our optimisations come into play. For each non-empty sub-tree, we compute the minimum distance from the smallest hyperrect bounding that sub-tree's exemplars to our query point. If this distance exceeds our maximum distance, we can rule out the entire sub-tree. Of course, if we haven't found our k-potential nearest neighbours yet, then both sub-trees may contain potential nearest neighbours so we will add them to the search stack anyway.
 
 
 
As a heuristic, we want to search the sub-tree with an edge closest to the query point first. This is more likely to lead to earlier contraction of our search distance. This is why we add the farthest tree to our stack first.
 
 
 
<pre>
 
private static <T extends Exemplar> void
 
searchTree(double[] query, int nMinResults, BucketKdTree<T> tree,
 
  SearchState searchState, Deque<SearchStackEntry<T>> stack)
 
{
 
  // Left is presumed near. This is verified further down.
 
  BucketKdTree<T> nearTree = tree.left, farTree = tree.right;
 
  // These variables let us skip empty sub-trees
 
  boolean nearEmpty = nearTree.minBounds == null;
 
  boolean farEmpty = farTree.minBounds == null;
 
 
 
  // Find distance from nearest possible point in each
 
  // sub-tree to query. If that is greater than max distance,
 
  // we can rule out that sub-tree.
 
  double nearD = nearEmpty ? Double.POSITIVE_INFINITY
 
    : minDistanceSqFrom(query,
 
      nearTree.minBounds, nearTree.maxBounds);
 
  double farD = farEmpty ? Double.POSITIVE_INFINITY
 
    : minDistanceSqFrom(query,
 
      farTree.minBounds, farTree.maxBounds);
 
 
 
  // Swap near and far if they're incorrect
 
  if (farD < nearD) {
 
    final double tmpD = nearD;
 
    final BucketKdTree<T> tmpTree = nearTree;
 
    final boolean tmpEmpty = nearEmpty;
 
 
 
    nearD = farD;
 
    nearTree = farTree;
 
    nearEmpty = farEmpty;
 
 
 
    farD = tmpD;
 
    farTree = tmpTree;
 
    farEmpty = tmpEmpty;
 
  }
 
 
 
  // Add nearest sub-tree to stack later so we descend it
 
  // first. This is likely to constrict our max distance
 
  // sooner, resulting in less visited nodes
 
  if (!farEmpty && (farD <= searchState.maxDistance
 
                    || searchState.nResults < nMinResults))
 
  {
 
    stack.addFirst(new SearchStackEntry<T>(farD, farTree));
 
  }
 
 
 
  if (!nearEmpty && (nearD <= searchState.maxDistance
 
                    || searchState.nResults < nMinResults))
 
  {
 
    stack.addFirst(new SearchStackEntry<T>(nearD, nearTree));
 
  }
 
}
 
</pre>
 
 
 
===Searching a leaf===
 
When we reach a leaf, we finally have a chance to contract our search distance. If we haven't yet found our k-potential nearest neighbours yet, or if the exemplar we're checking lies on the edge of our max search distance, we blindly add exemplars and update our max search distance.
 
 
 
Otherwise, we check the exemplar's distance from our query point. If it's closer than our furthest potential match, we drop our furthest result (unless doing so would result in too little k-potential neighbours) to make room for our new neighbour. Finally, we update our max search distance so we can skip any trees or points outside our new search distance.
 
 
 
If all the exemplars in a leaf are uniform, we don't need to do this check for every exemplar. We just do it for the first one, and either add or reject them all.
 
 
 
<pre>
 
private static <T extends Exemplar> void
 
searchLeaf(double[] query, int nMinResults,
 
  BucketKdTree<T> leaf, SearchState searchState,
 
  NavigableMap<Double, List<T>> results)
 
{
 
  // Keep track of elements at max distance so we know
 
  // whether we can just drop entire list of furthest
 
  // exemplars
 
  final int nMinResultsBeforeAddition = nMinResults - 1;
 
  for(T ex : leaf.exemplars) {
 
    final double exD = distanceSqFrom(query, ex.domain);
 
    if (searchState.nResults < nMinResults
 
      || exD == searchState.maxDistance)
 
    {
 
      // Blindly add this exemplar
 
      List<T> exsAtD = getOrElseInit(results, exD);
 
      if (leaf.exemplarsAreUniform) {
 
        // No need to go through every one if all
 
        // exemplars are the same
 
        exsAtD.addAll(leaf.exemplars);
 
        searchState.nResults += leaf.exemplars.size();
 
      } else {
 
        exsAtD.add(ex);
 
        searchState.nResults++;
 
      }
 
 
 
      // Update information about furthest exemplars
 
      if (exD > searchState.maxDistance
 
        || searchState.maxDistance == Double.POSITIVE_INFINITY)
 
      {
 
        searchState.maxDistance = exD;
 
      }
 
 
 
      if (searchState.maxDistance == exD)
 
        searchState.nExsAtMaxD = exsAtD.size();
 
 
 
    } else if (exD < searchState.maxDistance) {
 
      // Point closer than furthest neighbour
 
 
 
      if (searchState.nResults - searchState.nExsAtMaxD
 
        >= nMinResultsBeforeAddition)
 
      {
 
        // Dropping furthest exemplars won't leave
 
        // us with too little to meet return
 
        // minimum
 
        results.remove(searchState.maxDistance);
 
        searchState.nResults -= searchState.nExsAtMaxD;
 
      }
 
 
 
      // Add new exemplar
 
      List<T> exsAtD = getOrElseInit(results, exD);
 
      if (leaf.exemplarsAreUniform) {
 
        // No need to go through every one if all
 
        // exemplars are the same
 
        exsAtD.addAll(leaf.exemplars);
 
        searchState.nResults += leaf.exemplars.size();
 
      } else {
 
        exsAtD.add(ex);
 
        searchState.nResults++;
 
      }
 
 
 
      // Update information about furthest exemplars
 
      Map.Entry<Double, List<T>> lastEnt = results.lastEntry();
 
      searchState.maxDistance = lastEnt.getKey();
 
      searchState.nExsAtMaxD = lastEnt.getValue().size();
 
    } // exD < maxDistance
 
  
    // No need to keep going if all exemplars are
+
Our distance calculation is optimised because it will be called often. We use squared distances to avoid an unnecessary <code>sqrt()</code> operation. We also don't use <code>Math.pow()</code> because the JRE's default implementation must do some extra work before it can return with a simple multiplication (see the source in [http://www.netlib.org/fdlibm/e_pow.c fdlibm]).
    // the same
 
    if (leaf.exemplarsAreUniform) break;
 
  } // for(T ex : cur.exemplars)
 
}
 
  
private static <T extends Exemplar> List<T>
+
private static double
getOrElseInit(Map<Double, List<T>> results, double dst) {
+
distanceSqFrom(double[] p1, double[] p2) {
  List<T> lst = results.get(dst);
+
  // Note: profiling shows this is called lots of times, so it pays
  if (lst == null) {
+
  // to be well optimised
    lst = new LinkedList<T>();
+
  double dSq = 0;
    results.put(dst, lst);
+
  for(int d = 0; d < p1.length; d++) {
  }
+
    final double dst = p1[d] - p2[d];
  return lst;
+
    if (dst != 0)
}
+
      dSq += dst*dst;
</pre>
+
   }
 
+
   return dSq;
==Assessment==
 
Initial performance benchmarks are promising. Randomised JUnit tests indicate (not very rigorously) that searching can be 2&ndash;7x faster than a linear search. Using the [[k-NN algorithm benchmark]] with the [http://homepages.ucalgary.ca/~agschult/robocode/gun-data-Diamond-vs-jk.mini.CunobelinDC%200.3.csv.gz Diamond vs CunobelinDC gun data] yields the following performance:
 
 
 
<pre>
 
K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
 
-----------------------------------------
 
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
 
Read 25621 saves and 10300 searches.
 
 
 
Running 30 repetition(s) for k-nearest neighbours searching:
 
:: 13 dimension(s); 40 neighbour(s)
 
Warming up the JIT with 5 repetitions first...
 
 
 
Running tests... COMPLETED.
 
 
 
RESULT << k-nearest neighbours search with Voidious' Linear search >>
 
: Average searching time      = 0.57 miliseconds
 
: Average worst searching time = 14.992 miliseconds
 
: Average adding time          = 0.55 microseconds
 
: Accuracy                    = 100%
 
 
 
RESULT << k-nearest neighbours search with Simonton's Bucket PR k-d tree >>
 
: Average searching time      = 0.16 miliseconds
 
: Average worst searching time = 20.109 miliseconds
 
: Average adding time          = 1.41 microseconds
 
: Accuracy                    = 100%
 
 
 
RESULT << k-nearest neighbours search with Nat's Bucket PR k-d tree >>
 
: Average searching time      = 0.361 miliseconds
 
: Average worst searching time = 304.663 miliseconds
 
: Average adding time          = 66.64 microseconds
 
: Accuracy                    = 31%
 
 
 
RESULT << k-nearest neighbours search with Voidious' Bucket PR k-d tree >>
 
: Average searching time      = 0.216 miliseconds
 
: Average worst searching time = 100.556 miliseconds
 
: Average adding time          = 1.58 microseconds
 
: Accuracy                    = 100%
 
 
 
RESULT << k-nearest neighbours search with duyn's Bucket kd-tree >>
 
: Average searching time      = 0.079 miliseconds
 
: Average worst searching time = 17.476 miliseconds
 
: Average adding time          = 2.45 microseconds
 
: Accuracy                    = 100%
 
 
 
RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
 
: Average searching time      = 0.053 miliseconds
 
: Average worst searching time = 0.296 miliseconds
 
: Average adding time          = 1.69 microseconds
 
: Accuracy                    = 100%
 
 
 
 
 
BEST RESULT:
 
- #1 Rednaxela's Bucket kd-tree [0.0529]
 
- #2 duyn's Bucket kd-tree [0.0787]
 
- #3 Simonton's Bucket PR k-d tree [0.1598]
 
- #4 Voidious' Bucket PR k-d tree [0.2165]
 
- #5 Nat's Bucket PR k-d tree [0.3605]
 
- #6 Voidious' Linear search [0.5699]
 
 
 
Benchmark running time: 545.43 seconds
 
</pre>
 
 
 
This is not an exhaustive test, but it shows we're off to a nice start. The following code was used to run the benchmark:
 
 
 
* KNNBenchmark.java
 
  <pre>
 
public KNNBenchmark(final int dimension, final int numNeighbours, final SampleData[] samples, final int numReps) {
 
  final Class<?>[] searchAlgorithms = new Class<?>[] {
 
    FlatKNNSearch.class,
 
    SimontonTreeKNNSearch.class,
 
    NatTreeKNNSearch.class,
 
    VoidiousTreeKNNSearch.class,
 
+    DuynTreeImpl.class,
 
    RednaxelaTreeKNNSearch.class,
 
   };
 
   ...
 
 
  }
 
  }
  </pre>
 
 
* DuynTreeImpl.java
 
  <pre>
 
package net.robowiki.knn.implementations;
 
import net.robowiki.knn.util.KNNPoint;
 
import java.util.*;
 
// Namespace I use for my kd-tree
 
import dn.j1.algorithm.nearestneighbours.*;
 
 
public class DuynTreeImpl extends KNNImplementation {
 
  final BucketKdTree<KNNPointAdapter> tree;
 
  public DuynTreeImpl(final int dimension) {
 
    super(dimension);
 
    tree = new BucketKdTree<KNNPointAdapter>(10, dimension);
 
  }
 
 
  @Override public void
 
  addPoint(final double[] location, final String value) {
 
    tree.add(new KNNPointAdapter(location, value));
 
  }
 
 
  @Override public String
 
  getName() {
 
    return "duyn's Bucket kd-tree";
 
  }
 
 
  @Override public KNNPoint[]
 
  getNearestNeighbors(final double[] location, final int size) {
 
    final SortedMap<Double, List<KNNPointAdapter>> results = tree.search(location, size);
 
    final List<KNNPoint> justPoints = new LinkedList<KNNPoint>();
 
    for(final Map.Entry<Double, List<KNNPointAdapter>> entry : results.entrySet()) {
 
      final double distance = entry.getKey();
 
      for(final KNNPointAdapter pt : entry.getValue()) {
 
        justPoints.add(new KNNPoint(pt.value, distance));
 
      }
 
    }
 
    final KNNPoint[] retVal = new KNNPoint[justPoints.size()];
 
    return justPoints.toArray(retVal);
 
  }
 
 
  class KNNPointAdapter extends Exemplar {
 
    public final String value;
 
    public KNNPointAdapter(final double[] coords, final String value) {
 
      super(coords);
 
      this.value = value;
 
    }
 
  }
 
}
 
  </pre>
 
  
==Extensions==
+
And that's it. Barely 200 lines of code and we have ourselves a working ''k''d-tree.
This tree does not include any support for deletion or re-balancing. You can get away with this if you plan to re-build the entire tree regularly.
 
  
'''Deletion''' support would be easy to add to this tree. Since all exemplars are stored in leaves at the bottom, we need only walk the tree and remove the exemplar from the leaf containing it. We would then merge sub-trees if multiple deletions has left one of them empty. Unless our bucket width is really small, we don't need to merge sub-trees after each deletion since the addition of a subsequent exemplar is unlikely to result in choosing a significantly different split point.
+
===Evaluation===
 +
As a guide to performance, we will use the [[k-NN algorithm benchmark]] with the [http://homepages.ucalgary.ca/~agschult/robocode/gun-data-Diamond-vs-jk.mini.CunobelinDC%200.3.csv.gz Diamond vs CunobelinDC gun data]. For comparison, the benchmark included an optimised linear search to represent the lower bound of performance, and Rednaxela's tree&mdash;credited to be the fastest tree on the wiki.
  
'''Re-balancing''' with this tree is trickier. It does not currently store depth, so we don't know when a re-balance is needed. Because exemplars are only stored in leaves, collecting all the exemplars under a sub-tree would require exhaustively walking that sub-tree.
+
K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
 +
-----------------------------------------
 +
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
 +
Read 25621 saves and 10300 searches.
 +
 +
Running 30 repetition(s) for k-nearest neighbours searching:
 +
:: 13 dimension(s); 40 neighbour(s)
 +
Warming up the JIT with 5 repetitions first...
 +
 +
Running tests... COMPLETED.
 +
 +
RESULT << k-nearest neighbours search with Voidious' Linear search >>
 +
: Average searching time      = 0.577 miliseconds
 +
: Average worst searching time = 11.123 miliseconds
 +
: Average adding time          = 0.55 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
 +
: Average searching time      = 0.056 miliseconds
 +
: Average worst searching time = 15.779 miliseconds
 +
: Average adding time          = 1.65 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
 +
: Average searching time      = 0.404 miliseconds
 +
: Average worst searching time = 5.748 miliseconds
 +
: Average adding time          = 0.68 microseconds
 +
: Accuracy                    = 100%
 +
 +
 +
BEST RESULT:
 +
  - #1 Rednaxela's Bucket kd-tree [0.0557]
 +
  - #2 duyn's basic kd-tree [0.404]
 +
  - #3 Voidious' Linear search [0.5771]
 +
 +
Benchmark running time: 334.11 seconds
  
* If a depth is to be stored, each tree would need a pointer to its parent tree so we propogate depth updates when a leaf is split.
+
This test run showed that average add time was close to a linear search while searching performance improved by (0.404/0.5771 - 1) ~= 30%. It's still an order of magnitude slower than Rednaxela's included tree, which shows there is a lot to gain by selecting appropriate optimisations.
  
* We can avoid an exhaustive walk by storing each exemplar in every tree on its path from root to leaf. We then can re-balance easily by dropping both sub-trees and splitting the given tree anew. This approach has the down side of making inserts and deletes more complex since we have to update the exemplar list of every tree along the path from root to leaf.
+
To put it in context, this is how our tree performs compared to other implementations bundled with the benchmark:
  
A TreeMap does more processing than we need since it keeps all elements in order, when we only want access to the one with largest distance. Substituting a custom heap can significantly improve performance. Note that Java's standard PriorityQueue will not suffice since it does not handle duplicates the way we want.
+
K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
 +
-----------------------------------------
 +
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
 +
Read 25621 saves and 10300 searches.
 +
 +
Running 30 repetition(s) for k-nearest neighbours searching:
 +
:: 13 dimension(s); 40 neighbour(s)
 +
Warming up the JIT with 5 repetitions first...
 +
 +
Running tests... COMPLETED.
 +
 +
RESULT << k-nearest neighbours search with Voidious' Linear search >>
 +
: Average searching time      = 0.573 miliseconds
 +
: Average worst searching time = 15.828 miliseconds
 +
: Average adding time          = 0.56 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with Simonton's Bucket PR k-d tree >>
 +
: Average searching time      = 0.161 miliseconds
 +
: Average worst searching time = 19.711 miliseconds
 +
: Average adding time          = 1.47 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with Nat's Bucket PR k-d tree >>
 +
: Average searching time      = 0.381 miliseconds
 +
: Average worst searching time = 315.549 miliseconds
 +
: Average adding time          = 63.06 microseconds
 +
: Accuracy                    = 31%
 +
 +
RESULT << k-nearest neighbours search with Voidious' Bucket PR k-d tree >>
 +
: Average searching time      = 0.218 miliseconds
 +
: Average worst searching time = 104.198 miliseconds
 +
: Average adding time          = 1.45 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
 +
: Average searching time      = 0.053 miliseconds
 +
: Average worst searching time = 0.51 miliseconds
 +
: Average adding time          = 1.7 microseconds
 +
: Accuracy                    = 100%
 +
 +
RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
 +
: Average searching time      = 0.394 miliseconds
 +
: Average worst searching time = 8.718 miliseconds
 +
: Average adding time          = 0.68 microseconds
 +
: Accuracy                    = 100%
 +
 +
 +
BEST RESULT:
 +
  - #1 Rednaxela's Bucket kd-tree [0.0534]
 +
  - #2 Simonton's Bucket PR k-d tree [0.1611]
 +
  - #3 Voidious' Bucket PR k-d tree [0.2182]
 +
  - #4 Nat's Bucket PR k-d tree [0.3812]
 +
  - #5 duyn's basic kd-tree [0.3937]
 +
  - #6 Voidious' Linear search [0.5728]
 +
 +
Benchmark running time: 645.63 seconds
  
==Full Source Code==
+
==Optimisations to the Tree==
If you want to see all the code in one place, take a look at the tree behind this page&mdash;[[User:Duyn/BucketKdTree|duyn's Bucket kd-tree]]. The code there may have changed since this page as written.
 
  
==See also==
+
==A Better Results Heap==
* [[Kd-tree]]. For further information on kd-trees.
 
* [http://www.autonlab.org/autonweb/14665 Andrew Moore, 'An intoductory tutorial on kd']. A good introduction to kd-trees.
 
* [http://ilpubs.stanford.edu:8090/723/ Neal Sample, Matthew Haines, Mark Arnold, Timothy Purcell, 'Optimizing Search Strategies in k-d Trees']. Details some of the optimisations used in this tree.
 

Revision as of 10:16, 12 March 2010

This page will walk you through the implementation of a kd-tree. We start with a basic tree and progressively add levels of optimisation. Writing a kd-tree is not an easy task; having some details explained can make the process easier. There are already other kd-trees on this wiki, see some of the others for ideas.

Theory

A k-d tree is a binary tree which successively splits a k-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A k-nearest neighbour search can be implemented from a nearest neighbour algorithm by not shrinking the search radius until k items have been found. This page won't make a distinction between the two.

An Exemplar class

We will call each item in our k-d tree an exemplar. Each exemplar has a domain—its co-ordinates in the k-d tree's space. Each exemplar could also have an arbitrary payload, but our tree does not need to know about that. It will only handle storing exemplars based on their domain and returning them in a nearest neighbour search.

We might already have a class elsewere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.

public class Exemplar {
  public final double[] domain;

  public Exemplar(double[] domain) {
    this.domain = domain;
  }

  public final boolean
  collocated(final Exemplar other) {
    return Arrays.equals(domain, other.domain);
  }
}

While this class is fully usable as is, rarely will the domain of nearest neighbours be of any interest. Often, useful data (such as GuessFactors) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.

Basic Tree

First, we will build a basic kd-tree as described by Wikipedia's kd-tree page. We start off with a standard binary tree. Each tree is either a node with both left and right sub-trees defined, or a leaf carrying an exemplar. For simplicity, we won't allow nodes to contain any exemplars.

public class BasicKdTree<X extends Exemplar> {
  // Basic tree structure
  X data = null;
  BasicKdTree<X> left = null, right = null;

  // Only need to test one branch since we always populate both
  // branches at once
  private boolean
  isTree() { return left != null; }

  ...
}

Adding

Each of the public API functions defers to a private static method. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.

public void
add(X ex) {
  BasicKdTree.addToTree(this, ex);
}

To add an exemplar, we traverse the tree from top down until we find a leaf. If the leaf doesn't already have an exemplar, we can just stow it there. Otherwise, we split the leaf and put the free exemplars in each sub-tree.

To split a leaf, we cycle through split dimensions in order as in most descriptions of kd-trees. Because we only allow each leaf to hold a single exemplar, we can use a simple splitting strategy: smallest on the left, largest on the right. The split value is half way between the two points along the split dimension.

int splitDim = 0;
double split = Double.NaN;

private static <X extends Exemplar> void
addToTree(BasicKdTree<X> tree, X ex) {
  while(tree != null) {
    if (tree.isTree()) {
      // Traverse in search of a leaf
      tree = ex.domain[tree.splitDim] <= tree.split
        ? tree.left : tree.right;
    } else {
      if (tree.data == null) {
        tree.data = ex;
      } else {
        // Split tree and add
        // Find smallest exemplar to be our split point
        final int d = tree.splitDim;
        X leftX = ex, rightX = tree.data;
        if (rightX.domain[d] < leftX.domain[d]) {
          leftX = tree.data;
          rightX = ex;
        }
        tree.split = 0.5*(leftX.domain[d] + rightX.domain[d]);
        
        final int nextSplitDim =
          (tree.splitDim + 1)%tree.dimensions();
        
        tree.left = new BasicKdTree<X>();
        tree.left.splitDim = nextSplitDim;
        tree.left.data = leftX;
        
        tree.right = new BasicKdTree<X>();
        tree.right.splitDim = nextSplitDim;
        tree.right.data = rightX;
      }

      // Done.
      tree = null;
    }
  }
}

private int
dimensions() { return data.domain.length; }

Searching

Before we start coding a search method, we need a helper class to store search results along with their distance from the query point. This is called PrioNode because we will eventually re-use it to implement a custom priority queue.

public final class PrioNode<T> {
  public final double priority;
  public final T data;

  public PrioNode(double priority, T data) {
    this.priority = priority;
    this.data = data;
  }
}

Like with our add() method, our search() method delegates to a static method to avoid introducing bugs by accidentally referring to member variables while we descend the tree.

public Iterable<? extends PrioNode<X>>
search(double[] query, int nResults) {
  return BasicKdTree.search(this, query, nResults);
}

To do a nearest neighbours search, we walk the tree, preferring to search sub-trees which are on the same side of the split as the query point first. Once we have found our initial candidates, we can contract our search sphere. We only search the other sub-tree if our search sphere might spill over onto the other side of the split.

Results are collected in a java.util.PriorityQueue so we can easily remove the farthest exemplars as we find closer ones.

private static <X extends Exemplar> Iterable<? extends PrioNode<X>>
search(BasicKdTree<X> tree, double[] query, int nResults) {
  final Queue<PrioNode<X>> results =
    new PriorityQueue<PrioNode<X>>(nResults,
      new Comparator<PrioNode<X>>() {

        // min-heap
        public int
        compare(PrioNode<X> o1, PrioNode<X> o2) {
          return o1.priority == o2.priority ? 0
            : o1.priority > o2.priority ? -1
            : 1;
        }

      }
    );
  final Deque<BasicKdTree<X>> stack
    = new LinkedList<BasicKdTree<X>>();
  stack.addLast(tree);
  while (!stack.isEmpty()) {
    tree = stack.removeLast();
    
    if (tree.isTree()) {
      // Guess nearest tree to query point
      BasicKdTree<X> nearTree = tree.left, farTree = tree.right;
      if (query[tree.splitDim] > tree.split) {
        nearTree = tree.right;
        farTree = tree.left;
      }
      
      // Only search far tree if our search sphere might
      // overlap with splitting plane
      if (results.size() < nResults
        || sq(query[tree.splitDim] - tree.split)
          <= results.peek().priority)
      {
        stack.addLast(farTree);
      }

      // Always search the nearest branch
      stack.addLast(nearTree);
    } else {
      final double dSq = distanceSqFrom(query, tree.data.domain);
      if (results.size() < nResults
        || dSq < results.peek().priority)
      {
        while (results.size() >= nResults) {
          results.poll();
        }

        results.offer(new PrioNode<X>(dSq, tree.data));
      }
    }
  }
  return results;
}
 
private static double
sq(double n) { return n*n; }

Our distance calculation is optimised because it will be called often. We use squared distances to avoid an unnecessary sqrt() operation. We also don't use Math.pow() because the JRE's default implementation must do some extra work before it can return with a simple multiplication (see the source in fdlibm).

private static double
distanceSqFrom(double[] p1, double[] p2) {
  // Note: profiling shows this is called lots of times, so it pays
  // to be well optimised
  double dSq = 0;
  for(int d = 0; d < p1.length; d++) {
    final double dst = p1[d] - p2[d];
    if (dst != 0)
      dSq += dst*dst;
  }
  return dSq;
}

And that's it. Barely 200 lines of code and we have ourselves a working kd-tree.

Evaluation

As a guide to performance, we will use the k-NN algorithm benchmark with the Diamond vs CunobelinDC gun data. For comparison, the benchmark included an optimised linear search to represent the lower bound of performance, and Rednaxela's tree—credited to be the fastest tree on the wiki.

K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
-----------------------------------------
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
Read 25621 saves and 10300 searches.

Running 30 repetition(s) for k-nearest neighbours searching:
:: 13 dimension(s); 40 neighbour(s)
Warming up the JIT with 5 repetitions first...

Running tests... COMPLETED.

RESULT << k-nearest neighbours search with Voidious' Linear search >>
: Average searching time       = 0.577 miliseconds
: Average worst searching time = 11.123 miliseconds
: Average adding time          = 0.55 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
: Average searching time       = 0.056 miliseconds
: Average worst searching time = 15.779 miliseconds
: Average adding time          = 1.65 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
: Average searching time       = 0.404 miliseconds
: Average worst searching time = 5.748 miliseconds
: Average adding time          = 0.68 microseconds
: Accuracy                     = 100%


BEST RESULT: 
 - #1 Rednaxela's Bucket kd-tree [0.0557]
 - #2 duyn's basic kd-tree [0.404]
 - #3 Voidious' Linear search [0.5771]

Benchmark running time: 334.11 seconds

This test run showed that average add time was close to a linear search while searching performance improved by (0.404/0.5771 - 1) ~= 30%. It's still an order of magnitude slower than Rednaxela's included tree, which shows there is a lot to gain by selecting appropriate optimisations.

To put it in context, this is how our tree performs compared to other implementations bundled with the benchmark:

K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
-----------------------------------------
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
Read 25621 saves and 10300 searches.

Running 30 repetition(s) for k-nearest neighbours searching:
:: 13 dimension(s); 40 neighbour(s)
Warming up the JIT with 5 repetitions first...

Running tests... COMPLETED.

RESULT << k-nearest neighbours search with Voidious' Linear search >>
: Average searching time       = 0.573 miliseconds
: Average worst searching time = 15.828 miliseconds
: Average adding time          = 0.56 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Simonton's Bucket PR k-d tree >>
: Average searching time       = 0.161 miliseconds
: Average worst searching time = 19.711 miliseconds
: Average adding time          = 1.47 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Nat's Bucket PR k-d tree >>
: Average searching time       = 0.381 miliseconds
: Average worst searching time = 315.549 miliseconds
: Average adding time          = 63.06 microseconds
: Accuracy                     = 31%

RESULT << k-nearest neighbours search with Voidious' Bucket PR k-d tree >>
: Average searching time       = 0.218 miliseconds
: Average worst searching time = 104.198 miliseconds
: Average adding time          = 1.45 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
: Average searching time       = 0.053 miliseconds
: Average worst searching time = 0.51 miliseconds
: Average adding time          = 1.7 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
: Average searching time       = 0.394 miliseconds
: Average worst searching time = 8.718 miliseconds
: Average adding time          = 0.68 microseconds
: Accuracy                     = 100%


BEST RESULT: 
 - #1 Rednaxela's Bucket kd-tree [0.0534]
 - #2 Simonton's Bucket PR k-d tree [0.1611]
 - #3 Voidious' Bucket PR k-d tree [0.2182]
 - #4 Nat's Bucket PR k-d tree [0.3812]
 - #5 duyn's basic kd-tree [0.3937]
 - #6 Voidious' Linear search [0.5728]

Benchmark running time: 645.63 seconds

Optimisations to the Tree

A Better Results Heap