Difference between revisions of "User:Skilgannon/KDTree"

From Robowiki
Jump to navigation Jump to search
(My KDTree Implementation - it's super effective!)
 
m (Update, link to latest on Bitbucket)
 
(10 intermediate revisions by the same user not shown)
Line 1: Line 1:
 +
Latest version is available here: [https://bitbucket.org/jkflying/kd-tree]
 +
 +
A possibly outdated version is listed below:
 +
 
<code><syntaxhighlight>
 
<code><syntaxhighlight>
 
+
package jk.tree;
 
/*
 
/*
** KDTree.java by Julian Kent
+
** KDTree.java by Julian Kent
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
+
**
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
+
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
** For additional usage rights please contact jkflying@gmail.com
+
**
**
+
** Licence summary:
** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree
+
** Under this licence you are free to:
*/
+
**      Share : copy and redistribute the material in any medium or format
 
+
**      Adapt : remix, transform, and build upon the material
 +
**      The licensor cannot revoke these freedoms as long as you follow the license terms.
 +
**
 +
** Under the following terms:
 +
**      Attribution:
 +
**            You must give appropriate credit, provide a link to the license, and indicate
 +
**            if changes were made. You may do so in any reasonable manner, but not in any
 +
**            way that suggests the licensor endorses you or your use.
 +
**      NonCommercial:
 +
**            You may not use the material for commercial purposes.
 +
**      ShareAlike:
 +
**            If you remix, transform, or build upon the material, you must distribute your
 +
**            contributions under the same license as the original.
 +
**      No additional restrictions:
 +
**            You may not apply legal terms or technological measures that legally restrict
 +
**            others from doing anything the license permits.
 +
**
 +
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
 +
**
 +
** For additional licencing rights (including commercial) please contact jkflying@gmail.com
 +
**
 +
*/
  
package jk.mega;
 
import java.util.ArrayDeque;
 
 
import java.util.ArrayList;
 
import java.util.ArrayList;
 
import java.util.Arrays;
 
import java.util.Arrays;
//import ags.utils.*;
 
  
public class KDTree<T>{
+
public abstract class KDTree<T> {
 +
 
 +
    // use a big bucketSize so that we have less node bounds (for more cache
 +
    // hits) and better splits
 +
    // if you have lots of dimensions this should be big, and if you have few small
 +
    private static final int _bucketSize = 50;
 +
 
 +
    private final int _dimensions;
 +
    private int _nodes;
 +
    private final Node root;
 +
    private final ArrayList<Node> nodeList = new ArrayList<Node>();
 +
 
 +
    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
 +
    // time a leaf splits
 +
    private double[] mem_recycle;
 +
 
 +
    // the starting values for bounding boxes, for easy access
 +
    private final double[] bounds_template;
 +
 
 +
    // one big self-expanding array to keep all the node bounding boxes so that
 +
    // they stay in cache
 +
    // node bounds available at:
 +
    // low: 2 * _dimensions * node.index + 2 * dim
 +
    // high: 2 * _dimensions * node.index + 2 * dim + 1
 +
    private final ContiguousDoubleArrayList nodeMinMaxBounds;
 +
 
 +
    private KDTree(int dimensions) {
 +
        _dimensions = dimensions;
 +
 
 +
        // initialise this big so that it ends up in 'old' memory
 +
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
 +
        mem_recycle = new double[_bucketSize * dimensions];
 +
 
 +
        bounds_template = new double[2 * _dimensions];
 +
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
 +
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
 +
            bounds_template[i] = Double.POSITIVE_INFINITY;
 +
 
 +
        // and.... start!
 +
        root = new Node();
 +
    }
 +
 
 +
    public int nodes() {
 +
        return _nodes;
 +
    }
 +
 
 +
    public int size() {
 +
        return root.entries;
 +
    }
 +
 
 +
    public int addPoint(double[] location, T payload) {
 +
 
 +
        Node addNode = root;
 +
        // Do a Depth First Search to find the Node where 'location' should be
 +
        // stored
 +
        while (addNode.pointLocations == null) {
 +
            addNode.expandBounds(location);
 +
            if (location[addNode.splitDim] < addNode.splitVal)
 +
                addNode = nodeList.get(addNode.lessIndex);
 +
            else
 +
                addNode = nodeList.get(addNode.moreIndex);
 +
        }
 +
        addNode.expandBounds(location);
 +
 
 +
        int nodeSize = addNode.add(location, payload);
 +
 
 +
        if (nodeSize % _bucketSize == 0)
 +
            // try splitting again once every time the node passes a _bucketSize
 +
            // multiple
 +
            // in case it is full of points of the same location and won't split
 +
            addNode.split();
 +
 
 +
        return root.entries;
 +
    }
 +
 
 +
    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {
  
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
+
        K = Math.min(K, size());
  private static final int  _bucketSize = 64;
 
  
  private final int _dimensions;
+
        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
  private int _nodes; 
 
  private Node root;
 
 
 
  //prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
 
  private double[] mem_recycle;
 
 
 
  //the starting values for bounding boxes, for easy access
 
  private final double[] bounds_template;
 
  
  //one big self-expanding array to keep all the node bounding boxes so that they stay in cache
+
        if (K > 0) {
  // node bounds available at:
+
            IntStack stack = new IntStack();
  //low:  2 * _dimensions * node.index
+
            PrioQueue<T> results = new PrioQueue<T>(K, true);
  //high: 2 * _dimensions * node.index + _dimensions
+
 
  private ContiguousDoubleArrayList nodeMinMaxBounds;
+
            stack.push(root.index);
/*
+
 
  public static void main(String[] args){
+
            int added = 0;
      int dims = 12;
+
 
      int size = 20000;
+
            while (stack.size() > 0) {
      int testsize = 200;
+
                int nodeIndex = stack.pop();
      int k = 40;
+
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
      int iterations = 3;
+
                    Node node = nodeList.get(nodeIndex);
      System.out.println(
+
                    if (node.pointLocations == null)
        "Config:\n"
+
                        node.search(searchLocation, stack);
        + "No JIT Warmup\n"
+
                    else
        + "Tested on random data.\n"
+
                        added += node.search(searchLocation, results);
        + "Training and testing points shared across iterations.\n"
+
                }
        + "Searches interleaved.");
+
            }
      System.out.println("Num points:    " + size);
+
 
      System.out.println("Num searches:  " + testsize);
+
            double[] priorities = results.priorities;
      System.out.println("Dimensions:    " + dims);
+
            Object[] elements = results.elements;
      System.out.println("Num Neighbours: " + k);
+
            for (int i = 0; i < K; i++) { // forward (closest first)
      System.out.println();
+
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
      ArrayList<double[]> locs = new ArrayList<double[]>(size);
+
                returnResults.add(s);
      for(int i = 0; i < size; i++){
+
            }
        double[] loc = new double[dims];
+
        }
        for(int j = 0; j < dims; j++)
+
        return returnResults;
            loc[j] = Math.random();
+
    }
        locs.add(loc);
+
 
      }
+
    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
      ArrayList<double[]> testlocs = new ArrayList<double[]>(testsize);
+
        IntStack stack = new IntStack();
      for(int i = 0; i < testsize; i++){
+
        ArrayList<T> results = new ArrayList<T>();
        double[] loc = new double[dims];
+
 
        for(int j = 0; j < dims; j++)
+
        stack.push(root.index);
            loc[j] = Math.random();
+
 
        testlocs.add(loc);
+
        while (stack.size() > 0) {
      }
+
            int nodeIndex = stack.pop();
      for(int r = 0; r < iterations; r++){
+
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
        long t1 = System.nanoTime();
+
                Node node = nodeList.get(nodeIndex);
        KDTree<double[]> t = new KDTree<double[]>(dims);// This tree
+
                if (node.pointLocations == null)
        for(int i = 0; i < size; i++){
+
                    stack.push(node.moreIndex).push(node.lessIndex);
            t.addPoint(locs.get(i),locs.get(i));
+
                else
        }
+
                    node.searchBall(searchLocation, radius, results);
        long t2 = System.nanoTime();
+
            }
        KdTree<double[]> rt = new KdTree.Euclidean<double[]>(dims,null); //Rednaxela Gen2
+
        }
        for(int i = 0; i < size; i++){
+
        return results;
            rt.addPoint(locs.get(i),locs.get(i));
+
    }
        }
+
 
        long t3 = System.nanoTime();
+
    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
     
+
        IntStack stack = new IntStack();
        long jtn = 0;
+
        ArrayList<T> results = new ArrayList<T>();
        long rtn = 0;
+
 
        long mjtn = 0;
+
        stack.push(root.index);
        long mrtn = 0;
+
 
     
+
        while (stack.size() > 0) {
        double dist1 = 0, dist2 = 0;
+
             int nodeIndex = stack.pop();
        for(int i = 0; i < testsize; i++){
+
             if (overlaps(mins, maxs, nodeIndex)) {
             long t4 = System.nanoTime();
+
                Node node = nodeList.get(nodeIndex);
             dist1 += t.nearestNeighbours(testlocs.get(i),k).iterator().next().distance;
+
                if (node.pointLocations == null)
            long t5 = System.nanoTime();
+
                    stack.push(node.moreIndex).push(node.lessIndex);
            dist2 += rt.nearestNeighbor(testlocs.get(i),k,true).iterator().next().distance;
+
                else
             long t6 = System.nanoTime();
+
                    node.searchRect(mins, maxs, results);
            long t7 = System.nanoTime();
+
             }
            jtn += t5 - t4 - (t7 - t6);
+
        }
            rtn += t6 - t5 - (t7 - t6);  
+
        return results;
            mjtn = Math.max(mjtn,t5 - t4 - (t7 - t6));
+
 
             mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6));
+
    }
        }
+
 
     
+
    abstract double pointRectDist(int offset, final double[] location);
        System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!"));
+
 
        if(Math.abs(dist1-dist2) > 1e-10){
+
    abstract double pointDist(double[] arr, double[] location, int index);
            System.out.println("dist1: " + dist1 + "    dist2: " + dist2);
+
 
        }
+
    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {
        long jts = t2 - t1;
+
 
        long rts = t3 - t2;
+
        int offset = (index + 1) * mins.length;
        System.out.println("Iteration:      " + (r+1) + "/" + iterations);
+
 
     
+
        for (int i = mins.length; i-- > 0;) {
        System.out.println("This tree add avg:  " + jts/size + " ns");
+
             double d = arr[--offset];
        System.out.println("Reds tree add avg:  " + rts/size + " ns");
+
            if (mins[i] > d | d > maxs[i])
     
+
                return false;
        System.out.println("This tree knn avg:  " + jtn/testsize + " ns");
+
        }
        System.out.println("Reds tree knn avg:  " + rtn/testsize + " ns");
+
        return true;
        System.out.println("This tree knn max:  " + mjtn + " ns");
+
    }
        System.out.println("Reds tree knn max:  " + mrtn + " ns");
+
 
        System.out.println();
+
    boolean overlaps(double[] mins, double[] maxs, int offset) {
      }
+
        offset *= (2 * maxs.length);
  }
+
        final double[] array = nodeMinMaxBounds.array;
  */
+
        for (int i = 0; i < maxs.length; i++, offset += 2) {
 +
            double bmin = array[offset], bmax = array[offset + 1];
 +
            if (mins[i] > bmax | maxs[i] < bmin)
 +
                return false;
 +
        }
 +
 
 +
        return true;
 +
    }
 +
 
 +
    public static class Euclidean<T> extends KDTree<T> {
 +
        public Euclidean(int dims) {
 +
            super(dims);
 +
        }
  
  public KDTree(int dimensions){
+
        double pointRectDist(int offset, final double[] location) {
      _dimensions = dimensions;
+
            offset *= (2 * super._dimensions);
 
+
            double distance = 0;
      nodeMinMaxBounds = new ContiguousDoubleArrayList(2 * dimensions);
+
            final double[] array = super.nodeMinMaxBounds.array;
      mem_recycle = new double[_bucketSize*dimensions];
+
            for (int i = 0; i < location.length; i++, offset += 2) {
 
 
      bounds_template = new double[2*_dimensions];
 
      Arrays.fill(bounds_template,0,_dimensions,Double.POSITIVE_INFINITY);
 
      Arrays.fill(bounds_template,_dimensions,2*_dimensions,Double.NEGATIVE_INFINITY);
 
 
 
  //and.... start!
 
      root = new Node();
 
  }
 
  public int nodes(){
 
      return _nodes;
 
  }
 
  public int addPoint(double[] location, T payload){
 
 
 
      Node addNode = root;
 
  //Do a Depth First Search to find the node where it should be stored
 
      while(addNode.pointLocations == null){
 
        addNode.expandBounds(location);
 
        if(location[addNode.splitDim] < addNode.splitVal)
 
            addNode = addNode.less;
 
        else
 
            addNode = addNode.more;
 
      }
 
      addNode.expandBounds(location);
 
 
 
      int nodeSize = addNode.add(location,payload);
 
 
 
      if(nodeSize % _bucketSize == 0)
 
      //try splitting again once every time the node passes a _bucketSize multiple
 
        addNode.split();
 
 
 
      return root.entries;
 
  }
 
  
  public SearchResult<T> nearestNeighbour(double[] searchLocation){
+
                double diff = 0;
 
+
                double bv = array[offset];
      Node searchNode = root;
+
                double lv = location[i];
      ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
+
                if (bv > lv)
 
+
                    diff = bv - lv;
  //Do a Depth First Search to find the Node this location would be stored
+
                else {
      while(searchNode.pointLocations == null){
+
                    bv = array[offset + 1];
     
+
                    if (lv > bv)
        if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
+
                        diff = lv - bv;
            stack.push(searchNode.more); 
+
                }
            searchNode = searchNode.less;
+
                distance += sqr(diff);
        }
 
        else{
 
            stack.push(searchNode.less);
 
            searchNode = searchNode.more;
 
        }
 
      }
 
      double minDist = Double.POSITIVE_INFINITY;
 
      T minValue = null;
 
 
 
      double[] array = searchNode.pointLocations.array;
 
 
 
  //Find the closest point in this Node and use as a solution
 
      for(int j = searchNode.entries; j-- > 0;){
 
     
 
        double distance = searchNode.pointDist(searchLocation,j);
 
        if(distance < minDist){
 
            minDist = distance;
 
            minValue = searchNode.pointPayloads.get(j);
 
        }
 
      }
 
 
 
  //backtrace stack
 
      while(stack.size() > 0){
 
        searchNode = stack.pop();
 
        if(searchNode.pointRectDist(searchLocation) < minDist){
 
            if(searchNode.pointLocations == null){
 
              if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
 
                  stack.push(searchNode.more);
 
                  stack.push(searchNode.less);
 
              }
 
              else{
 
                  stack.push(searchNode.less);
 
                  stack.push(searchNode.more);
 
              }
 
 
             }
 
             }
             else{
+
             return distance;
              array = searchNode.pointLocations.array;
+
        }
              for(int j = searchNode.entries; j-- > 0;){
+
 
                  double distance = searchNode.pointDist(searchLocation,j);
+
        double pointDist(double[] arr, double[] location, int index) {
                  if(distance < minDist){
+
            double distance = 0;
                    minDist = distance;
+
            int offset = (index + 1) * super._dimensions;
                    minValue = searchNode.pointPayloads.get(j);
+
 
                  }
+
            for (int i = super._dimensions; i-- > 0;) {
              }
+
                distance += sqr(arr[--offset] - location[i]);
 
             }
 
             }
        }
+
            return distance;
      }
+
        }
 
+
 
      return new SearchResult(minDist,minValue); 
+
    }
  }
+
 
  public ArrayList<SearchResult<double[]>> nearestNeighbours(double[] searchLocation, int K){
+
    public static class Manhattan<T> extends KDTree<T> {
 
+
        public Manhattan(int dims) {
      Node searchNode = root;
+
            super(dims);
      ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
+
        }
 
+
 
  //Do a Depth First Search to find the Node this location would be stored
+
        double pointRectDist(int offset, final double[] location) {
      while(searchNode.pointLocations == null){
+
             offset *= (2 * super._dimensions);
     
+
             double distance = 0;
        if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
+
             final double[] array = super.nodeMinMaxBounds.array;
             stack.push(searchNode.more);  
+
             for (int i = 0; i < location.length; i++, offset += 2) {
             searchNode = searchNode.less;
+
 
        }
+
                double diff = 0;
        else{
+
                double bv = array[offset];
             stack.push(searchNode.less);
+
                double lv = location[i];
             searchNode = searchNode.more;
+
                if (bv > lv)
        }
+
                    diff = bv - lv;
      }
+
                else {
      PrioQueue<T> results = new PrioQueue<T>(K);
+
                    bv = array[offset + 1];
 
+
                    if (lv > bv)
      ArrayList<T> payloads = searchNode.pointPayloads;
+
                        diff = lv - bv;
      double[] temp = new double[_dimensions];
+
                }
  //Find the closest point in this Node and use as a solution
+
                distance += (diff);
      for(int j = searchNode.entries; j-- > 0;){
 
     
 
        double distance = searchNode.pointDist(searchLocation,j);
 
        results.offer(payloads.get(j),-distance);
 
      }
 
 
 
 
 
  //backtrace stack
 
      while(stack.size() > 0){
 
        searchNode = stack.pop();
 
        if( searchNode.pointRectDist(searchLocation) < -results.peekPrio()){
 
            if(searchNode.pointLocations == null){
 
              if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
 
                  stack.push(searchNode.more);
 
                  stack.push(searchNode.less);//less will be popped first
 
              }
 
              else{
 
                  stack.push(searchNode.less);
 
                  stack.push(searchNode.more);
 
              }
 
 
             }
 
             }
             else{
+
             return distance;
              payloads = searchNode.pointPayloads;
+
        }
              for(int j = searchNode.entries; j-- > 0;){
+
 
                  double distance = searchNode.pointDist(searchLocation,j);
+
        double pointDist(double[] arr, double[] location, int index) {
                  results.offer(payloads.get(j),-distance);
+
            double distance = 0;
              }
+
            int offset = (index + 1) * super._dimensions;
             
+
 
 +
            for (int i = super._dimensions; i-- > 0;) {
 +
                distance += Math.abs(arr[--offset] - location[i]);
 
             }
 
             }
        }
+
            return distance;
      }
+
        }
 
+
    }
      ArrayList<SearchResult<double[]>> returnResults = new ArrayList<SearchResult<double[]>>(results.elements.size());
+
 
 
+
    public static class WeightedManhattan<T> extends KDTree<T> {
  //for(int i =0, j = results.elements.size(); i<j;i++){//Forward (closest first)
+
        private double[] weights;
      for(int i = results.elements.size(); i-- > 0;){//Reverse (Like Rednaxela Gen2)
 
        PrioQueue<T>.Element e = results.elements.get(i);
 
        SearchResult s = new SearchResult(-e.priority,e.contents);
 
        returnResults.add(s);
 
      }
 
      return returnResults;
 
  }
 
  
//NB! This Priority Queue keeps things with the HIGHEST priority.
+
        public WeightedManhattan(int dims) {
//If you want lowest priority items kept, negate your values
+
            super(dims);
  private static class PrioQueue<S>{
+
            weights = new double[dims];
      ArrayList<Element> elements;
+
            for (int i = 0; i < dims; i++)
      private double minPrio;
+
                weights[i] = 1.0;
      PrioQueue(int size){
+
        }
        elements = new ArrayList<Element>(size);
+
 
        while(size-->0){
+
        public void setWeights(double[] newWeights) {
             elements.add(new Element(null,Double.NEGATIVE_INFINITY));
+
            weights = newWeights;
        }
+
        }
        minPrio = Double.NEGATIVE_INFINITY;
+
 
      }
+
        double pointRectDist(int offset, final double[] location) {
      //uses O(log(n)) comparisons and one big shift of size O(N)
+
             offset *= (2 * super._dimensions);
      //and is MUCH simpler than a heap --> faster JIT
+
            double distance = 0;
      boolean offer(S value,double priority){
+
            final double[] array = super.nodeMinMaxBounds.array;
       
+
            for (int i = 0; i < location.length; i++, offset += 2) {
        //is this point worthy of joining the exulted ranks?
+
 
        if(priority > minPrio){
+
                double diff = 0;
       
+
                double bv = array[offset];
        //recycle object to avoid garbage collector stalls
+
                double lv = location[i];
            Element replace = elements.remove(elements.size() - 1);
+
                if (bv > lv)
             replace.update(value,priority);
+
                    diff = bv - lv;
             add(replace);
+
                else {
       
+
                    bv = array[offset + 1];
             return true;
+
                    if (lv > bv)
        }
+
                        diff = lv - bv;
        return false;
+
                }
      }
+
                distance += (diff) * weights[i];
      void add(Element e){
+
             }
     
+
            return distance;
        //find the right place with a binary search
+
        }
        int index = searchFor(e.priority);
+
 
        //and re-insert updated value (ArrayList automatically shifts other elements up)
+
        double pointDist(double[] arr, double[] location, int index) {
        elements.add(index,e);
+
            double distance = 0;
           
+
             int offset = (index + 1) * super._dimensions;
        minPrio = elements.get(elements.size() - 1).priority;
+
 
      }
+
             for (int i = super._dimensions; i-- > 0;) {
      int searchFor(double priority){
+
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
        int i = elements.size()-1;
+
            }
        int j = 0;  
+
            return distance;
        while(i>=j){
+
        }
             int index = (i+j)>>1;
+
    }
             if(elements.get(index).priority < priority)
+
 
              i = index-1;
+
    public static class SearchResult<S> {
 +
        public double distance;
 +
        public S payload;
 +
 
 +
        SearchResult(double dist, S load) {
 +
            distance = dist;
 +
            payload = load;
 +
        }
 +
    }
 +
 
 +
    private class Node {
 +
 
 +
        // for accessing bounding box data
 +
        // - if trees weren't so unbalanced might be better to use an implicit
 +
        // heap?
 +
        int index;
 +
 
 +
        // keep track of size of subtree
 +
        int entries;
 +
 
 +
        // leaf
 +
        ContiguousDoubleArrayList pointLocations;
 +
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
 +
 
 +
        // stem
 +
        // Node less, more;
 +
        int lessIndex, moreIndex;
 +
        int splitDim;
 +
        double splitVal;
 +
 
 +
        Node() {
 +
            this(new double[_bucketSize * _dimensions]);
 +
        }
 +
 
 +
        Node(double[] pointMemory) {
 +
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
 +
            index = _nodes++;
 +
            nodeList.add(this);
 +
             nodeMinMaxBounds.add(bounds_template);
 +
        }
 +
 
 +
        void search(double[] searchLocation, IntStack stack) {
 +
             if (searchLocation[splitDim] < splitVal)
 +
                stack.push(moreIndex).push(lessIndex); // less will be popped
 +
            // first
 
             else
 
             else
              j = index+1;
+
                stack.push(lessIndex).push(moreIndex); // more will be popped
        }
+
            // first
        return j;
+
        }
      }
+
 
      double peekPrio(){
+
        // returns number of points added to results
        return minPrio;
+
        int search(double[] searchLocation, PrioQueue<T> results) {
      }
+
            int updated = 0;
      /* //Methods for using it as a priority stack - leave them out for now
+
            for (int j = entries; j-- > 0;) {
      void push(S value, double priority){
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
        Element insert = new Element(value,priority);
+
                if (results.peekPrio() > distance) {
        add(insert);
+
                    updated++;
      }
+
                    results.addNoGrow(pointPayloads.get(j), distance);
      S pop(){
+
                }
        Element remove = elements.remove(elements.size() - 1);
+
            }
        if(elements.size() == 0)
+
            return updated;
            minPrio = Double.NEGATIVE_INFINITY;
+
        }
        else
+
 
            minPrio = elements.get(elements.size() - 1).priority;
+
        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {
        return remove.contents;
+
 
      }
+
            for (int j = entries; j-- > 0;) {
      int size(){
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
        return elements.size();
+
                if (radius >= distance) {
      }
+
                    results.add(pointPayloads.get(j));
      void trim(double newMinPrio){
+
                }
        if(newMinPrio > minPrio){
+
            }
             int index = searchFor(newMinPrio);
+
        }
             int size = elements.size();
+
 
            elements.subList(index,elements.size()).clear();
+
        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {
            if(elements.size() == 0)
+
 
              minPrio = Double.NEGATIVE_INFINITY;
+
            for (int j = entries; j-- > 0;)
            else
+
                if (contains(pointLocations.array, mins, maxs, j))
              minPrio = elements.get(elements.size() - 1).priority;
+
                    results.add(pointPayloads.get(j));
        }
+
 
      }
+
        }
    //  */
+
 
      class Element{
+
        void expandBounds(double[] location) {
        S contents;
+
            entries++;
        double priority;
+
            int mio = index * 2 * _dimensions;
       
+
            for (int i = 0; i < _dimensions; i++) {
        Element(S con, double prio){
+
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
            contents = con;
+
                mio++;
            priority = prio;
+
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
        }
+
                mio++;
        void update(S con, double prio){
+
            }
            contents = con;
+
        }
             priority = prio;
+
 
        }
+
        int add(double[] location, T load) {
      }
+
            pointLocations.add(location);
  }
+
            pointPayloads.add(load);
 +
            return entries;
 +
        }
 +
 
 +
        void split() {
 +
             int offset = index * 2 * _dimensions;
 +
 
 +
            double diff = 0;
 +
             for (int i = 0; i < _dimensions; i++) {
 +
                double min = nodeMinMaxBounds.array[offset];
 +
                double max = nodeMinMaxBounds.array[offset + 1];
 +
                if (max - min > diff) {
 +
                    double mean = 0;
 +
                    for (int j = 0; j < entries; j++)
 +
                        mean += pointLocations.array[i + _dimensions * j];
 +
 
 +
                    mean = mean / entries;
 +
                    double varianceSum = 0;
 +
 
 +
                    for (int j = 0; j < entries; j++)
 +
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);
 +
 
 +
                    if (varianceSum > diff * entries) {
 +
                        diff = varianceSum / entries;
 +
                        splitVal = mean;
 +
 
 +
                        splitDim = i;
 +
                    }
 +
                }
 +
                offset += 2;
 +
            }
 +
 
 +
            // kill all the nasties
 +
            if (splitVal == Double.POSITIVE_INFINITY)
 +
                splitVal = Double.MAX_VALUE;
 +
             else if (splitVal == Double.NEGATIVE_INFINITY)
 +
                splitVal = Double.MIN_VALUE;
 +
            else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
 +
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];
  
 +
            Node less = new Node(mem_recycle); // recycle that memory!
 +
            Node more = new Node();
 +
            lessIndex = less.index;
 +
            moreIndex = more.index;
  
  public static class SearchResult<S>{
+
            // reduce garbage by factor of _bucketSize by recycling this array
      double distance;
+
            double[] pointLocation = new double[_dimensions];
      S payload;
+
            for (int i = 0; i < entries; i++) {
      SearchResult(double dist, S load){
+
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
        distance = dist;
+
                T load = pointPayloads.get(i);
        payload = load;
 
      }
 
  }
 
  
  private class Node {
+
                if (pointLocation[splitDim] < splitVal) {
 
+
                    less.expandBounds(pointLocation);
  //for accessing bounding box data
+
                    less.add(pointLocation, load);
  // - if trees weren't so unbalanced might be better to use an implicit heap?
+
                }
      int index;
+
                else {
     
+
                    more.expandBounds(pointLocation);
  //keep track of size of subtree
+
                    more.add(pointLocation, load);
      int entries;
+
                }
 
 
  //leaf
 
      ContiguousDoubleArrayList pointLocations ;
 
      ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
 
     
 
  //stem
 
      Node less, more;
 
      int splitDim;
 
      double splitVal;
 
 
 
      private Node(){
 
        this(new double[_bucketSize*_dimensions]);
 
      }
 
      private Node(double[] pointMemory){
 
        pointLocations = new ContiguousDoubleArrayList(pointMemory);
 
        index = _nodes++;
 
        nodeMinMaxBounds.add(bounds_template);
 
      }
 
      private final double pointRectDist(double[] location){
 
        int minOffset = 2*index*_dimensions;
 
        int maxOffset = minOffset+_dimensions;
 
        double distance=0;
 
        double[] array = nodeMinMaxBounds.array;
 
        for(int i = _dimensions; i-- > 0; ){
 
       
 
            double lowDist = array[i+minOffset] - location[i];
 
            if(lowDist > 0)
 
              distance += sqr(lowDist);
 
            else{
 
              double highDist = location[i] - array[i+maxOffset];
 
              if(highDist > 0)
 
                  distance += sqr(highDist);
 
 
             }
 
             }
        }
+
            if (less.entries * more.entries == 0) {
        return distance;
+
                // one of them was 0, so the split was worthless. throw it away.
      }
+
                _nodes -= 2; // recall that bounds memory
      private final double pointDist(double[] location, int index){
+
                nodeList.remove(moreIndex);
        double distance = 0;
+
                nodeList.remove(lessIndex);
        int offset = index*_dimensions;
 
        for(int i = _dimensions; i-- > 0;)
 
            distance += sqr(pointLocations.array[offset+i] - location[i]);
 
        return distance;
 
      }
 
 
 
      private void expandBounds(double[] location){
 
        entries++;
 
        int offset = index*2*_dimensions;
 
        for(int i = 0; i < _dimensions;i++){
 
            nodeMinMaxBounds.array[offset+i] = Math.min(nodeMinMaxBounds.array[offset+i],location[i]);
 
            nodeMinMaxBounds.array[offset+_dimensions+i] = Math.max(nodeMinMaxBounds.array[offset+_dimensions+i],location[i]);
 
        }
 
      }
 
 
 
      private int add(double[] location, T load){
 
        pointLocations.add(location);
 
        pointPayloads.add(load);
 
        return entries;
 
      }
 
      private void split(){
 
        double diff = 0;
 
        int offset = index*2*_dimensions;
 
        for(int i = 0; i < _dimensions; i++){
 
            double min = nodeMinMaxBounds.array[offset+i];
 
            double max = nodeMinMaxBounds.array[offset+_dimensions+i];
 
            if(max - min > diff){
 
              diff = max - min;
 
              splitVal = 0.5*(max + min);
 
              splitDim = i;
 
 
             }
 
             }
        }
+
            else {
     
+
 
     
+
                // we won't be needing that now, so keep it for the next split
        less = new Node(mem_recycle);//recycle that memory!
+
                // to reduce garbage
        more = new Node();
+
                mem_recycle = pointLocations.array;
       
+
 
        //reduce garbage by factor of _bucketSize by recycling this array
+
                pointLocations = null;
        double[] pointLocation = new double[_dimensions];
+
 
        for(int i = 0; i < entries; i++){
+
                pointPayloads.clear();
            System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);
+
                pointPayloads = null;
            T load = pointPayloads.get(i);
 
       
 
            if(pointLocation[splitDim] < splitVal){
 
              less.expandBounds(pointLocation);
 
              less.add(pointLocation,load);
 
 
             }
 
             }
            else{
+
        }
              more.expandBounds(pointLocation);  
+
 
              more.add(pointLocation,load);
+
    }
 +
 
 +
    // NB! This Priority Queue keeps things with the LOWEST priority.
 +
    // If you want highest priority items kept, negate your values
 +
    private static class PrioQueue<S> {
 +
 
 +
        Object[] elements;
 +
        double[] priorities;
 +
        private double minPrio;
 +
        private int size;
 +
 
 +
        PrioQueue(int size, boolean prefill) {
 +
            elements = new Object[size];
 +
            priorities = new double[size];
 +
            Arrays.fill(priorities, Double.POSITIVE_INFINITY);
 +
            if (prefill) {
 +
                minPrio = Double.POSITIVE_INFINITY;
 +
                this.size = size;
 
             }
 
             }
        }
+
        }
        if(less.entries*more.entries == 0){
 
        //one of them was 0, so the split was worthless. throw it away.
 
            less = null;
 
            more = null;
 
        }
 
        else{
 
       
 
        //we won't be needing that now, so keep it for the next split to reduce garbage
 
            mem_recycle = pointLocations.array;
 
       
 
            pointLocations = null;
 
       
 
            pointPayloads.clear();
 
            pointPayloads = null;
 
        }
 
      }
 
 
 
  }
 
  
 +
        // uses O(log(n)) comparisons and one big shift of size O(N)
 +
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT
  
  private static class ContiguousDoubleArrayList{
+
        void addNoGrow(S value, double priority) {
      double[] array;
+
            int index = searchFor(priority);
      int size;
+
            int nextIndex = index + 1;
      ContiguousDoubleArrayList(){
+
            int length = size - nextIndex;
        this(300);
+
             System.arraycopy(elements, index, elements, nextIndex, length);
      }
+
            System.arraycopy(priorities, index, priorities, nextIndex, length);
      ContiguousDoubleArrayList(int size){
+
            elements[index] = value;
        this(new double[size]);
+
            priorities[index] = priority;
      }
 
      ContiguousDoubleArrayList(double[] data){
 
        array = data;
 
      }
 
      ContiguousDoubleArrayList add(double[] da){
 
        if(size + da.length >= array.length){
 
             array = Arrays.copyOf(array,(array.length+da.length)*2);
 
        }
 
        System.arraycopy(da,0,array,size,da.length);
 
        size += da.length;
 
        return this;
 
      }
 
  }
 
  
  private static final double sqr(double d){
+
            minPrio = priorities[size - 1];
      return d*d;}
+
        }
 +
 
 +
        int searchFor(double priority) {
 +
            int i = size - 1;
 +
            int j = 0;
 +
            while (i >= j) {
 +
                int index = (i + j) >>> 1;
 +
                if (priorities[index] < priority)
 +
                    j = index + 1;
 +
                else
 +
                    i = index - 1;
 +
            }
 +
            return j;
 +
        }
 +
 
 +
        double peekPrio() {
 +
            return minPrio;
 +
        }
 +
    }
 +
 
 +
    private static class ContiguousDoubleArrayList {
 +
        double[] array;
 +
        int size;
 +
 
 +
        ContiguousDoubleArrayList(int size) {
 +
            this(new double[size]);
 +
        }
 +
 
 +
        ContiguousDoubleArrayList(double[] data) {
 +
            array = data;
 +
        }
 +
 
 +
        ContiguousDoubleArrayList add(double[] da) {
 +
            if (size + da.length > array.length)
 +
                array = Arrays.copyOf(array, (array.length + da.length) * 2);
 +
 
 +
            System.arraycopy(da, 0, array, size, da.length);
 +
            size += da.length;
 +
            return this;
 +
        }
 +
    }
 +
 
 +
    private static class IntStack {
 +
        int[] array;
 +
        int size;
 +
 
 +
        IntStack() {
 +
            this(64);
 +
        }
 +
 
 +
        IntStack(int size) {
 +
            this(new int[size]);
 +
        }
 +
 
 +
        IntStack(int[] data) {
 +
            array = data;
 +
        }
 +
 
 +
        IntStack push(int i) {
 +
            if (size >= array.length)
 +
                array = Arrays.copyOf(array, (array.length + 1) * 2);
 +
 
 +
            array[size++] = i;
 +
            return this;
 +
        }
 +
 
 +
        int pop() {
 +
            return array[--size];
 +
        }
 +
 
 +
        int size() {
 +
            return size;
 +
        }
 +
    }
 +
 
 +
    static final double sqr(double d) {
 +
        return d * d;
 +
    }
  
 
}
 
}
 
 
</syntaxhighlight></code>
 
