Difference between revisions of "User:Duyn/BucketKdTree"

From Robowiki
Jump to navigation Jump to search
(Source code for upcoming tutorial. Large enough to warrant own page.)
 
(Pointer to uploaded JAR instead.)
 
(2 intermediate revisions by the same user not shown)
Line 1: Line 1:
Below is the full source code to a tree that will form the basis for a future kd tree tutorial I'm working on. This code is released under the [[RWPCL]]. It is decently optimised—theoretically, anyway. I don't specialise in writing fast code, but the algorithms used are designed to ensure whole trees are ruled out as early as possible.
+
There used to be a kd-tree on this page. It has since been packaged into a JAR file and is available at: [[File:Duyn_tutorial_kd_trees.jar]]. Its source is available under the [[RWPCL]].
 
 
===Exemplar.java===
 
Base class for items to be added into the tree. Sub-classes can carry useful data like guess factors.
 
 
 
<pre>
 
/**
 
* A sample point in multi-dimensional space. Needed because each sample
 
* may contain an arbitrary payload.
 
*
 
* Note: this class does not make allowance for a payload. Sub-class if
 
* you want to store something more than just data points.
 
*
 
* @author duyn
 
*/
 
public class Exemplar {
 
  public final double[] domain;
 
 
 
  public Exemplar(final double[] coords) {
 
    this.domain = coords;
 
  }
 
 
 
  public boolean domainEquals(final Exemplar other) {
 
    return Arrays.equals(domain, other.domain);
 
  }
 
 
 
  @Override public String
 
  toString() {
 
    return Arrays.toString(domain);
 
  }
 
}
 
</pre>
 
 
 
===BucketKdTree.java===
 
Tree with k-nearest neighbour search. Does not support deletion or rebalancing&mdash;a re-build is required if you want to do either.
 
 
 
<pre>
 
import java.util.*;
 
 
 
/**
 
* A k-dimensional binary partitioning tree which splits space on the
 
* mean of the dimension with the largest variance. Points are held in
 
* buckets so we can pick a better split point than whatever comes first.
 
*
 
* Does not store tree depth. If you want balance, re-build the tree
 
* periodically.
 
*
 
* Optimisations in this tree assume distance metric is euclidian distance.
 
* May work if retrofitted with other metrics, but that is purely
 
* accidental.
 
*
 
* Note: results can become unpredictable if values are different but so
 
* close together that rounding errors in computing their mean result in
 
* all exemplars being on one side of the mean. Performance degrades when
 
* this occurs. Nearest neighbour search tested to work up to range
 
* [1, 1 + 5e-16).
 
*
 
* Ideas for path ordering and bounds come from:
 
*  NEAL SAMPLE, MATTHEW HAINES, MARK ARNOLD, TIMOTHY PURCELL,
 
*    'Optimizing Search Strategies in k-d Trees'
 
*    http://ilpubs.stanford.edu:8090/723/
 
*
 
* Computation of variance from:
 
*  John Cook, 'Accurately computing running variance'
 
*    http://www.johndcook.com/standard_deviation.html
 
*
 
* Terminology note: points are called Exemplars. They must all be
 
* descended from Exemplar class. Position in k-d space is stored in each
 
* exemplar's domain member. This is to avoid conflicting with already
 
* existing classes referring to geometric points.
 
*
 
* Terminology comes from:
 
*  Andrew Moore, 'An intoductory tutorial on kd'
 
*    http://www.autonlab.org/autonweb/14665
 
*
 
* @author duyn
 
*/
 
public class BucketKdTree<T extends Exemplar> {
 
 
 
  // Only leaf nodes contain points
 
  private List<T> exemplars = new LinkedList<T>();
 
  // These aren't initialised until add() is called.
 
  private double[] exMean;
 
  private double[] exSumSqDev;
 
 
 
  // Optimisation when tree contains large number of duplicates
 
  private boolean exemplarsAreUniform = true;
 
 
 
  private int bucketSize;
 
  private BucketKdTree<T> left, right;
 
 
 
  private int splitDim;
 
  private double split;
 
 
 
  private final int dimensions;
 
 
 
  // Optimisation for searches. This lets us skip a node if its
 
  // scope intersects with a search hypersphere but it doesn't contain
 
  // any points that actually intersect.
 
