Difference between revisions of "User:Duyn/BucketKdTree"

From Robowiki
Jump to navigation Jump to search
(Updated to latest version)
(Pointer to uploaded JAR instead.)
 
Line 1: Line 1:
Below is the full source code to a tree that will form the basis for a future kd tree tutorial I'm working on. This code is released under the [[RWPCL]]. It is decently optimised—theoretically, anyway. I don't specialise in writing fast code, but the algorithms used are designed to ensure whole trees are ruled out as early as possible.
+
There used to be a kd-tree on this page. It has since been packaged into a JAR file and is available at: [[File:Duyn_tutorial_kd_trees.jar]]. Its source is available under the [[RWPCL]].
 
 
===Exemplar.java===
 
Base class for items to be added into the tree. Sub-classes can carry useful data like guess factors.
 
 
 
<pre>
 
import java.util.Arrays;
 
/**
 
* A sample point in multi-dimensional space. Needed because each sample
 
* may contain an arbitrary payload.
 
*
 
* Note: this class does not make allowance for a payload. Sub-class if
 
* you want to store something more than just data points.
 
*
 
* @author dnn
 
*/
 
public class Exemplar {
 
  public final double[] domain;
 
 
 
  public Exemplar(double[] domain) {
 
    this.domain = domain;
 
  }
 
 
 
  public boolean domainEquals(final Exemplar other) {
 
    return Arrays.equals(domain, other.domain);
 
  }
 
 
 
  @Override public String
 
  toString() {
 
    return Arrays.toString(domain);
 
  }
 
}
 
</pre>
 
 
 
===BucketKdTree.java===
 
Tree with k-nearest neighbour search. Does not support deletion or rebalancing&mdash;a re-build is required if you want to do either.
 
 
 
<pre>
 
import java.util.*;
 
 
 
/**
 
* A k-dimensional binary partitioning tree which splits space on the
 
* mean of the dimension with the largest variance. Points are held in
 
* buckets so we can pick a better split point than whatever comes first.
 
*
 
* Does not store tree depth. If you want balance, re-build the tree
 
* periodically.
 
*
 
* Optimisations in this tree assume distance metric is euclidian distance.
 
* May work if retrofitted with other metrics, but that is purely
 
* accidental.
 
*
 
* Note: results can become unpredictable if values are different but so
 
* close together that rounding errors in computing their mean result in
 
* all exemplars being on one side of the mean. Performance degrades when
 
* this occurs. Nearest neighbour search tested to work up to range
 
* [1, 1 + 5e-16).
 
*
 
* Ideas for path ordering and bounds come from:
 
*  NEAL SAMPLE, MATTHEW HAINES, MARK ARNOLD, TIMOTHY PURCELL,
 
*    'Optimizing Search Strategies in k-d Trees'
 
*    http://ilpubs.stanford.edu:8090/723/
 
*
 
* Computation of variance from:
 
*  John Cook, 'Accurately computing running variance'
 
*    http://www.johndcook.com/standard_deviation.html
 
*
 
* Terminology note: points are called Exemplars. They must all be
 
* descended from Exemplar class. Position in k-d space is stored in each
 
* exemplar's domain member. This is to avoid conflicting with already
 
* existing classes referring to geometric points.
 
*
 
* Terminology comes from:
 
*  Andrew Moore, 'An intoductory tutorial on kd'
 
*    http://www.autonlab.org/autonweb/14665
 
*
 
* @author dnn
 
*/
 
public class BucketKdTree<T extends Exemplar> {
 
 
 
  // Only leaf nodes contain points
 
  private List<T> exemplars = new LinkedList<T>();
 
  // These aren't initialised until add() is called.
 
  private double[] exMean;
 
  private double[] exSumSqDev;
 
 
 
  // Optimisation when tree contains large number of duplicates
 
  private boolean exemplarsAreUniform = true;
 
 
 
  private int bucketSize;
 
  private BucketKdTree<T> left, right;
 
 
 
  private int splitDim;
 
  private double split;
 
 
 
  private int dimensions = 0;
 
 
 
  // Optimisation for searches. This lets us skip a node if its
 
  // scope intersects with a search hypersphere but it doesn't contain
 
  // any points that actually intersect.
 
  private double[] maxBounds;
 
  private double[] minBounds;
 
 
 