</syntaxhighlight></code>

Latest revision as of 21:46, 20 September 2017

Latest version is available here: [1]

A possibly outdated version is listed below:

package jk.tree;
/*
 ** KDTree.java by Julian Kent
 **
 ** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
 **
 ** Licence summary:
 ** Under this licence you are free to:
 **      Share : copy and redistribute the material in any medium or format
 **      Adapt : remix, transform, and build upon the material
 **      The licensor cannot revoke these freedoms as long as you follow the license terms.
 **
 ** Under the following terms:
 **      Attribution:
 **            You must give appropriate credit, provide a link to the license, and indicate
 **            if changes were made. You may do so in any reasonable manner, but not in any
 **            way that suggests the licensor endorses you or your use.
 **      NonCommercial:
 **            You may not use the material for commercial purposes.
 **      ShareAlike:
 **            If you remix, transform, or build upon the material, you must distribute your
 **            contributions under the same license as the original.
 **      No additional restrictions:
 **            You may not apply legal terms or technological measures that legally restrict
 **            others from doing anything the license permits.
 **
 ** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
 **
 ** For additional licencing rights (including commercial) please contact jkflying@gmail.com
 **
 */

import java.util.ArrayList;
import java.util.Arrays;

public abstract class KDTree<T> {

    // use a big bucketSize so that we have less node bounds (for more cache
    // hits) and better splits
    // if you have lots of dimensions this should be big, and if you have few small
    private static final int _bucketSize = 50;