  private double[] maxBounds;
 
  private double[] minBounds;
 
 
 
  public BucketKdTree(int bucketSize, int dimensions) {
 
    this.bucketSize = bucketSize;
 
    this.dimensions = dimensions;
 
  }
 
 
 
  //
 
  // PUBLIC METHODS
 
  //
 
 
 
  public void
 
  add(T ex) {
 
    BucketKdTree<T> tree = addNoSplit(this, ex);
 
    if (shouldSplit(tree)) {
 
      split(tree);
 
    }
 
  }
 
 
 
  public void
 
  addAll(Collection<T> exs) {
 
    // Some spurious function calls. Optimised for readability over
 
    // efficiency.
 
    final Set<BucketKdTree<T>> modTrees =
 
      new HashSet<BucketKdTree<T>>();
 
    for(T ex : exs) {
 
      modTrees.add(addNoSplit(this, ex));
 
    }
 
 
 
    for(BucketKdTree<T> tree : modTrees) {
 
      if (shouldSplit(tree)) {
 
        split(tree);
 
      }
 
    }
 
  }
 
 
 
  public SortedMap<Double, List<T>>
 
  search(double[] query, int nMinResults) {
 
    // Forward to a static method to avoid accidental reference to
 
    // instance variables while descending the tree
 
    return search(this, query, nMinResults);
 
  }
 
 
 
  @Override public String
 
  toString() {
 
    return toString("");
 
  }
 
 
 
  //
 
  // IMPLEMENTATION DETAILS
 
  //
 
 
 
  private boolean
 
  isTree() { return left != null; }
 
 
 
  // Addition
 
 
 
  // Adds an exemplar without splitting overflowing leaves.
 
  // Returns leaf to which exemplar was added.
 
  private static <T extends Exemplar> BucketKdTree<T>
 
  addNoSplit(BucketKdTree<T> tree, T ex) {
 
    // Some spurious function calls. Optimised for readability over
 
    // efficiency.
 
    BucketKdTree<T> cursor = tree;
 
    while (cursor != null) {
 
      updateBounds(cursor, ex);
 
      if (cursor.isTree()) {
 
        // Sub-tree
 
        cursor = ex.domain[cursor.splitDim] <= cursor.split
 
          ? cursor.left : cursor.right;
 
      } else {
 
        // Leaf
 
        cursor.exemplars.add(ex);
 
        final int nExs = cursor.exemplars.size();
 
        if (nExs == 1) {
 
          cursor.exMean =
 
            Arrays.copyOf(ex.domain, cursor.dimensions);
 
          cursor.exSumSqDev = new double[cursor.dimensions];
 
        } else {
 
          for(int d = 0; d < cursor.dimensions; d++) {
 
            final double coord = ex.domain[d];
 
 
 
            final double oldExMean = cursor.exMean[d];
 
            final double newMean = cursor.exMean[d] =
 
              oldExMean + (coord - oldExMean)/nExs;
 
 
 
            final double oldSumSqDev = cursor.exSumSqDev[d];
 
            cursor.exSumSqDev[d] = oldSumSqDev
 
              + (coord - oldExMean)*(coord - newMean);
 
          }
 
        }
 
        if (cursor.exemplarsAreUniform) {
 
          final List<T> cExs = cursor.exemplars;
 
          if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0)))
 
            cursor.exemplarsAreUniform = false;
 
        }
 
