User:Duyn/BucketKdTree

From Robowiki
< User:Duyn
Revision as of 19:33, 28 February 2010 by Duyn (talk | contribs) (Updated to latest version)
Jump to navigation Jump to search

Below is the full source code to a tree that will form the basis for a future kd tree tutorial I'm working on. This code is released under the RWPCL. It is decently optimised—theoretically, anyway. I don't specialise in writing fast code, but the algorithms used are designed to ensure whole trees are ruled out as early as possible.

Exemplar.java

Base class for items to be added into the tree. Sub-classes can carry useful data like guess factors.

import java.util.Arrays;
/**
 * A sample point in multi-dimensional space. Needed because each sample
 * may contain an arbitrary payload.
 *
 * Note: this class does not make allowance for a payload. Sub-class if
 * you want to store something more than just data points.
 *
 * @author dnn
 */
public class Exemplar {
  public final double[] domain;

  public Exemplar(double[] domain) {
    this.domain = domain;
  }

  public boolean domainEquals(final Exemplar other) {
    return Arrays.equals(domain, other.domain);
  }

  @Override public String
  toString() {
    return Arrays.toString(domain);
  }
}

BucketKdTree.java

Tree with k-nearest neighbour search. Does not support deletion or rebalancing—a re-build is required if you want to do either.

import java.util.*;

/**
 * A k-dimensional binary partitioning tree which splits space on the
 * mean of the dimension with the largest variance. Points are held in
 * buckets so we can pick a better split point than whatever comes first.
 *
 * Does not store tree depth. If you want balance, re-build the tree
 * periodically.
 *
 * Optimisations in this tree assume distance metric is euclidian distance.
 * May work if retrofitted with other metrics, but that is purely
 * accidental.
 *
 * Note: results can become unpredictable if values are different but so
 * close together that rounding errors in computing their mean result in
 * all exemplars being on one side of the mean. Performance degrades when
 * this occurs. Nearest neighbour search tested to work up to range
 * [1, 1 + 5e-16).
 *
 * Ideas for path ordering and bounds come from:
 *   NEAL SAMPLE, MATTHEW HAINES, MARK ARNOLD, TIMOTHY PURCELL,
 *     'Optimizing Search Strategies in k-d Trees'
 *     http://ilpubs.stanford.edu:8090/723/
 *
 * Computation of variance from:
 *   John Cook, 'Accurately computing running variance'
 *     http://www.johndcook.com/standard_deviation.html
 *
 * Terminology note: points are called Exemplars. They must all be
 * descended from Exemplar class. Position in k-d space is stored in each
 * exemplar's domain member. This is to avoid conflicting with already
 * existing classes referring to geometric points.
 *
 * Terminology comes from:
 *   Andrew Moore, 'An intoductory tutorial on kd'
 *     http://www.autonlab.org/autonweb/14665
 * 
 * @author dnn
 */
public class BucketKdTree<T extends Exemplar> {

  // Only leaf nodes contain points
  private List<T> exemplars = new LinkedList<T>();
  // These aren't initialised until add() is called.
  private double[] exMean;
  private double[] exSumSqDev;

  // Optimisation when tree contains large number of duplicates
  private boolean exemplarsAreUniform = true;

  private int bucketSize;
  private BucketKdTree<T> left, right;

  private int splitDim;
  private double split;

  private int dimensions = 0;

  // Optimisation for searches. This lets us skip a node if its
  // scope intersects with a search hypersphere but it doesn't contain
  // any points that actually intersect.
  private double[] maxBounds;
  private double[] minBounds;

  public BucketKdTree(int bucketSize) {
    this.bucketSize = bucketSize;
  }

  //
  // PUBLIC METHODS
  //

  public void
  add(T ex) {
    BucketKdTree<T> tree = addNoSplit(this, ex);
    if (shouldSplit(tree)) {
      split(tree);
    }
  }

  public void
  addAll(Collection<T> exs) {
    // Some spurious function calls. Optimised for readability over
    // efficiency.
    final Set<BucketKdTree<T>> modTrees =
      new HashSet<BucketKdTree<T>>();
    for(T ex : exs) {
      modTrees.add(addNoSplit(this, ex));
    }

    for(BucketKdTree<T> tree : modTrees) {
      if (shouldSplit(tree)) {
        split(tree);
      }
    }
  }

  public SortedMap<Double, List<T>>
  search(double[] query, int nMinResults) {
    // Forward to a static method to avoid accidental reference to
    // instance variables while descending the tree
    return search(this, query, nMinResults);
  }
  