    private final int _dimensions;
    private int _nodes;
    private final Node root;
    private final ArrayList<Node> nodeList = new ArrayList<Node>();

    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
    // time a leaf splits
    private double[] mem_recycle;

    // the starting values for bounding boxes, for easy access
    private final double[] bounds_template;

    // one big self-expanding array to keep all the node bounding boxes so that
    // they stay in cache
    // node bounds available at:
    // low: 2 * _dimensions * node.index + 2 * dim
    // high: 2 * _dimensions * node.index + 2 * dim + 1
    private final ContiguousDoubleArrayList nodeMinMaxBounds;

    private KDTree(int dimensions) {
        _dimensions = dimensions;

        // initialise this big so that it ends up in 'old' memory
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
        mem_recycle = new double[_bucketSize * dimensions];

        bounds_template = new double[2 * _dimensions];
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
            bounds_template[i] = Double.POSITIVE_INFINITY;

        // and.... start!
        root = new Node();
    }

    public int nodes() {
        return _nodes;
    }

    public int size() {
        return root.entries;
    }

    public int addPoint(double[] location, T payload) {

        Node addNode = root;
        // Do a Depth First Search to find the Node where 'location' should be
        // stored
        while (addNode.pointLocations == null) {
            addNode.expandBounds(location);
            if (location[addNode.splitDim] < addNode.splitVal)
                addNode = nodeList.get(addNode.lessIndex);
            else
                addNode = nodeList.get(addNode.moreIndex);
        }
        addNode.expandBounds(location);

        int nodeSize = addNode.add(location, payload);

        if (nodeSize % _bucketSize == 0)
            // try splitting again once every time the node passes a _bucketSize
            // multiple
            // in case it is full of points of the same location and won't split
            addNode.split();

        return root.entries;
    }

    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {

        K = Math.min(K, size());

        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);