        return cursor;
 
      }
 
    }
 
    throw new RuntimeException("Walked tree without adding anything");
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  updateBounds(BucketKdTree<T> tree, Exemplar ex) {
 
    final int dims = tree.dimensions;
 
    if (tree.maxBounds == null) {
 
      tree.maxBounds = Arrays.copyOf(ex.domain, dims);
 
      tree.minBounds = Arrays.copyOf(ex.domain, dims);
 
    } else {
 
      for(int d = 0; d < dims; d++) {
 
        final double dimVal = ex.domain[d];
 
        if (dimVal > tree.maxBounds[d])
 
          tree.maxBounds[d] = dimVal;
 
        else if (dimVal < tree.minBounds[d])
 
          tree.minBounds[d] = dimVal;
 
      }
 
    }
 
  }
 
 
 
  // Splitting (internal operation)
 
 
 
  private static <T extends Exemplar> boolean
 
  shouldSplit(BucketKdTree<T> tree) {
 
    return tree.exemplars.size() > tree.bucketSize
 
      && !tree.exemplarsAreUniform;
 
  }
 
 
 
  @SuppressWarnings("unchecked") private static <T extends Exemplar> void
 
  split(BucketKdTree<T> tree) {
 
    assert !tree.exemplarsAreUniform;
 
    // Find dimension with largest variance to split on
 
    double largestVar = -1;
 
    int splitDim = 0;
 
    for(int d = 0; d < tree.dimensions; d++) {
 
      final double var =
 
        tree.exSumSqDev[d]/(tree.exemplars.size() - 1);
 
      if (var > largestVar) {
 
        largestVar = var;
 
        splitDim = d;
 
      }
 
    }
 
 
 
    // Find mean as position for our split
 
    double splitValue = tree.exMean[splitDim];
 
 
 
    // Check that our split actually splits our data. This also lets
 
    // us bulk load exemplars into sub-trees, which is more likely
 
    // to keep optimal balance.
 
    final List<T> leftExs = new LinkedList<T>();
 
    final List<T> rightExs = new LinkedList<T>();
 
    for(T s : tree.exemplars) {
 
      if (s.domain[splitDim] <= splitValue)
 
        leftExs.add(s);
 
      else
 
        rightExs.add(s);
 
    }
 
    int leftSize = leftExs.size();
 
    final int treeSize = tree.exemplars.size();
 
    if (leftSize == treeSize || leftSize == 0) {
 
      System.err.println(
 
        "WARNING: Randomly splitting non-uniform tree");
 
      // We know the exemplars aren't all the same, so try picking
 
      // an exemplar and a dimension at random for our split point
 
 
 
      // This might take several tries, so we copy our exemplars to
 
      // an array to speed up process of picking a random point
 
      Object[] exs = tree.exemplars.toArray();
 
      while (leftSize == treeSize || leftSize == 0) {
 
        leftExs.clear();
 
        rightExs.clear();
 
 
 
        splitDim = (int)
 
          Math.floor(Math.random()*tree.dimensions);
 
        final int splitPtIdx = (int)
 
          Math.floor(Math.random()*exs.length);
 
        // Cast is inevitable consequence of java's inability to
 
        // create a generic array
 
        splitValue = ((T)exs[splitPtIdx]).domain[splitDim];
 
        for(T s : tree.exemplars) {
 
          if (s.domain[splitDim] <= splitValue)
 
            leftExs.add(s);
 
          else
 
            rightExs.add(s);
 
        }
 
        leftSize = leftExs.size();
 
      }
 
    }
 
 
 
    // We have found a valid split. Start building our sub-trees
 
    final BucketKdTree<T> left =
 
      new BucketKdTree<T>(tree.bucketSize, tree.dimensions);
 
    final BucketKdTree<T> right =
 
      new BucketKdTree<T>(tree.bucketSize, tree.dimensions);
 
    left.addAll(leftExs);
 
    right.addAll(rightExs);
 
 
 
    // Finally, commit the split
 
    tree.splitDim = splitDim;
 
    tree.split = splitValue;
 
    tree.left = left;
 
    tree.right = right;
 
 
 
    // Let go of exemplars (and their running stats) held in this leaf
 
    tree.exemplars = null;
 
    tree.exMean = tree.exSumSqDev = null;
 
  }
 
 
 
  // Searching
 
 
 
  // May return more results than requested if multiple exemplars have
 
  // same distance from target.
 
  //
 
  // Note: this function works with squared distances to avoid sqrt()
 
  // operations
 
  private static <T extends Exemplar> SortedMap<Double, List<T>>
 
  search(BucketKdTree<T> tree, double[] query, int nMinResults) {
 
    // distance => list of points that distance away from query
 
    final NavigableMap<Double, List<T>> results =
 
      new TreeMap<Double, List<T>>();
 
 
 
    final SearchState state = new SearchState();
 
    final Deque<SearchStackEntry<T>> stack =
 
      new LinkedList<SearchStackEntry<T>>();
 
    stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree));
 
    while (!stack.isEmpty()) {
 
      final SearchStackEntry<T> entry = stack.removeFirst();
 
      final BucketKdTree<T> cur = entry.tree;
 
 
 
      if (cur.isTree()) {
 
        searchTree(query, nMinResults, cur, state, stack);
 
      } else if (entry.minDFromQ <= state.maxDistance
 
        || state.nResults < nMinResults)
 
      {
 
        searchLeaf(query, nMinResults, cur, state, results);
 
      }
 
    }
 
 
 
    return results;
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  searchTree(double[] query, int nMinResults, BucketKdTree<T> tree,
 
    SearchState searchState, Deque<SearchStackEntry<T>> stack)
 
  {
 
    // Left is presumed near. This is verified further down.
 
    BucketKdTree<T> nearTree = tree.left, farTree = tree.right;
 
    // These variables let us skip empty sub-trees
 
    boolean nearEmpty = nearTree.minBounds == null;
 
    boolean farEmpty = farTree.minBounds == null;
 
 
 
    // Find distance from nearest possible point in each
 
    // sub-tree to query. If that is greater than max distance,
 
    // we can rule out that sub-tree.
 
    double nearD = nearEmpty ? Double.POSITIVE_INFINITY
 
      : minDistanceSqFrom(query,
 
        nearTree.minBounds, nearTree.maxBounds);
 
    double farD = farEmpty ? Double.POSITIVE_INFINITY
 
      : minDistanceSqFrom(query,
 
        farTree.minBounds, farTree.maxBounds);
 
 
 
    // Swap near and far if they're incorrect
 
    if (farD < nearD) {
 
      final double tmpD = nearD;
 
      final BucketKdTree<T> tmpTree = nearTree;
 
      final boolean tmpEmpty = nearEmpty;
 
 
 
      nearD = farD;
 
      nearTree = farTree;
 
      nearEmpty = farEmpty;
 
 
 
      farD = tmpD;
 
      farTree = tmpTree;
 
      farEmpty = tmpEmpty;
 
    }
 
 
 
    // Add nearest sub-tree to stack later so we descend it
 
    // first. This is likely to constrict our max distance
 
    // sooner, resulting in less visited nodes
 
    if (!farEmpty && (farD <= searchState.maxDistance
 
                      || searchState.nResults < nMinResults))
 
    {
 
      stack.addFirst(new SearchStackEntry<T>(farD, farTree));
 
    }
 
 
 
    if (!nearEmpty && (nearD <= searchState.maxDistance
 
                      || searchState.nResults < nMinResults))
 
    {
 
      stack.addFirst(new SearchStackEntry<T>(nearD, nearTree));
 
    }
 
  }
 
 
 
  private static <T extends Exemplar> void
 
  searchLeaf(double[] query, int nMinResults,
 
    BucketKdTree<T> leaf, SearchState searchState,
 
    NavigableMap<Double, List<T>> results)
 
  {
 
    // Keep track of elements at max distance so we know
 
    // whether we can just drop entire list of furthest
 
    // exemplars
 
    final int nMinResultsBeforeAddition = nMinResults - 1;
 
    for(T ex : leaf.exemplars) {
 
      final double exD = distanceSqFrom(query, ex.domain);
 
      if (searchState.nResults < nMinResults
 
        || exD == searchState.maxDistance)
 
      {
 
        // Blindly add this exemplar
 
        List<T> exsAtD = getOrElseInit(results, exD);
 
        if (leaf.exemplarsAreUniform) {
 
          // No need to go through every one if all
 
          // exemplars are the same
 
          exsAtD.addAll(leaf.exemplars);
 
          searchState.nResults += leaf.exemplars.size();
 
        } else {
 
          exsAtD.add(ex);
 
          searchState.nResults++;
 
        }
 
 
 
        // Update information about furthest exemplars
 
        if (exD > searchState.maxDistance
 
          || searchState.maxDistance == Double.POSITIVE_INFINITY)
 
        {
 
          searchState.maxDistance = exD;
 
        }
 
 
 
        if (searchState.maxDistance == exD)
 
          searchState.nExsAtMaxD = exsAtD.size();
 
 
 
      } else if (exD < searchState.maxDistance) {
 
        // Point closer than furthest neighbour
 
 
 
        if (searchState.nResults - searchState.nExsAtMaxD
 
          >= nMinResultsBeforeAddition)
 
        {
 
          // Dropping furthest exemplars won't leave
 
          // us with too little to meet return
 
          // minimum
 
          results.remove(searchState.maxDistance);
 
          searchState.nResults -= searchState.nExsAtMaxD;
 
        }
 
 
 
        // Add new exemplar
 
        List<T> exsAtD = getOrElseInit(results, exD);
 
        if (leaf.exemplarsAreUniform) {
 
          // No need to go through every one if all
 
          // exemplars are the same
 
          exsAtD.addAll(leaf.exemplars);
 
          searchState.nResults += leaf.exemplars.size();
 
        } else {
 
          exsAtD.add(ex);
 
          searchState.nResults++;
 
        }
 
 
 
        // Update information about furthest exemplars
 
        Map.Entry<Double, List<T>> lastEnt = results.lastEntry();
 
        searchState.maxDistance = lastEnt.getKey();
 
        searchState.nExsAtMaxD = lastEnt.getValue().size();
 
      } // exD < maxDistance
 
 
 
      // No need to keep going if all exemplars are
 
      // the same
 
      if (leaf.exemplarsAreUniform) break;
 
    } // for(T ex : cur.exemplars)
 
  }
 
 
 
  private static <T extends Exemplar> List<T>
 
  getOrElseInit(Map<Double, List<T>> results, double dst) {
 
    List<T> lst = results.get(dst);
 
    if (lst == null) {
 
      lst = new LinkedList<T>();
 
      results.put(dst, lst);
 
    }
 
    return lst;
 
  }
 
 
 
  // Dumping to string (debug)
 
 
 
  private String
 
  toString(String indent) {
 
    if (isTree()){
 
      return String.format("%s{|%d|%f|\n%s\n%s} ", indent,
 
        splitDim, split, left.toString(indent + "\t"),
 
        right.toString(indent + "\t"));
 
    } else {
 
      return String.format("%sL%s", indent,
 
        Arrays.toString(exemplars.toArray()));
 
    }
 
  }
 
 
 
  // Distance calculations
 
 
 
  // Gets distance from target of nearest point on hyper-rect defined
 
  // by supplied min and max bounds
 
  private static double
 
  minDistanceSqFrom(double[] target, double[] min, double[] max) {
 
    // Note: profiling shows this is called lots of times, so it pays
 
    // to be well optimised
 
    double distanceSq = 0;
 
    for(int d = 0; d < target.length; d++) {
 
      final double coord = target[d];
 
      double nearCoord;
 
      if (((nearCoord = min[d]) > coord)
 
        || ((nearCoord = max[d]) < coord))
 
      {
 
        final double dst = nearCoord - coord;
 
        distanceSq += dst*dst;
 
      }
 
    }
 
    return distanceSq;
 
  }
 
 
 
  // Accessible to testing
 
  static double
 
  distanceSqFrom(double[] p1, double[] p2) {
 
    // Note: profiling shows this is called lots of times, so it pays
 
    // to be well optimised
 
    double dSq = 0;
 
    for(int d = 0; d < p1.length; d++) {
 
      final double dst = p1[d] - p2[d];
 
      if (dst != 0)
 
        dSq += dst*dst;
 
    }
 
    return dSq;
 
  }
 
 
 
  //
 
  // class SearchStackEntry
 
  //
 
  // Stores a precomputed distance so we don't have to do it again
 
  // when we pop the tree off the search stack.
 
 
 
  private static class SearchStackEntry<T extends Exemplar> {
 
    public final double minDFromQ;
 
    public final BucketKdTree<T> tree;
 
 
 
    public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) {
 
      this.minDFromQ = minDFromQ;
 
      this.tree = tree;
 
    }
 
  }
 
 
 
  //
 
  // class SearchState
 
  //
 
  // Holds data about current state of the search. Used for live updating
 
  // of pruning distance.
 
 
 
  private static class SearchState {
 
    int nResults = 0;
 
    double maxDistance = Double.POSITIVE_INFINITY;
 
    int nExsAtMaxD = 0;
 
  }
 
}
 
</pre>
 

Latest revision as of 13:23, 13 March 2010

There used to be a kd-tree on this page. It has since been packaged into a JAR file and is available at: File:Duyn tutorial kd trees.jar. Its source is available under the RWPCL.