  @Override public String
  toString() {
    return toString("");
  }

  //
  // IMPLEMENTATION DETAILS
  //

  private boolean
  isTree() { return left != null; }

  // Addition

  // Adds an exemplar without splitting overflowing leaves.
  // Returns leaf to which exemplar was added.
  private static <T extends Exemplar> BucketKdTree<T>
  addNoSplit(BucketKdTree<T> tree, T ex) {
    // Some spurious function calls. Optimised for readability over
    // efficiency.
    BucketKdTree<T> cursor = tree;
    while (cursor != null) {
      updateBounds(cursor, ex);
      if (cursor.isTree()) {
        // Sub-tree
        cursor = ex.domain[cursor.splitDim] <= cursor.split
          ? cursor.left : cursor.right;
      } else {
        // Leaf

        // Infer dimensions if we haven't already
        if (cursor.dimensions == 0)
          cursor.dimensions = ex.domain.length;

        // Add exemplar to leaf
        cursor.exemplars.add(ex);

        // Calculate running mean and sum of squared deviations
        final int nExs = cursor.exemplars.size();
        final int dims = cursor.dimensions;
        if (nExs == 1) {
          cursor.exMean = Arrays.copyOf(ex.domain, dims);
          cursor.exSumSqDev = new double[dims];
        } else {
          for(int d = 0; d < dims; d++) {
            final double coord = ex.domain[d];
            final double oldMean = cursor.exMean[d], newMean;
            cursor.exMean[d] = newMean =
              oldMean + (coord - oldMean)/nExs;
            cursor.exSumSqDev[d] = cursor.exSumSqDev[d]
              + (coord - oldMean)*(coord - newMean);
          }
        }

        // Check that exemplars are still uniform
        if (cursor.exemplarsAreUniform) {
          final List<T> cExs = cursor.exemplars;
          if (cExs.size() > 0 && !ex.domainEquals(cExs.get(0)))
            cursor.exemplarsAreUniform = false;
        }

        // Finished walking
        return cursor;
      }
    }
    throw new RuntimeException("Walked tree without adding anything");
  }

  private static <T extends Exemplar> void
  updateBounds(BucketKdTree<T> tree, Exemplar ex) {
    final int dims = ex.domain.length;
    if (tree.maxBounds == null) {
      tree.maxBounds = Arrays.copyOf(ex.domain, dims);
      tree.minBounds = Arrays.copyOf(ex.domain, dims);
    } else {
      for(int d = 0; d < dims; d++) {
        final double dimVal = ex.domain[d];
        if (dimVal > tree.maxBounds[d])
          tree.maxBounds[d] = dimVal;
        else if (dimVal < tree.minBounds[d])
          tree.minBounds[d] = dimVal;
      }
    }
  }

  // Splitting (internal operation)

  private static <T extends Exemplar> boolean
  shouldSplit(BucketKdTree<T> tree) {
    return tree.exemplars.size() > tree.bucketSize
      && !tree.exemplarsAreUniform;
  }

  @SuppressWarnings("unchecked") private static <T extends Exemplar> void
  split(BucketKdTree<T> tree) {
    assert !tree.exemplarsAreUniform;
    // Find dimension with largest variance to split on
    double largestVar = -1;
    int splitDim = 0;
    for(int d = 0; d < tree.dimensions; d++) {
      // Don't need to divide by number of exemplars to find largest
      // variance
      final double var = tree.exSumSqDev[d];
      if (var > largestVar) {
        largestVar = var;
        splitDim = d;
      }
    }

    // Find mean as position for our split
    double splitValue = tree.exMean[splitDim];

    // Check that our split actually splits our data. This also lets
    // us bulk load exemplars into sub-trees, which is more likely
    // to keep optimal balance.
    final List<T> leftExs = new LinkedList<T>();
    final List<T> rightExs = new LinkedList<T>();
    for(T s : tree.exemplars) {
      if (s.domain[splitDim] <= splitValue)
        leftExs.add(s);
      else
        rightExs.add(s);
    }
    int leftSize = leftExs.size();
    final int treeSize = tree.exemplars.size();
    if (leftSize == treeSize || leftSize == 0) {
      System.err.println(
        "WARNING: Randomly splitting non-uniform tree");
      // We know the exemplars aren't all the same, so try picking
      // an exemplar and a dimension at random for our split point

      // This might take several tries, so we copy our exemplars to
      // an array to speed up process of picking a random point
      Object[] exs = tree.exemplars.toArray();
      while (leftSize == treeSize || leftSize == 0) {
        leftExs.clear();
        rightExs.clear();

        splitDim = (int)
          Math.floor(Math.random()*tree.dimensions);
        final int splitPtIdx = (int)
          Math.floor(Math.random()*exs.length);
        // Cast is inevitable consequence of java's inability to
        // create a generic array
        splitValue = ((T)exs[splitPtIdx]).domain[splitDim];
        for(T s : tree.exemplars) {
          if (s.domain[splitDim] <= splitValue)
            leftExs.add(s);
          else
            rightExs.add(s);
        }
        leftSize = leftExs.size();
      }
    }

    // We have found a valid split. Start building our sub-trees
    final BucketKdTree<T> left =
      new BucketKdTree<T>(tree.bucketSize);
    final BucketKdTree<T> right =
      new BucketKdTree<T>(tree.bucketSize);
    left.addAll(leftExs);
    right.addAll(rightExs);

    // Finally, commit the split
    tree.splitDim = splitDim;
    tree.split = splitValue;
    tree.left = left;
    tree.right = right;

    // Let go of exemplars (and their running stats) held in this leaf
    tree.exemplars = null;
    tree.exMean = tree.exSumSqDev = null;
  }