        if (K > 0) {
            IntStack stack = new IntStack();
            PrioQueue<T> results = new PrioQueue<T>(K, true);

            stack.push(root.index);

            int added = 0;

            while (stack.size() > 0) {
                int nodeIndex = stack.pop();
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
                    Node node = nodeList.get(nodeIndex);
                    if (node.pointLocations == null)
                        node.search(searchLocation, stack);
                    else
                        added += node.search(searchLocation, results);
                }
            }

            double[] priorities = results.priorities;
            Object[] elements = results.elements;
            for (int i = 0; i < K; i++) { // forward (closest first)
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
                returnResults.add(s);
            }
        }
        return returnResults;
    }

    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchBall(searchLocation, radius, results);
            }
        }
        return results;
    }

    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (overlaps(mins, maxs, nodeIndex)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchRect(mins, maxs, results);
            }
        }
        return results;

    }

    abstract double pointRectDist(int offset, final double[] location);

    abstract double pointDist(double[] arr, double[] location, int index);

    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {

        int offset = (index + 1) * mins.length;

        for (int i = mins.length; i-- > 0;) {
            double d = arr[--offset];
            if (mins[i] > d | d > maxs[i])
                return false;
        }
        return true;
    }

    boolean overlaps(double[] mins, double[] maxs, int offset) {
        offset *= (2 * maxs.length);
        final double[] array = nodeMinMaxBounds.array;
        for (int i = 0; i < maxs.length; i++, offset += 2) {
            double bmin = array[offset], bmax = array[offset + 1];
            if (mins[i] > bmax | maxs[i] < bmin)
                return false;
        }

        return true;
    }

    public static class Euclidean<T> extends KDTree<T> {
        public Euclidean(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += sqr(diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += sqr(arr[--offset] - location[i]);
            }
            return distance;
        }

    }

    public static class Manhattan<T> extends KDTree<T> {
        public Manhattan(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]);
            }
            return distance;
        }
    }

    public static class WeightedManhattan<T> extends KDTree<T> {
        private double[] weights;

        public WeightedManhattan(int dims) {
            super(dims);
            weights = new double[dims];
            for (int i = 0; i < dims; i++)
                weights[i] = 1.0;
        }

        public void setWeights(double[] newWeights) {
            weights = newWeights;
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff) * weights[i];
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
            }
            return distance;
        }
    }

    public static class SearchResult<S> {
        public double distance;
        public S payload;

        SearchResult(double dist, S load) {
            distance = dist;
            payload = load;
        }
    }

    private class Node {

        // for accessing bounding box data
        // - if trees weren't so unbalanced might be better to use an implicit
        // heap?
        int index;

        // keep track of size of subtree
        int entries;

        // leaf
        ContiguousDoubleArrayList pointLocations;
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);

        // stem
        // Node less, more;
        int lessIndex, moreIndex;
        int splitDim;
        double splitVal;

        Node() {
            this(new double[_bucketSize * _dimensions]);
        }

        Node(double[] pointMemory) {
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
            index = _nodes++;
            nodeList.add(this);
            nodeMinMaxBounds.add(bounds_template);
        }

        void search(double[] searchLocation, IntStack stack) {
            if (searchLocation[splitDim] < splitVal)
                stack.push(moreIndex).push(lessIndex); // less will be popped
            // first
            else
                stack.push(lessIndex).push(moreIndex); // more will be popped
            // first
        }

        // returns number of points added to results
        int search(double[] searchLocation, PrioQueue<T> results) {
            int updated = 0;
            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (results.peekPrio() > distance) {
                    updated++;
                    results.addNoGrow(pointPayloads.get(j), distance);
                }
            }
            return updated;
        }

        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {

            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (radius >= distance) {
                    results.add(pointPayloads.get(j));
                }
            }
        }

        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {

            for (int j = entries; j-- > 0;)
                if (contains(pointLocations.array, mins, maxs, j))
                    results.add(pointPayloads.get(j));

        }

        void expandBounds(double[] location) {
            entries++;
            int mio = index * 2 * _dimensions;
            for (int i = 0; i < _dimensions; i++) {
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
            }
        }

        int add(double[] location, T load) {
            pointLocations.add(location);
            pointPayloads.add(load);
            return entries;
        }

        void split() {
            int offset = index * 2 * _dimensions;

            double diff = 0;
            for (int i = 0; i < _dimensions; i++) {
                double min = nodeMinMaxBounds.array[offset];
                double max = nodeMinMaxBounds.array[offset + 1];
                if (max - min > diff) {
                    double mean = 0;
                    for (int j = 0; j < entries; j++)
                        mean += pointLocations.array[i + _dimensions * j];

                    mean = mean / entries;
                    double varianceSum = 0;

                    for (int j = 0; j < entries; j++)
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);

                    if (varianceSum > diff * entries) {
                        diff = varianceSum / entries;
                        splitVal = mean;

                        splitDim = i;
                    }
                }
                offset += 2;
            }

            // kill all the nasties
            if (splitVal == Double.POSITIVE_INFINITY)
                splitVal = Double.MAX_VALUE;
            else if (splitVal == Double.NEGATIVE_INFINITY)
                splitVal = Double.MIN_VALUE;
            else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];

            Node less = new Node(mem_recycle); // recycle that memory!
            Node more = new Node();
            lessIndex = less.index;
            moreIndex = more.index;

            // reduce garbage by factor of _bucketSize by recycling this array
            double[] pointLocation = new double[_dimensions];
            for (int i = 0; i < entries; i++) {
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
                T load = pointPayloads.get(i);

                if (pointLocation[splitDim] < splitVal) {
                    less.expandBounds(pointLocation);
                    less.add(pointLocation, load);
                }
                else {
                    more.expandBounds(pointLocation);
                    more.add(pointLocation, load);
                }
            }
            if (less.entries * more.entries == 0) {
                // one of them was 0, so the split was worthless. throw it away.
                _nodes -= 2; // recall that bounds memory
                nodeList.remove(moreIndex);
                nodeList.remove(lessIndex);
            }
            else {

                // we won't be needing that now, so keep it for the next split
                // to reduce garbage
                mem_recycle = pointLocations.array;

                pointLocations = null;

                pointPayloads.clear();
                pointPayloads = null;
            }
        }

    }

    // NB! This Priority Queue keeps things with the LOWEST priority.
    // If you want highest priority items kept, negate your values
    private static class PrioQueue<S> {

        Object[] elements;
        double[] priorities;
        private double minPrio;
        private int size;

        PrioQueue(int size, boolean prefill) {
            elements = new Object[size];
            priorities = new double[size];
            Arrays.fill(priorities, Double.POSITIVE_INFINITY);
            if (prefill) {
                minPrio = Double.POSITIVE_INFINITY;
                this.size = size;
            }
        }

        // uses O(log(n)) comparisons and one big shift of size O(N)
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT

        void addNoGrow(S value, double priority) {
            int index = searchFor(priority);
            int nextIndex = index + 1;
            int length = size - nextIndex;
            System.arraycopy(elements, index, elements, nextIndex, length);
            System.arraycopy(priorities, index, priorities, nextIndex, length);
            elements[index] = value;
            priorities[index] = priority;

            minPrio = priorities[size - 1];
        }

        int searchFor(double priority) {
            int i = size - 1;
            int j = 0;
            while (i >= j) {
                int index = (i + j) >>> 1;
                if (priorities[index] < priority)
                    j = index + 1;
                else
                    i = index - 1;
            }
            return j;
        }

        double peekPrio() {
            return minPrio;
        }
    }

    private static class ContiguousDoubleArrayList {
        double[] array;
        int size;

        ContiguousDoubleArrayList(int size) {
            this(new double[size]);
        }

        ContiguousDoubleArrayList(double[] data) {
            array = data;
        }

        ContiguousDoubleArrayList add(double[] da) {
            if (size + da.length > array.length)
                array = Arrays.copyOf(array, (array.length + da.length) * 2);

            System.arraycopy(da, 0, array, size, da.length);
            size += da.length;
            return this;
        }
    }

    private static class IntStack {
        int[] array;
        int size;

        IntStack() {
            this(64);
        }

        IntStack(int size) {
            this(new int[size]);
        }

        IntStack(int[] data) {
            array = data;
        }

        IntStack push(int i) {
            if (size >= array.length)
                array = Arrays.copyOf(array, (array.length + 1) * 2);

            array[size++] = i;
            return this;
        }

        int pop() {
            return array[--size];
        }

        int size() {
            return size;
        }
    }

    static final double sqr(double d) {
        return d * d;
    }

}