  public BucketKdTree(int bucketSize) {
 
    this.bucketSize = bucketSize;
 
  }
 
 
 
  //
 
  // PUBLIC METHODS
 
  //
 
 
 
  public void
 
  add(T ex) {
 
    BucketKdTree<T> tree = addNoSplit(this, ex);
 
    if (shouldSplit(tree)) {
 
      split(tree);
 
    }
 
  }
 
 
 
  public void
 
  addAll(Collection<T> exs) {
 
    // Some spurious function calls. Optimised for readability over
 
    // efficiency.
 
    final Set<BucketKdTree<T>> modTrees =
 
      new HashSet<BucketKdTree<T>>();
 
    for(T ex : exs) {
 
      modTrees.add(addNoSplit(this, ex));
 
    }
 
 
 
    for(BucketKdTree<T> tree : modTrees) {
 
      if (shouldSplit(tree)) {
 
        split(tree);
 
      }
 
    }
 
  }
 
 
 
  public SortedMap<Double, List<T>>
 
  search(double[] query, int nMinResults) {
 
    // Forward to a static method to avoid accidental reference to
 
    // instance variables while descending the tree
 
    return search(this, query, nMinResults);
 
  }
 
 
 
  @Override public String
 
  toString() {
 
    return toString("");
 
  }
 
 
 
  //
 
  // IMPLEMENTATION DETAILS
 
  //
 
 
 
  private boolean
 
  isTree() { return left != null; }
 
 
 
  // Addition
 
 
 
  // Adds an exemplar without splitting overflowing leaves.
 
  // Returns leaf to which exemplar was added.
 
  private static <T extends Exemplar> BucketKdTree<T>
 
  addNoSplit(BucketKdTree<T> tree, T ex) {
 
    // Some spurious function calls. Optimised for readability over
 
    // efficiency.
 
    BucketKdTree<T> cursor = tree;
 
    while (cursor != null) {
 
      updateBounds(cursor, ex);
 
      if (cursor.isTree()) {
 
        // Sub-tree
 
        cursor = ex.domain[cursor.splitDim] <= cursor.split
 
          ? cursor.left : cursor.right;
 
      } else {
 
        // Leaf
 
 
 
        // Infer dimensions if we haven't already
 
        if (cursor.dimensions == 0)
 
          cursor.dimensions = ex.domain.length;
 
 
 
        // Add exemplar to leaf
 
        cursor.exemplars.add(ex);
 
 
 
        // Calculate running mean and sum of squared deviations
 
        final int nExs = cursor.exemplars.size();
 
        final int dims = cursor.dimensions;
 
        if (nExs == 1) {
 
          cursor.exMean = Arrays.copyOf(ex.domain, dims);
 
          cursor.exSumSqDev = new double[dims];
 
        } else {
 
          for(int d = 0; d < dims; d++) {
 
            final double coord = ex.domain[d];
 
            final double oldMean = cursor.exMean[d], newMean;
 
            cursor.exMean[d] = newMean =
 
              oldMean + (coord - oldMean)/nExs;
 
            cursor.exSumSqDev[d] = cursor.exSumSqDev[d]
 
              + (coord - oldMean)*(coord - newMean);
 
          }
 
        }
 
 
 
        // Check that exemplars are still uniform
 
        if (cursor.exemplarsAreUniform) {
 
          final List<T> cExs = cursor.exemplars;
 
          if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0)))
 
            cursor.exemplarsAreUniform = false;
 
        }
 
 
 
        // Finished walking
 