  // Searching

  // May return more results than requested if multiple exemplars have
  // same distance from target.
  //
  // Note: this function works with squared distances to avoid sqrt()
  // operations
  private static <T extends Exemplar> SortedMap<Double, List<T>>
  search(BucketKdTree<T> tree, double[] query, int nMinResults) {
    // distance => list of points that distance away from query
    final NavigableMap<Double, List<T>> results =
      new TreeMap<Double, List<T>>();

    final SearchState state = new SearchState();
    final Deque<SearchStackEntry<T>> stack =
      new LinkedList<SearchStackEntry<T>>();
    stack.addFirst(new SearchStackEntry<T>(state.maxDistance, tree));
    while (!stack.isEmpty()) {
      final SearchStackEntry<T> entry = stack.removeFirst();
      final BucketKdTree<T> cur = entry.tree;

      if (cur.isTree()) {
        searchTree(query, nMinResults, cur, state, stack);
      } else if (entry.minDFromQ <= state.maxDistance
        || state.nResults < nMinResults)
      {
        searchLeaf(query, nMinResults, cur, state, results);
      }
    }

    return results;
  }

  private static <T extends Exemplar> void
  searchTree(double[] query, int nMinResults, BucketKdTree<T> tree,
    SearchState searchState, Deque<SearchStackEntry<T>> stack)
  {
    // Left is presumed near. This is verified further down.
    BucketKdTree<T> nearTree = tree.left, farTree = tree.right;
    // These variables let us skip empty sub-trees
    boolean nearEmpty = nearTree.minBounds == null;
    boolean farEmpty = farTree.minBounds == null;

    // Find distance from nearest possible point in each
    // sub-tree to query. If that is greater than max distance,
    // we can rule out that sub-tree.
    double nearD = nearEmpty ? Double.POSITIVE_INFINITY
      : minDistanceSqFrom(query,
        nearTree.minBounds, nearTree.maxBounds);
    double farD = farEmpty ? Double.POSITIVE_INFINITY
      : minDistanceSqFrom(query,
        farTree.minBounds, farTree.maxBounds);

    // Swap near and far if they're incorrect
    if (farD < nearD) {
      final double tmpD = nearD;
      final BucketKdTree<T> tmpTree = nearTree;
      final boolean tmpEmpty = nearEmpty;

      nearD = farD;
      nearTree = farTree;
      nearEmpty = farEmpty;

      farD = tmpD;
      farTree = tmpTree;
      farEmpty = tmpEmpty;
    }

    // Add nearest sub-tree to stack later so we descend it
    // first. This is likely to constrict our max distance
    // sooner, resulting in less visited nodes
    if (!farEmpty && (farD <= searchState.maxDistance
                      || searchState.nResults < nMinResults))
    {
      stack.addFirst(new SearchStackEntry<T>(farD, farTree));
    }

    if (!nearEmpty && (nearD <= searchState.maxDistance
                       || searchState.nResults < nMinResults))
    {
      stack.addFirst(new SearchStackEntry<T>(nearD, nearTree));
    }
  }

