Difference between revisions of "User:Duyn/kd-tree Tutorial"

From Robowiki
Jump to navigation Jump to search
(→‎Searching a Tree: Progress save.)
(→‎Searching: Added missing public method.)
Line 695: Line 695:
 
Our search method does a depth-first search using a stack. Each time we pop our stack, we do a bounds check if necessary, then pass the call on to the appropriate method.
 
Our search method does a depth-first search using a stack. Each time we pop our stack, we do a bounds check if necessary, then pass the call on to the appropriate method.
  
 +
public Iterable<PrioNode<X>>
 +
search(double[] query, int nResults) {
 +
  // Forward to a static method to avoid accidental reference to
 +
  // instance variables while descending the tree
 +
  return search(this, query, nResults);
 +
}
 +
 
  private static <X extends Exemplar> Iterable<PrioNode<X>>
 
  private static <X extends Exemplar> Iterable<PrioNode<X>>
 
  search(OptKdTree<X> tree, double[] query, int nResults) {
 
  search(OptKdTree<X> tree, double[] query, int nResults) {

Revision as of 17:49, 12 March 2010

This page will walk you through the implementation of a kd-tree. We start with a basic tree and progressively add levels of optimisation. Writing a kd-tree is not an easy task; having some details explained can make the process easier. There are already other kd-trees on this wiki, see some of the others for ideas.

Theory

A k-d tree is a binary tree which successively splits a k-dimensional space in half. This lets us speed up a nearest neighbour search by not examining points in partitions which are too far away. A k-nearest neighbour search can be implemented from a nearest neighbour algorithm by not shrinking the search radius until k items have been found. This page won't make a distinction between the two.

An Exemplar class

We will call each item in our k-d tree an exemplar. Each exemplar has a domain—its co-ordinates in the k-d tree's space. Each exemplar could also have an arbitrary payload, but our tree does not need to know about that. It will only handle storing exemplars based on their domain and returning them in a nearest neighbour search.

We might already have a class elsewere called Point which handles 2D co-ordinates. This terminology avoids conflict with that.

public class Exemplar {
  public final double[] domain;

  public Exemplar(double[] domain) {
    this.domain = domain;
  }

  public final boolean
  collocated(final Exemplar other) {
    return Arrays.equals(domain, other.domain);
  }
}

While this class is fully usable as is, rarely will the domain of nearest neighbours be of any interest. Often, useful data (such as GuessFactors) will be loaded by sub-classing this Exemplar class. Our k-d tree will be parameterised based on this expectation.

Basic Tree

First, we will build a basic kd-tree as described by Wikipedia's kd-tree page. We start off with a standard binary tree. Each tree is either a node with both left and right sub-trees defined, or a leaf carrying an exemplar. For simplicity, we won't allow nodes to contain any exemplars.

public class BasicKdTree<X extends Exemplar> {
  // Basic tree structure
  X data = null;
  BasicKdTree<X> left = null, right = null;

  // Only need to test one branch since we always populate both
  // branches at once
  private boolean
  isTree() { return left != null; }

  ...
}

Adding

Each of the public API functions defers to a private static method. This is to avoid accidentally referring to instance variables while we walk the tree. This is a common pattern we will use for much of the actual behaviour code for this tree.

public void
add(X ex) {
  BasicKdTree.addToTree(this, ex);
}

To add an exemplar, we traverse the tree from top down until we find a leaf. If the leaf doesn't already have an exemplar, we can just stow it there. Otherwise, we split the leaf and put the free exemplars in each sub-tree.

To split a leaf, we cycle through split dimensions in order as in most descriptions of kd-trees. Because we only allow each leaf to hold a single exemplar, we can use a simple splitting strategy: smallest on the left, largest on the right. The split value is half way between the two points along the split dimension.

int splitDim = 0;
double split = Double.NaN;

private static <X extends Exemplar> void
addToTree(BasicKdTree<X> tree, X ex) {
  while(tree != null) {
    if (tree.isTree()) {
      // Traverse in search of a leaf
      tree = ex.domain[tree.splitDim] <= tree.split
        ? tree.left : tree.right;
    } else {
      if (tree.data == null) {
        tree.data = ex;
      } else {
        // Split tree and add
        // Find smallest exemplar to be our split point
        final int d = tree.splitDim;
        X leftX = ex, rightX = tree.data;
        if (rightX.domain[d] < leftX.domain[d]) {
          leftX = tree.data;
          rightX = ex;
        }
        tree.split = 0.5*(leftX.domain[d] + rightX.domain[d]);
        
        final int nextSplitDim =
          (tree.splitDim + 1)%tree.dimensions();
        
        tree.left = new BasicKdTree<X>();
        tree.left.splitDim = nextSplitDim;
        tree.left.data = leftX;
        
        tree.right = new BasicKdTree<X>();
        tree.right.splitDim = nextSplitDim;
        tree.right.data = rightX;
      }

      // Done.
      tree = null;
    }
  }
}

private int
dimensions() { return data.domain.length; }

Searching

Before we start coding a search method, we need a helper class to store search results along with their distance from the query point. This is called PrioNode because we will eventually re-use it to implement a custom priority queue.

public final class PrioNode<T> {
  public final double priority;
  public final T data;

  public PrioNode(double priority, T data) {
    this.priority = priority;
    this.data = data;
  }
}

Like with our add() method, our search() method delegates to a static method to avoid introducing bugs by accidentally referring to member variables while we descend the tree.

public Iterable<? extends PrioNode<X>>
search(double[] query, int nResults) {
  return BasicKdTree.search(this, query, nResults);
}

To do a nearest neighbours search, we walk the tree, preferring to search sub-trees which are on the same side of the split as the query point first. Once we have found our initial candidates, we can contract our search sphere. We only search the other sub-tree if our search sphere might spill over onto the other side of the split.

Results are collected in a java.util.PriorityQueue so we can easily remove the farthest exemplars as we find closer ones.

private static <X extends Exemplar> Iterable<? extends PrioNode<X>>
search(BasicKdTree<X> tree, double[] query, int nResults) {
  final Queue<PrioNode<X>> results =
    new PriorityQueue<PrioNode<X>>(nResults,
      new Comparator<PrioNode<X>>() {

        // min-heap
        public int
        compare(PrioNode<X> o1, PrioNode<X> o2) {
          return o1.priority == o2.priority ? 0
            : o1.priority > o2.priority ? -1
            : 1;
        }

      }
    );
  final Deque<BasicKdTree<X>> stack
    = new LinkedList<BasicKdTree<X>>();
  stack.addLast(tree);
  while (!stack.isEmpty()) {
    tree = stack.removeLast();
    
    if (tree.isTree()) {
      // Guess nearest tree to query point
      BasicKdTree<X> nearTree = tree.left, farTree = tree.right;
      if (query[tree.splitDim] > tree.split) {
        nearTree = tree.right;
        farTree = tree.left;
      }
      
      // Only search far tree if our search sphere might
      // overlap with splitting plane
      if (results.size() < nResults
        || sq(query[tree.splitDim] - tree.split)
          <= results.peek().priority)
      {
        stack.addLast(farTree);
      }

      // Always search the nearest branch
      stack.addLast(nearTree);
    } else {
      final double dSq = distanceSqFrom(query, tree.data.domain);
      if (results.size() < nResults
        || dSq < results.peek().priority)
      {
        while (results.size() >= nResults) {
          results.poll();
        }

        results.offer(new PrioNode<X>(dSq, tree.data));
      }
    }
  }
  return results;
}
 
private static double
sq(double n) { return n*n; }

Our distance calculation is optimised because it will be called often. We use squared distances to avoid an unnecessary sqrt() operation. We also don't use Math.pow() because the JRE's default implementation must do some extra work before it can return with a simple multiplication (see the source in fdlibm).

private static double
distanceSqFrom(double[] p1, double[] p2) {
  // Note: profiling shows this is called lots of times, so it pays
  // to be well optimised
  double dSq = 0;
  for(int d = 0; d < p1.length; d++) {
    final double dst = p1[d] - p2[d];
    if (dst != 0)
      dSq += dst*dst;
  }
  return dSq;
}

And that's it. Barely 200 lines of code and we have ourselves a working kd-tree.

Evaluation

As a guide to performance, we will use the k-NN algorithm benchmark with the Diamond vs CunobelinDC gun data. For comparison, the benchmark included an optimised linear search to represent the lower bound of performance, and Rednaxela's tree—credited to be the fastest tree on the wiki.

K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
-----------------------------------------
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
Read 25621 saves and 10300 searches.

Running 30 repetition(s) for k-nearest neighbours searching:
:: 13 dimension(s); 40 neighbour(s)
Warming up the JIT with 5 repetitions first...

Running tests... COMPLETED.

RESULT << k-nearest neighbours search with Voidious' Linear search >>
: Average searching time       = 0.577 miliseconds
: Average worst searching time = 11.123 miliseconds
: Average adding time          = 0.55 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
: Average searching time       = 0.056 miliseconds
: Average worst searching time = 15.779 miliseconds
: Average adding time          = 1.65 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
: Average searching time       = 0.404 miliseconds
: Average worst searching time = 5.748 miliseconds
: Average adding time          = 0.68 microseconds
: Accuracy                     = 100%


BEST RESULT: 
 - #1 Rednaxela's Bucket kd-tree [0.0557]
 - #2 duyn's basic kd-tree [0.404]
 - #3 Voidious' Linear search [0.5771]

Benchmark running time: 334.11 seconds

This test run showed that average add time was close to a linear search while searching performance improved by (0.404/0.5771 - 1) ~= 30%. It's still an order of magnitude slower than Rednaxela's included tree, which shows there is a lot to gain by selecting appropriate optimisations.

To put it in context, this is how our tree performs compared to other implementations bundled with the benchmark:

K-NEAREST NEIGHBOURS ALGORITHMS BENCHMARK
-----------------------------------------
Reading data from gun-data-Diamond-vs-jk.mini.CunobelinDC 0.3.csv.gz
Read 25621 saves and 10300 searches.

Running 30 repetition(s) for k-nearest neighbours searching:
:: 13 dimension(s); 40 neighbour(s)
Warming up the JIT with 5 repetitions first...

Running tests... COMPLETED.

RESULT << k-nearest neighbours search with Voidious' Linear search >>
: Average searching time       = 0.573 miliseconds
: Average worst searching time = 15.828 miliseconds
: Average adding time          = 0.56 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Simonton's Bucket PR k-d tree >>
: Average searching time       = 0.161 miliseconds
: Average worst searching time = 19.711 miliseconds
: Average adding time          = 1.47 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Nat's Bucket PR k-d tree >>
: Average searching time       = 0.381 miliseconds
: Average worst searching time = 315.549 miliseconds
: Average adding time          = 63.06 microseconds
: Accuracy                     = 31%

RESULT << k-nearest neighbours search with Voidious' Bucket PR k-d tree >>
: Average searching time       = 0.218 miliseconds
: Average worst searching time = 104.198 miliseconds
: Average adding time          = 1.45 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with Rednaxela's Bucket kd-tree >>
: Average searching time       = 0.053 miliseconds
: Average worst searching time = 0.51 miliseconds
: Average adding time          = 1.7 microseconds
: Accuracy                     = 100%

RESULT << k-nearest neighbours search with duyn's basic kd-tree >>
: Average searching time       = 0.394 miliseconds
: Average worst searching time = 8.718 miliseconds
: Average adding time          = 0.68 microseconds
: Accuracy                     = 100%


BEST RESULT: 
 - #1 Rednaxela's Bucket kd-tree [0.0534]
 - #2 Simonton's Bucket PR k-d tree [0.1611]
 - #3 Voidious' Bucket PR k-d tree [0.2182]
 - #4 Nat's Bucket PR k-d tree [0.3812]
 - #5 duyn's basic kd-tree [0.3937]
 - #6 Voidious' Linear search [0.5728]

Benchmark running time: 645.63 seconds

The following code was used in the benchmark:

  • KNNBenchmark.java
public KNNBenchmark(final int dimension, final int numNeighbours, final SampleData[] samples, final int numReps) {
  final Class<?>[] searchAlgorithms = new Class<?>[] {
    FlatKNNSearch.class,
    SimontonTreeKNNSearch.class,
    NatTreeKNNSearch.class,
    VoidiousTreeKNNSearch.class,
+   DuynBasicKNNSearch.class,
    RednaxelaTreeKNNSearch.class,
  };
  ...
}
 
  • DuynBasicKNNSearch.java
public class DuynBasicKNNSearch extends KNNImplementation {
  final BasicKdTree<StringExemplar> tree;
  public DuynBasicKNNSearch(final int dimension) {
    super(dimension);
    tree = new BasicKdTree<StringExemplar>();
  }

  @Override public void
  addPoint(final double[] location, final String value) {
    tree.add(new StringExemplar(location, value));
  }

  @Override public String
  getName() {
    return "duyn's basic kd-tree";
  }

  @Override public KNNPoint[]
  getNearestNeighbors(final double[] location, final int size) {
    final List<KNNPoint> justPoints = new LinkedList<KNNPoint>();
    for(PrioNode<StringExemplar> sr : tree.search(location, size)) {
      final double distance = sr.priority;
      final StringExemplar pt = sr.data;
      justPoints.add(new KNNPoint(pt.value, distance));
    }
    final KNNPoint[] retVal = new KNNPoint[justPoints.size()];
    return justPoints.toArray(retVal);
  }

  class StringExemplar extends Exemplar {
    public final String value;
    public StringExemplar(final double[] coords, final String value) {
      super(coords);
      this.value = value;
    }
  }
}

Optimisations to the Tree

In this section, we will introduce the following optimisations to our tree:

  • Bucket leaves. Leaves will store multiple exemplars, only being split when they overflow. This makes our split points less dependent on the order in which we receive points and gives us more opportunity to choose a better split point.
  • Bounds-overlap-ball testing. Each tree stores the bounds of the smallest hyperrect which will contain all the data in that tree. This is smaller—maybe significantly so—than the bounds which that sub-tree occupy, giving us more opportunities to skip sub-trees. We will check that the minimum distance from each tree's content bounds overlaps with our search hypersphere. This lets us avoid walking all the way to a leaf if a tree's content bounds overlap with our search hypersphere but none of its children's content bounds do.
  • Splitting choice. We will split on the mean of the dimension with the largest variance. Splitting on the mean does not guarantee that our tree will be balanced, but is easier for us since we would have already calculated the mean in finding the dimension with largest variance. This will add more overhead to our add times since we will compute running means and variances each time an exemplar is added.
  • Singularity tracking. We will actively check whether all the exemplars in a leaf have the same domain. This is necessary to avoid getting ourselves caught in an infinite loop trying to repeatedly split an unsplittable leaf. It also lets us avoid repeated distance calculations when searching for nearest neighbours. Practically, unless all your dimensions are limited, discrete values, this optimisation is unlikely to make much impact.
  • Dropping PriorityQueue. PriorityQueue is intensely slow, despite the fact that it implements an often practically efficient algorithm. In the development of this tutorial, switching to a PriorityQueue has sometimes pushed an implementation into the next order of magnitude. It seems that comparing generic types in Java is inescapably slow since type erasure means Comparator<T> must receive two Objects, cast them, then call the real code.

Basic Tree

Once again, we start off with a basic binary tree. Each tree must store its own bucket size since it will not have a reference to its parent Bucket size is not static because we might want to build multiple trees with different bucket sizes.

public final class OptKdTree<X extends Exemplar> {
  final Queue<X> data;
  OptKdTree<X> left = null, right = null;

  // Number of exemplars to hold in a leaf before splitting
  private final int bucketSize;
  private static final int DEFAULT_BUCKET_SIZE = 10;

  public OptKdTree() {
    this(DEFAULT_BUCKET_SIZE);
  }

  public OptKdTree(int bucketSize) {
    this.bucketSize = bucketSize;
    this.data = new LinkedList<X>();
  }

  private final boolean
  isTree() { return left != null; }

  [...]
}

Adding

Now that our code is a little more sophisticated, it becomes profitable to have a dedicated method to bulk-load our tree. This also makes it easier to add dynamic rebalance support later if we need it. To take advantage the additional information we get from bulk loading, we decide whether to split a leaf only after the add has been completed.

public void
add(X ex) {
  OptKdTree<X> tree = addNoSplit(this, ex);
  if (shouldSplit(tree)) {
    split(tree);
  }
}

public void
addAll(Collection<X> exs) {
  final Set<OptKdTree<X>> modTrees =
    new HashSet<OptKdTree<X>>();
  for(X ex : exs) {
    modTrees.add(addNoSplit(this, ex));
  }

  for(OptKdTree<X> tree : modTrees) {
    if (shouldSplit(tree)) {
      split(tree);
    }
  }
}

To add an exemplar, we traverse the tree from top down until we find a leaf. Then we add the exemplar to the list at that leaf. For now, there is no special ordering to the list of exemplars in a leaf. To decide which sub-tree to traverse, each tree stores two values—a splitting dimension and a splitting value. If our new exemplar's domain along the splitting dimension is greater than the tree's splitting value, we put it in the right sub-tree. Otherwise, it goes in the left one.

Since adding takes little time compared to searching, we take this opportunity to make some optimisations:

  • We keep track of the actual hyperrect the points in this tree occupy. This lets us rule out a tree, even though its space may intersect with our search sphere, if it has no chance of containing any points within our search sphere. This hyperrect is fully defined by just storing its two extreme corners.
  • To save us having to do a full iteration when we come to split a leaf, we compute the running mean and variance for each dimension using Welford's method:
   <math>
   M_1 = x_1,\qquad S_1 = 0</math>
   <math>
   M_k = M_{k-1} + {x_k - M_{k-1} \over k} \qquad(exMean)</math>
   <math>
   S_k = S_{k-1} + (x_k - M_{k-1})\times(x_k - M_k) \qquad(exSumSqDev)</math>
  • We keep track of whether all exemplars in this leaf have the same domain. If they do, we know that comparisons on one exemplar apply to all exemplars in that leaf.
// These aren't initialised until add() is called.
private double[] exMean = null, exSumSqDev = null;

// Optimisation when sub-tree contains only duplicates
private boolean singularity = true;

// Optimisation for searches. This lets us skip a node if its
// scope intersects with a search hypersphere but it doesn't contain
// any points that actually intersect.
private double[] contentMax = null, contentMin = null;

private int
dimensions() { return contentMax.length; }

// Addition

// Adds an exemplar without splitting overflowing leaves.
// Returns leaf to which exemplar was added.
private static <X extends Exemplar> OptKdTree<X>
addNoSplit(OptKdTree<X> tree, X ex) {
  // Some spurious function calls. Optimised for readability over
  // efficiency.
  OptKdTree<X> cursor = tree;
  while (cursor != null) {
    updateBounds(cursor, ex);
    if (cursor.isTree()) {
      // Sub-tree
      cursor = ex.domain[cursor.splitDim] <= cursor.split
        ? cursor.left : cursor.right;
    } else {
      // Leaf

      // Add exemplar to leaf
      cursor.data.add(ex);

      // Calculate running mean and sum of squared deviations
      final int nExs = cursor.data.size();
      final int dims = cursor.dimensions();
      if (nExs == 1) {
        cursor.exMean = Arrays.copyOf(ex.domain, dims);
        cursor.exSumSqDev = new double[dims];
      } else {
        for(int d = 0; d < dims; d++) {
          final double coord = ex.domain[d];
          final double oldMean = cursor.exMean[d], newMean;
          cursor.exMean[d] = newMean =
            oldMean + (coord - oldMean)/nExs;
          cursor.exSumSqDev[d] = cursor.exSumSqDev[d]
            + (coord - oldMean)*(coord - newMean);
        }
      }

      // Check that data are still uniform
      if (cursor.singularity) {
        final Queue<X> cExs = cursor.data;
        if (cExs.size() > 0 && !ex.collocated(cExs.peek()))
          cursor.singularity = false;
      }

      // Finished walking
      return cursor;
    }
  }
  throw new RuntimeException("Walked tree without adding anything");
}

To update the bounding hyperrect for the contents of a tree, we iterate through each dimension, extending (if necessary) the bounds in that dimension to contain the new exemplar.

private static <T extends Exemplar> void
updateBounds(OptKdTree<T> tree, Exemplar ex) {
  final int dims = ex.domain.length;
  if (tree.contentMax == null) {
    tree.contentMax = Arrays.copyOf(ex.domain, dims);
    tree.contentMin = Arrays.copyOf(ex.domain, dims);
  } else {
    for(int d = 0; d < dims; d++) {
      final double dimVal = ex.domain[d];
      if (dimVal > tree.contentMax[d])
        tree.contentMax[d] = dimVal;
      else if (dimVal < tree.contentMin[d])
        tree.contentMin[d] = dimVal;
    }
  }
}

Splitting

We only split when a leaf's exemplars exceed its bucket size. It is only worth splitting if the exemplars don't all have the same domain.

private static <T extends Exemplar> boolean
shouldSplit(OptKdTree<T> tree) {
  return tree.data.size() > tree.bucketSize
    && !tree.singularity;
}

Thanks to our pre-computation, splitting is usually straight-forward. We iterate through each dimension to find the one with the largest variance (skip the unnecessary division), then we can directly look up the mean of that dimension. This strategy might not successfully split the tree if exemplars are so close together that rounding errors in computing the mean result in the mean not lying strictly between the exemplars in a leaf. When this happens, we resort to repeatedly trying a dimension and an exemplar at random for our splitting point. We only call split if the leaf is not a singularity, so eventually we will find a point that does divide our exemplars.

To make sure the point actually does divide our data, we separate our data into two lists destined for each sub-tree. Once we have found a successful split point, we bulk load our sub-trees, store information about our split point and let go of the exemplars stored in the tree.

@SuppressWarnings("unchecked") private static <T extends Exemplar> void
split(OptKdTree<T> tree) {
  assert !tree.singularity;
  // Find dimension with largest variance to split on
  double largestVar = -1;
  int splitDim = 0;
  for(int d = 0; d < tree.dimensions(); d++) {
    // Don't need to divide by number of data to find largest
    // variance
    final double var = tree.exSumSqDev[d];
    if (var > largestVar) {
      largestVar = var;
      splitDim = d;
    }
  }

  // Find mean as position for our split
  double splitValue = tree.exMean[splitDim];

  // Check that our split actually splits our data. This also lets
  // us bulk load data into sub-trees, which is more likely
  // to keep optimal balance.
  final Queue<T> leftExs = new LinkedList<T>();
  final Queue<T> rightExs = new LinkedList<T>();
  for(T s : tree.data) {
    if (s.domain[splitDim] <= splitValue)
      leftExs.add(s);
    else
      rightExs.add(s);
  }
  int leftSize = leftExs.size();
  final int treeSize = tree.data.size();
  if (leftSize == treeSize || leftSize == 0) {
    System.err.println(
      "WARNING: Randomly splitting non-uniform tree");
    // We know the data aren't all the same, so try picking
    // an exemplar and a dimension at random for our split point

    // This might take several tries, so we copy our data to
    // an array to speed up process of picking a random point
    Object[] exs = tree.data.toArray();
    while (leftSize == treeSize || leftSize == 0) {
      leftExs.clear();
      rightExs.clear();

      splitDim = (int)
        Math.floor(Math.random()*tree.dimensions());
      final int splitPtIdx = (int)
        Math.floor(Math.random()*exs.length);
      // Cast is inevitable consequence of java's inability to
      // create a generic array
      splitValue = ((T)exs[splitPtIdx]).domain[splitDim];
      for(T s : tree.data) {
        if (s.domain[splitDim] <= splitValue)
          leftExs.add(s);
        else
          rightExs.add(s);
      }
      leftSize = leftExs.size();
    }
  }

  // We have found a valid split. Start building our sub-trees
  final OptKdTree<T> left = new OptKdTree<T>(tree.bucketSize);
  final OptKdTree<T> right = new OptKdTree<T>(tree.bucketSize);
  left.addAll(leftExs);
  right.addAll(rightExs);

  // Finally, commit the split
  tree.splitDim = splitDim;
  tree.split = splitValue;
  tree.left = left;
  tree.right = right;

  // Let go of data (and their running stats) held in this leaf
  tree.data.clear();
  tree.exMean = tree.exSumSqDev = null;
}

Distance calculations

Before we can start searching, we need to define our distance metric. The distance calculation has already been introduced before.

Squared Euclidean distance:
<math>distance_{x\rightarrow y}^2 = \sum (x[d] - y[d])^2</math>

This will be called for every point examined, so we should optimise it well.

private static double
distanceSqFrom(double[] p1, double[] p2) {
  double dSq = 0;
  for(int d = 0; d < p1.length; d++) {
    final double dst = p1[d] - p2[d];
    if (dst != 0)
      dSq += dst*dst;
  }
  return dSq;
}

We will also need to find the minimum distance from a hyperrect to a query point. This will let us rule out trees whose content hyperrects are outside our search sphere. To find the minimum distance between a hyperrect and a given point, we need first find the closest point along bounds of that hyperrect to the point. This is easy to do, even if not immediately obvious.

Along each dimension,
<math>nearestPoint[d]=\begin{cases}

min[d] & target[d]\le min[d]\\ target[d] & min<target[d]<max[d]\\ max[d] & target[d]\ge max[d]\end{cases}</math>

After that, we just calculate the distance from our query point to this nearest point on the hyperrect. The code below has been optimised because it gets called a lot during a search.

// Gets distance from target of nearest point on hyper-rect defined
// by supplied min and max bounds
private static double
minDistanceSqFrom(double[] target, double[] min, double[] max) {
  double distanceSq = 0;
  for(int d = 0; d < target.length; d++) {
    if (target[d] < min[d]) {
      final double dst = min[d] - target[d];
      distanceSq += dst*dst;
    } else if (target[d] > max[d]) {
      final double dst = max[d] - target[d];
      distanceSq += dst*dst;
    }
  }
  return distanceSq;
}

Searching

Before we dive into searching, we will introduce two helper classes.

  • SearchStackEntry stores a tree along any data which will be useful to keep on the search stack. Currently, this additional information consists only of whether a bounds check should be done.
//
// class SearchStackEntry
//

private static class SearchStackEntry<T extends Exemplar> {
  public final boolean needBoundsCheck;
  public final OptKdTree<T> tree;

  public SearchStackEntry(boolean needBoundsCheck,
    OptKdTree<T> tree) {
    this.needBoundsCheck = needBoundsCheck;
    this.tree = tree;
  }
}
  • SearchState holds some variables we want to be updated in after each nearest neighbour is added. Its sole purpose is to let us separate searching sub-trees and leaves into different methods without sacrificing performance. This breaks up a long method and makes profiling easier. This class also shows that we have also switched from using a PriorityQueue to a TreeMap to collect our results. This change alone results in a large performance boost.
//
// class SearchState
//
// Holds data about current state of the search. Used for live updating
// of pruning distance.

private static class SearchState<X extends Exemplar> {
  final int nResults;
  final NavigableMap<Double, Queue<PrioNode<X>>> results;
  int resultSize = 0;

  public SearchState(int nResults) {
    this.nResults = nResults;
    results = new TreeMap<Double, Queue<PrioNode<X>>>();
  }
}

Our search method does a depth-first search using a stack. Each time we pop our stack, we do a bounds check if necessary, then pass the call on to the appropriate method.

public Iterable<PrioNode<X>>
search(double[] query, int nResults) {
  // Forward to a static method to avoid accidental reference to
  // instance variables while descending the tree
  return search(this, query, nResults);
}

private static <X extends Exemplar> Iterable<PrioNode<X>>
search(OptKdTree<X> tree, double[] query, int nResults) {
  final SearchState<X> state = new SearchState<X>(nResults);
  final Deque<SearchStackEntry<X>> stack =
    new LinkedList<SearchStackEntry<X>>();
  if (tree.contentMin != null)
    stack.addLast(new SearchStackEntry<X>(false, tree));
TREE_WALK:
  while (!stack.isEmpty()) {
    final SearchStackEntry<X> entry = stack.removeLast();
    final OptKdTree<X> cur = entry.tree;

    if (entry.needBoundsCheck && state.results.size() >= nResults) {
      final double d = minDistanceSqFrom(query,
        cur.contentMin, cur.contentMax);
      if (d > state.results.lastKey())
        continue TREE_WALK;
    }

    if (cur.isTree()) {
      searchTree(query, cur, stack);
    } else {
      searchLeaf(query, cur, state);
    }
  }

  // Collect results
  final List<PrioNode<X>> results = new LinkedList<PrioNode<X>>();
  for(Queue<PrioNode<X>> nodes : state.results.values()) {
    results.addAll(nodes);
  }
  return results;
}

Searching a Tree

We basically want to add both sub-trees to the stack, if they are not empty. As a heuristic, we think the sub-tree on the same side of the split as our query point is more likely to contain closer points, so we search that sub-tree first by adding it to the stack last. You might try to avoid a distance calculation by not requiring a bounds check for the heuristically nearer sub-tree, but that does not seem to significantly improve performance.

private static <X extends Exemplar> void
searchTree(double[] query, OptKdTree<X> tree,
  Deque<SearchStackEntry<X>> stack)
{
  OptKdTree<X> nearTree = tree.left, farTree = tree.right;
  if (query[tree.splitDim] > tree.split) {
    nearTree = tree.right;
    farTree = tree.left;
  }

  // These variables let us skip empty sub-trees
  boolean nearEmpty = nearTree.contentMin == null;
  boolean farEmpty = farTree.contentMin == null;

  // Add nearest sub-tree to stack later so we descend it
  // first. This is likely to constrict our max distance
  // sooner, resulting in less visited nodes
  if (!farEmpty) {
    stack.addLast(new SearchStackEntry<X>(true, farTree));
  }

  if (!nearEmpty) {
    stack.addLast(new SearchStackEntry<X>(true, nearTree));
  }
}

Searching a Leaf

A Better Results Heap

Extensions

This tree does not include any support for deletion or re-balancing. You can get away with this if you plan to re-build the entire tree regularly.

Deletion support would be easy to add to this tree. Since all exemplars are stored in leaves at the bottom, we need only walk the tree and remove the exemplar from the leaf containing it. We would then merge sub-trees if multiple deletions has left one of them empty. Unless our bucket width is really small, we don't need to merge sub-trees after each deletion since the addition of a subsequent exemplar is unlikely to result in choosing a significantly different split point.

Re-balancing with this tree is trickier. It does not currently store depth, so we don't know when a re-balance is needed. Because exemplars are only stored in leaves, collecting all the exemplars under a sub-tree would require exhaustively walking that sub-tree.

  • If a depth is to be stored, each tree would need a pointer to its parent tree so we propogate depth updates when a leaf is split.
  • We can avoid an exhaustive walk by storing each exemplar in every tree on its path from root to leaf. We then can re-balance easily by dropping both sub-trees and splitting the given tree anew. This approach has the down side of making inserts and deletes more complex since we have to update the exemplar list of every tree along the path from root to leaf.

A TreeMap does more processing than we need since it keeps all elements in order, when we only want access to the one with largest distance. Substituting a custom heap can significantly improve performance. Note that Java's standard PriorityQueue will not suffice since it does not handle duplicates the way we want.

See also