        return cursor;
 
      }
 
    }
 
    throw new RuntimeException("Walked tree without adding anything");
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  updateBounds(BucketKdTree<T> tree, Exemplar ex) {
 
    final int dims = ex.domain.length;
 
    if (tree.maxBounds == null) {
 
      tree.maxBounds = Arrays.copyOf(ex.domain, dims);
 
      tree.minBounds = Arrays.copyOf(ex.domain, dims);
 
    } else {
 
      for(int d = 0; d < dims; d++) {
 
        final double dimVal = ex.domain[d];
 
        if (dimVal > tree.maxBounds[d])
 
          tree.maxBounds[d] = dimVal;
 
        else if (dimVal < tree.minBounds[d])
 
          tree.minBounds[d] = dimVal;
 
      }
 
    }
 
  }
 
 
 
  // Splitting (internal operation)
 
 
 
  private static <T extends Exemplar> boolean
 
  shouldSplit(BucketKdTree<T> tree) {
 
    return tree.exemplars.size() > tree.bucketSize
 
      && !tree.exemplarsAreUniform;
 
  }
 
 
 
  @SuppressWarnings("unchecked") private static <T extends Exemplar> void
 
  split(BucketKdTree<T> tree) {
 
    assert !tree.exemplarsAreUniform;
 
    // Find dimension with largest variance to split on
 
    double largestVar = -1;
 
    int splitDim = 0;
 
    for(int d = 0; d < tree.dimensions; d++) {
 
      // Don't need to divide by number of exemplars to find largest
 
      // variance
 
      final double var = tree.exSumSqDev[d];
 
      if (var > largestVar) {
 
        largestVar = var;
 
        splitDim = d;
 
      }
 
    }
 
 
 
    // Find mean as position for our split
 
    double splitValue = tree.exMean[splitDim];
 
 
 
    // Check that our split actually splits our data. This also lets
 
    // us bulk load exemplars into sub-trees, which is more likely
 
    // to keep optimal balance.
 
    final List<T> leftExs = new LinkedList<T>();
 
    final List<T> rightExs = new LinkedList<T>();
 
    for(T s : tree.exemplars) {
 
      if (s.domain[splitDim] <= splitValue)
 
        leftExs.add(s);
 
      else
 
        rightExs.add(s);
 
    }
 
    int leftSize = leftExs.size();
 
    final int treeSize = tree.exemplars.size();
 
    if (leftSize == treeSize || leftSize == 0) {
 
      System.err.println(
 
        "WARNING: Randomly splitting non-uniform tree");
 
      // We know the exemplars aren't all the same, so try picking
 
      // an exemplar and a dimension at random for our split point
 
 
 
      // This might take several tries, so we copy our exemplars to
 
      // an array to speed up process of picking a random point
 
      Object[] exs = tree.exemplars.toArray();
 
      while (leftSize == treeSize || leftSize == 0) {
 
        leftExs.clear();
 
        rightExs.clear();
 
 
 
        splitDim = (int)
 
          Math.floor(Math.random()*tree.dimensions);
 
        final int splitPtIdx = (int)
 
          Math.floor(Math.random()*exs.length);
 
        // Cast is inevitable consequence of java's inability to
 
        // create a generic array
 
        splitValue = ((T)exs[splitPtIdx]).domain[splitDim];
 
        for(T s : tree.exemplars) {
 
          if (s.domain[splitDim] <= splitValue)
 
            leftExs.add(s);
 
          else
 
            rightExs.add(s);
 
        }
 
        leftSize = leftExs.size();
 
      }
 
    }
 
 
 
    // We have found a valid split. Start building our sub-trees
 
    final BucketKdTree<T> left =
 
      new BucketKdTree<T>(tree.bucketSize);
 
    final BucketKdTree<T> right =
 
      new BucketKdTree<T>(tree.bucketSize);
 
    left.addAll(leftExs);
 
    right.addAll(rightExs);
 
 
 
    // Finally, commit the split
 
    tree.splitDim = splitDim;
 
    tree.split = splitValue;
 
    tree.left = left;
 
    tree.right = right;
 
 
 
    // Let go of exemplars (and their running stats) held in this leaf
 
    tree.exemplars = null;
 
    tree.exMean = tree.exSumSqDev = null;
 
  }
 
 
 
  // Searching
 
 
 
  // May return more results than requested if multiple exemplars have
 
  // same distance from target.
 
  //
 
  // Note: this function works with squared distances to avoid sqrt()
 
  // operations
 
  private static <T extends Exemplar> SortedMap<Double, List<T>>
 
  search(BucketKdTree<T> tree, double[] query, int nMinResults) {
 
    // distance => list of points that distance away from query
 
    final NavigableMap<Double, List<T>> results =
 
      new TreeMap<Double, List<T>>();
 
 
 
    final SearchState state = new SearchState();
 
    final Deque<SearchStackEntry<T>> stack =
 
      new LinkedList<SearchStackEntry<T>>();
 
    stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree));
 
    while (!stack.isEmpty()) {
 
      final SearchStackEntry<T> entry = stack.removeFirst();
 
      final BucketKdTree<T> cur = entry.tree;
 
 
 
      if (cur.isTree()) {
 
        searchTree(query, nMinResults, cur, state, stack);
 
      } else if (entry.minDFromQ <= state.maxDistance
 
        || state.nResults < nMinResults)
 
      {
 
        searchLeaf(query, nMinResults, cur, state, results);
 
      }
 
    }
 
 
 
    return results;
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  searchTree(double[] query, int nMinResults, BucketKdTree<T> tree,
 
    SearchState searchState, Deque<SearchStackEntry<T>> stack)
 
  {
 
    // Left is presumed near. This is verified further down.
 
    BucketKdTree<T> nearTree = tree.left, farTree = tree.right;
 
    // These variables let us skip empty sub-trees
 
    boolean nearEmpty = nearTree.minBounds == null;
 
    boolean farEmpty = farTree.minBounds == null;
 
 
 
    // Find distance from nearest possible point in each
 
    // sub-tree to query. If that is greater than max distance,
 
    // we can rule out that sub-tree.
 
    double nearD = nearEmpty ? Double.POSITIVE_INFINITY
 
      : minDistanceSqFrom(query,
 
        nearTree.minBounds, nearTree.maxBounds);
 
    double farD = farEmpty ? Double.POSITIVE_INFINITY
 
      : minDistanceSqFrom(query,
 
        farTree.minBounds, farTree.maxBounds);
 
 
 
    // Swap near and far if they're incorrect
 
    if (farD < nearD) {
 
      final double tmpD = nearD;
 
      final BucketKdTree<T> tmpTree = nearTree;
 
      final boolean tmpEmpty = nearEmpty;
 
 
 
      nearD = farD;
 
      nearTree = farTree;
 
      nearEmpty = farEmpty;
 
 
 
      farD = tmpD;
 
      farTree = tmpTree;
 
      farEmpty = tmpEmpty;
 
    }
 
 
 
    // Add nearest sub-tree to stack later so we descend it
 
    // first. This is likely to constrict our max distance
 
    // sooner, resulting in less visited nodes
 
    if (!farEmpty && (farD <= searchState.maxDistance
 
                      || searchState.nResults < nMinResults))
 
    {
 
      stack.addFirst(new SearchStackEntry<T>(farD, farTree));
 
    }
 
 
 
    if (!nearEmpty && (nearD <= searchState.maxDistance
 
                      || searchState.nResults < nMinResults))
 
    {
 
      stack.addFirst(new SearchStackEntry<T>(nearD, nearTree));
 
    }
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  searchLeaf(double[] query, int nMinResults,
 
    BucketKdTree<T> leaf, SearchState searchState,
 
    NavigableMap<Double, List<T>> results)
 
  {
 
    // Keep track of elements at max distance so we know
 
    // whether we can just drop entire list of furthest
 
    // exemplars
 
    final int nMinResultsBeforeAddition = nMinResults - 1;
 
    for(T ex : leaf.exemplars) {
 
      final double exD = distanceSqFrom(query, ex.domain);
 
      if (searchState.nResults < nMinResults
 
        || exD == searchState.maxDistance)
 
      {
 
        // Blindly add this exemplar
 
        List<T> exsAtD = getOrElseInit(results, exD);
 
        if (leaf.exemplarsAreUniform) {
 
          // No need to go through every one if all
 
          // exemplars are the same
 
          exsAtD.addAll(leaf.exemplars);
 
          searchState.nResults += leaf.exemplars.size();
 
        } else {
 
          exsAtD.add(ex);
 
          searchState.nResults++;
 
        }
 
 
 
        // Update information about furthest exemplars
 
        if (exD > searchState.maxDistance
 
          || searchState.maxDistance == Double.POSITIVE_INFINITY)
 
        {
 
          searchState.maxDistance = exD;
 
        }
 
 
 
        if (searchState.maxDistance == exD)
 
          searchState.nExsAtMaxD = exsAtD.size();
 
 
 
      } else if (exD < searchState.maxDistance) {
 
        // Point closer than furthest neighbour
 
 
 
        if (searchState.nResults - searchState.nExsAtMaxD
 
          >= nMinResultsBeforeAddition)
 
        {
 
          // Dropping furthest exemplars won't leave
 
          // us with too little to meet return
 
          // minimum
 
          results.remove(searchState.maxDistance);
 
          searchState.nResults -= searchState.nExsAtMaxD;
 
        }
 
 
 
        // Add new exemplar
 
        List<T> exsAtD = getOrElseInit(results, exD);
 
        if (leaf.exemplarsAreUniform) {
 
          // No need to go through every one if all
 
          // exemplars are the same
 
          exsAtD.addAll(leaf.exemplars);
 
          searchState.nResults += leaf.exemplars.size();
 
        } else {
 
          exsAtD.add(ex);
 
          searchState.nResults++;
 
        }
 
 
 
        // Update information about furthest exemplars
 
        Map.Entry<Double, List<T>> lastEnt = results.lastEntry();
 
        searchState.maxDistance = lastEnt.getKey();
 
        searchState.nExsAtMaxD = lastEnt.getValue().size();
 
      } // exD < maxDistance
 
 
 
      // No need to keep going if all exemplars are
 
      // the same
 
      if (leaf.exemplarsAreUniform) break;
 
    } // for(T ex : cur.exemplars)
 
  }
 
 
 
  private static <T extends Exemplar> List<T>
 
  getOrElseInit(Map<Double, List<T>> results, double dst) {
 
    List<T> lst = results.get(dst);
 
    if (lst == null) {
 
      lst = new LinkedList<T>();
 
      results.put(dst, lst);
 
    }
 
    return lst;
 
  }
 
 
 
  // Dumping to string (debug)
 
 
 
  private String
 
  toString(String indent) {
 
    if (isTree()){
 
      return String.format("%s{|%d|%f|\n%s\n%s} ", indent,
 
        splitDim, split, left.toString(indent + "\t"),
 
        right.toString(indent + "\t"));
 
    } else {
 
      return String.format("%sL%s", indent,
 
        Arrays.toString(exemplars.toArray()));
 
    }
 
  }
 
 
 
  // Distance calculations
 
 
 
  // Gets distance from target of nearest point on hyper-rect defined
 
  // by supplied min and max bounds
 
  private static double
 
  minDistanceSqFrom(double[] target, double[] min, double[] max) {
 
    // Note: profiling shows this is called lots of times, so it pays
 
    // to be well optimised
 
    double distanceSq = 0;
 
    for(int d = 0; d < target.length; d++) {
 
      final double coord = target[d];
 
      double nearCoord;
 
      if (((nearCoord = min[d]) > coord)
 
        || ((nearCoord = max[d]) < coord))
 
      {
 
        final double dst = nearCoord - coord;
 
        distanceSq += dst*dst;
 
      }
 
    }
 
    return distanceSq;
 
  }
 
 
 
  // Accessible to testing
 
  static double
 
  distanceSqFrom(double[] p1, double[] p2) {
 
    // Note: profiling shows this is called lots of times, so it pays
 
    // to be well optimised
 
    double dSq = 0;
 
    for(int d = 0; d < p1.length; d++) {
 
      final double dst = p1[d] - p2[d];
 
      if (dst != 0)
 
        dSq += dst*dst;
 
    }
 
    return dSq;
 
  }
 
 
 
  //
 
  // class SearchStackEntry
 
  //
 
  // Stores a precomputed distance so we don't have to do it again
 
  // when we pop the tree off the search stack.
 
 
 
  private static class SearchStackEntry<T extends Exemplar> {
 
    public final double minDFromQ;
 
    public final BucketKdTree<T> tree;
 
 
 
    public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) {
 
      this.minDFromQ = minDFromQ;
 
      this.tree = tree;
 
    }
 
  }
 
 
 
  //
 
  // class SearchState
 
  //
 
  // Holds data about current state of the search. Used for live updating
 
  // of pruning distance.
 
 
 
  private static class SearchState {
 
    int nResults = 0;
 
    double maxDistance = Double.POSITIVE_INFINITY;
 
    int nExsAtMaxD = 0;
 
  }
 
}
 
</pre>
 

Latest revision as of 13:23, 13 March 2010

There used to be a kd-tree on this page. It has since been packaged into a JAR file and is available at: File:Duyn tutorial kd trees.jar. Its source is available under the RWPCL.