  private static <T extends Exemplar> void
  searchLeaf(double[] query, int nMinResults,
    BucketKdTree<T> leaf, SearchState searchState,
    NavigableMap<Double, List<T>> results)
  {
    // Keep track of elements at max distance so we know
    // whether we can just drop entire list of furthest
    // exemplars
    final int nMinResultsBeforeAddition = nMinResults - 1;
    for(T ex : leaf.exemplars) {
      final double exD = distanceSqFrom(query, ex.domain);
      if (searchState.nResults < nMinResults
        || exD == searchState.maxDistance)
      {
        // Blindly add this exemplar
        List<T> exsAtD = getOrElseInit(results, exD);
        if (leaf.exemplarsAreUniform) {
          // No need to go through every one if all
          // exemplars are the same
          exsAtD.addAll(leaf.exemplars);
          searchState.nResults += leaf.exemplars.size();
        } else {
          exsAtD.add(ex);
          searchState.nResults++;
        }

        // Update information about furthest exemplars
        if (exD > searchState.maxDistance
          || searchState.maxDistance == Double.POSITIVE_INFINITY)
        {
          searchState.maxDistance = exD;
        }

        if (searchState.maxDistance == exD)
          searchState.nExsAtMaxD = exsAtD.size();

      } else if (exD < searchState.maxDistance) {
        // Point closer than furthest neighbour

        if (searchState.nResults - searchState.nExsAtMaxD
          >= nMinResultsBeforeAddition)
        {
          // Dropping furthest exemplars won't leave
          // us with too little to meet return
          // minimum
          results.remove(searchState.maxDistance);
          searchState.nResults -= searchState.nExsAtMaxD;
        }

        // Add new exemplar
        List<T> exsAtD = getOrElseInit(results, exD);
        if (leaf.exemplarsAreUniform) {
          // No need to go through every one if all
          // exemplars are the same
          exsAtD.addAll(leaf.exemplars);
          searchState.nResults += leaf.exemplars.size();
        } else {
          exsAtD.add(ex);
          searchState.nResults++;
        }

        // Update information about furthest exemplars
        Map.Entry<Double, List<T>> lastEnt = results.lastEntry();
        searchState.maxDistance = lastEnt.getKey();
        searchState.nExsAtMaxD = lastEnt.getValue().size();
      } // exD < maxDistance

      // No need to keep going if all exemplars are
      // the same
      if (leaf.exemplarsAreUniform) break;
    } // for(T ex : cur.exemplars)
  }

  private static <T extends Exemplar> List<T>
  getOrElseInit(Map<Double, List<T>> results, double dst) {
    List<T> lst = results.get(dst);
    if (lst == null) {
      lst = new LinkedList<T>();
      results.put(dst, lst);
    }
    return lst;
  }

  // Dumping to string (debug)

  private String
  toString(String indent) {
    if (isTree()){
      return String.format("%s{|%d|%f|\n%s\n%s} ", indent,
        splitDim, split, left.toString(indent + "\t"),
        right.toString(indent + "\t"));
    } else {
      return String.format("%sL%s", indent,
        Arrays.toString(exemplars.toArray()));
    }
  }

  // Distance calculations

  // Gets distance from target of nearest point on hyper-rect defined
  // by supplied min and max bounds
  private static double
  minDistanceSqFrom(double[] target, double[] min, double[] max) {
    // Note: profiling shows this is called lots of times, so it pays
    // to be well optimised
    double distanceSq = 0;
    for(int d = 0; d < target.length; d++) {
      final double coord = target[d];
      double nearCoord;
      if (((nearCoord = min[d]) > coord)
        || ((nearCoord = max[d]) < coord))
      {
        final double dst = nearCoord - coord;
        distanceSq += dst*dst;
      }
    }
    return distanceSq;
  }

  // Accessible to testing
  static double
  distanceSqFrom(double[] p1, double[] p2) {
    // Note: profiling shows this is called lots of times, so it pays
    // to be well optimised
    double dSq = 0;
    for(int d = 0; d < p1.length; d++) {
      final double dst = p1[d] - p2[d];
      if (dst != 0)
        dSq += dst*dst;
    }
    return dSq;
  }

  //
  // class SearchStackEntry
  //
  // Stores a precomputed distance so we don't have to do it again
  // when we pop the tree off the search stack.

  private static class SearchStackEntry<T extends Exemplar> {
    public final double minDFromQ;
    public final BucketKdTree<T> tree;

    public SearchStackEntry(double minDFromQ, BucketKdTree<T> tree) {
      this.minDFromQ = minDFromQ;
      this.tree = tree;
    }
  }

  //
  // class SearchState
  //
  // Holds data about current state of the search. Used for live updating
  // of pruning distance.

  private static class SearchState {
    int nResults = 0;
    double maxDistance = Double.POSITIVE_INFINITY;
    int nExsAtMaxD = 0;
  }
}