Difference between revisions of "User:Skilgannon/KDTree"

From Robowiki
Jump to navigation Jump to search
(2 step search, little changes)
m (Update, link to latest on Bitbucket)
 
(7 intermediate revisions by the same user not shown)
Line 1: Line 1:
 +
Latest version is available here: [https://bitbucket.org/jkflying/kd-tree]
 +
 +
A possibly outdated version is listed below:
 +
 
<code><syntaxhighlight>
 
<code><syntaxhighlight>
 
+
package jk.tree;
 
/*
 
/*
** KDTree.java by Julian Kent
+
** KDTree.java by Julian Kent
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
+
**
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
+
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
** For additional licencing rights please contact jkflying@gmail.com
+
**
**
+
** Licence summary:
** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree
+
** Under this licence you are free to:
*/
+
**      Share : copy and redistribute the material in any medium or format
 
+
**      Adapt : remix, transform, and build upon the material
 +
**      The licensor cannot revoke these freedoms as long as you follow the license terms.
 +
**
 +
** Under the following terms:
 +
**      Attribution:
 +
**            You must give appropriate credit, provide a link to the license, and indicate
 +
**            if changes were made. You may do so in any reasonable manner, but not in any
 +
**            way that suggests the licensor endorses you or your use.
 +
**      NonCommercial:
 +
**            You may not use the material for commercial purposes.
 +
**      ShareAlike:
 +
**            If you remix, transform, or build upon the material, you must distribute your
 +
**            contributions under the same license as the original.
 +
**      No additional restrictions:
 +
**            You may not apply legal terms or technological measures that legally restrict
 +
**            others from doing anything the license permits.
 +
**
 +
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
 +
**
 +
** For additional licencing rights (including commercial) please contact jkflying@gmail.com
 +
**
 +
*/
  
package jk.mega;
 
import java.util.ArrayDeque;
 
 
import java.util.ArrayList;
 
import java.util.ArrayList;
 
import java.util.Arrays;
 
import java.util.Arrays;
//import ags.utils.*;
 
//import ags.utils.dataStructures.*;
 
  
public class KDTree<T>{
+
public abstract class KDTree<T> {
 +
 
 +
    // use a big bucketSize so that we have less node bounds (for more cache
 +
    // hits) and better splits
 +
    // if you have lots of dimensions this should be big, and if you have few small
 +
    private static final int _bucketSize = 50;
 +
 
 +
    private final int _dimensions;
 +
    private int _nodes;
 +
    private final Node root;
 +
    private final ArrayList<Node> nodeList = new ArrayList<Node>();
 +
 
 +
    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
 +
    // time a leaf splits
 +
    private double[] mem_recycle;
 +
 
 +
    // the starting values for bounding boxes, for easy access
 +
    private final double[] bounds_template;
 +
 
 +
    // one big self-expanding array to keep all the node bounding boxes so that
 +
    // they stay in cache
 +
    // node bounds available at:
 +
    // low: 2 * _dimensions * node.index + 2 * dim
 +
    // high: 2 * _dimensions * node.index + 2 * dim + 1
 +
    private final ContiguousDoubleArrayList nodeMinMaxBounds;
 +
 
 +
    private KDTree(int dimensions) {
 +
        _dimensions = dimensions;
 +
 
 +
        // initialise this big so that it ends up in 'old' memory
 +
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
 +
        mem_recycle = new double[_bucketSize * dimensions];
 +
 
 +
        bounds_template = new double[2 * _dimensions];
 +
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
 +
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
 +
            bounds_template[i] = Double.POSITIVE_INFINITY;
 +
 
 +
        // and.... start!
 +
        root = new Node();
 +
    }
 +
 
 +
    public int nodes() {
 +
        return _nodes;
 +
    }
 +
 
 +
    public int size() {
 +
        return root.entries;
 +
    }
 +
 
 +
    public int addPoint(double[] location, T payload) {
 +
 
 +
        Node addNode = root;
 +
        // Do a Depth First Search to find the Node where 'location' should be
 +
        // stored
 +
        while (addNode.pointLocations == null) {
 +
            addNode.expandBounds(location);
 +
            if (location[addNode.splitDim] < addNode.splitVal)
 +
                addNode = nodeList.get(addNode.lessIndex);
 +
            else
 +
                addNode = nodeList.get(addNode.moreIndex);
 +
        }
 +
        addNode.expandBounds(location);
 +
 
 +
        int nodeSize = addNode.add(location, payload);
 +
 
 +
        if (nodeSize % _bucketSize == 0)
 +
            // try splitting again once every time the node passes a _bucketSize
 +
            // multiple
 +
            // in case it is full of points of the same location and won't split
 +
            addNode.split();
 +
 
 +
        return root.entries;
 +
    }
 +
 
 +
    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {
 +
 
 +
        K = Math.min(K, size());
  
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
+
        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
  private static final int  _bucketSize = 50;
 
  
  private final int _dimensions;
+
        if (K > 0) {
  private int _nodes;  
+
            IntStack stack = new IntStack();
  private Node root;
+
            PrioQueue<T> results = new PrioQueue<T>(K, true);
 
 
  //prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
 
  private double[] mem_recycle;
 
 
 
  //the starting values for bounding boxes, for easy access
 
  private final double[] bounds_template;
 
  
  //one big self-expanding array to keep all the node bounding boxes so that they stay in cache
+
            stack.push(root.index);
  // node bounds available at:
+
 
  //low:  2 * _dimensions * node.index + 2 * dim
+
            int added = 0;
  //high: 2 * _dimensions * node.index + 2 * dim + 1
+
 
  private ContiguousDoubleArrayList nodeMinMaxBounds;
+
            while (stack.size() > 0) {
/*
+
                int nodeIndex = stack.pop();
  public static void main(String[] args){
+
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
      int dims = 1;
+
                    Node node = nodeList.get(nodeIndex);
      int size = 2000000;
+
                    if (node.pointLocations == null)
      int testsize = 1;
+
                        node.search(searchLocation, stack);
      int k = 40;
+
                    else
      int iterations = 1;
+
                        added += node.search(searchLocation, results);
      System.out.println(
+
                }
        "Config:\n"
+
            }
        + "No JIT Warmup\n"
+
 
        + "Tested on random data.\n"
+
            double[] priorities = results.priorities;
        + "Training and testing points shared across iterations.\n"
+
            Object[] elements = results.elements;
        + "Searches interleaved.");
+
            for (int i = 0; i < K; i++) { // forward (closest first)
      System.out.println("Num points:    " + size);
+
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
      System.out.println("Num searches:  " + testsize);
+
                returnResults.add(s);
      System.out.println("Dimensions:    " + dims);
+
            }
      System.out.println("Num Neighbours: " + k);
+
        }
      System.out.println();
+
        return returnResults;
      ArrayList<double[]> locs = new ArrayList<double[]>(size);
+
    }
      for(int i = 0; i < size; i++){
+
 
        double[] loc = new double[dims];
+
    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
        for(int j = 0; j < dims; j++)
+
        IntStack stack = new IntStack();
            loc[j] = Math.random();
+
        ArrayList<T> results = new ArrayList<T>();
        locs.add(loc);
+
 
      }
+
        stack.push(root.index);
      ArrayList<double[]> testlocs = new ArrayList<double[]>(testsize);
+
 
      for(int i = 0; i < testsize; i++){
+
        while (stack.size() > 0) {
        double[] loc = new double[dims];
+
            int nodeIndex = stack.pop();
        for(int j = 0; j < dims; j++)
+
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
            loc[j] = Math.random();
+
                Node node = nodeList.get(nodeIndex);
        testlocs.add(loc);
+
                if (node.pointLocations == null)
      }
+
                    stack.push(node.moreIndex).push(node.lessIndex);
      for(int r = 0; r < iterations; r++){
+
                else
        long t1 = System.nanoTime();
+
                    node.searchBall(searchLocation, radius, results);
        KDTree<double[]> t = new KDTree<double[]>(dims);// This tree
+
            }
        for(int i = 0; i < size; i++){
+
        }
            t.addPoint(locs.get(i),locs.get(i));
+
        return results;
        }
+
    }
        long t2 = System.nanoTime();
+
 
        KdTree<double[]> rt = new KdTree.Euclidean<double[]>(dims,null); //Rednaxela Gen2
+
    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
        for(int i = 0; i < size; i++){
+
        IntStack stack = new IntStack();
            rt.addPoint(locs.get(i),locs.get(i));
+
        ArrayList<T> results = new ArrayList<T>();
        }
+
 
        long t3 = System.nanoTime();
+
        stack.push(root.index);
     
+
 
        long jtn = 0;
+
        while (stack.size() > 0) {
        long rtn = 0;
+
            int nodeIndex = stack.pop();
        long mjtn = 0;
+
            if (overlaps(mins, maxs, nodeIndex)) {
        long mrtn = 0;
+
                Node node = nodeList.get(nodeIndex);
     
+
                if (node.pointLocations == null)
        double dist1 = 0, dist2 = 0;
+
                    stack.push(node.moreIndex).push(node.lessIndex);
        for(int i = 0; i < testsize; i++){
+
                else
             long t4 = System.nanoTime();
+
                    node.searchRect(mins, maxs, results);
             dist1 += t.nearestNeighbours(testlocs.get(i),k).iterator().next().distance;
+
            }
            long t5 = System.nanoTime();
+
        }
            dist2 += rt.nearestNeighbor(testlocs.get(i),k,true).iterator().next().distance;
+
        return results;
            long t6 = System.nanoTime();
+
 
            long t7 = System.nanoTime();
+
    }
            jtn += t5 - t4 - (t7 - t6);
+
 
             rtn += t6 - t5 - (t7 - t6);  
+
    abstract double pointRectDist(int offset, final double[] location);
             mjtn = Math.max(mjtn,t5 - t4 - (t7 - t6));
+
 
            mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6));
+
    abstract double pointDist(double[] arr, double[] location, int index);
        }
+
 
     
+
    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {
        System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!"));
+
 
        if(Math.abs(dist1-dist2) > 1e-10){
+
        int offset = (index + 1) * mins.length;
             System.out.println("dist1: " + dist1 + "    dist2: " + dist2);
+
 
        }
+
        for (int i = mins.length; i-- > 0;) {
        long jts = t2 - t1;
+
             double d = arr[--offset];
        long rts = t3 - t2;
+
             if (mins[i] > d | d > maxs[i])
        System.out.println("Iteration:      " + (r+1) + "/" + iterations);
+
                return false;
     
+
        }
        System.out.println("This tree add avg:  " + jts/size + " ns");
+
        return true;
        System.out.println("Reds tree add avg:  " + rts/size + " ns");
+
    }
     
+
 
        System.out.println("This tree knn avg:  " + jtn/testsize + " ns");
+
    boolean overlaps(double[] mins, double[] maxs, int offset) {
        System.out.println("Reds tree knn avg:  " + rtn/testsize + " ns");
+
        offset *= (2 * maxs.length);
        System.out.println("This tree knn max:  " + mjtn + " ns");
+
        final double[] array = nodeMinMaxBounds.array;
        System.out.println("Reds tree knn max:  " + mrtn + " ns");
+
        for (int i = 0; i < maxs.length; i++, offset += 2) {
        System.out.println();
+
             double bmin = array[offset], bmax = array[offset + 1];
      }
+
             if (mins[i] > bmax | maxs[i] < bmin)
  }
+
                return false;
  // */
+
        }
 +
 
 +
        return true;
 +
    }
 +
 
 +
    public static class Euclidean<T> extends KDTree<T> {
 +
        public Euclidean(int dims) {
 +
             super(dims);
 +
        }
 +
 
 +
        double pointRectDist(int offset, final double[] location) {
 +
            offset *= (2 * super._dimensions);
 +
            double distance = 0;
 +
            final double[] array = super.nodeMinMaxBounds.array;
 +
            for (int i = 0; i < location.length; i++, offset += 2) {
  
  public KDTree(int dimensions){
+
                double diff = 0;
      _dimensions = dimensions;
+
                double bv = array[offset];
 
+
                double lv = location[i];
  //initialise this so that it ends up in 'old' memory
+
                if (bv > lv)
      nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2*_dimensions);
+
                    diff = bv - lv;
      mem_recycle = new double[_bucketSize*dimensions];
+
                else {
 
+
                    bv = array[offset + 1];
      bounds_template = new double[2*_dimensions];
+
                    if (lv > bv)
      Arrays.fill(bounds_template,Double.NEGATIVE_INFINITY);
+
                        diff = lv - bv;
      for(int i = 0, max = 2*_dimensions; i < max; i+=2)
+
                }
        bounds_template[i] = Double.POSITIVE_INFINITY;
+
                distance += sqr(diff);
 
+
            }
  //and.... start!
+
            return distance;
      root = new Node();
+
        }
  }
 
  public int nodes(){
 
      return _nodes;
 
  }
 
  public int addPoint(double[] location, T payload){
 
 
 
      Node addNode = root;
 
  //Do a Depth First Search to find the Node where 'location' should be stored
 
      while(addNode.pointLocations == null){
 
        addNode.expandBounds(location);
 
        if(location[addNode.splitDim] < addNode.splitVal)
 
            addNode = addNode.less;
 
        else
 
            addNode = addNode.more;
 
      }
 
      addNode.expandBounds(location);
 
 
 
      int nodeSize = addNode.add(location,payload);
 
 
 
      if(nodeSize % _bucketSize == 0)
 
      //try splitting again once every time the node passes a _bucketSize multiple
 
        addNode.split();
 
 
 
      return root.entries;
 
  }
 
  
 +
        double pointDist(double[] arr, double[] location, int index) {
 +
            double distance = 0;
 +
            int offset = (index + 1) * super._dimensions;
  
  public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K){
+
             for (int i = super._dimensions; i-- > 0;) {
      ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
+
                distance += sqr(arr[--offset] - location[i]);
      PrioQueue<T> results = new PrioQueue<T>(K,true);
+
            }
 
+
            return distance;
      stack.push(root);
+
        }
 
 
      int added = 0;
 
      while(added < K )
 
        added += stack.pop().search(searchLocation,stack,results);
 
           
 
      double bestDist = -results.peekPrio();
 
      while(stack.size() > 0 ){
 
        Node searchNode = stack.poll();
 
        if(bestDist >= searchNode.pointRectDist(searchLocation)){
 
            searchNode.search(searchLocation,stack,results);
 
             bestDist = -results.peekPrio();
 
        }
 
      }
 
     
 
      ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
 
 
 
      for(int i = K; i-- > 0;){//Reverse (furthest first, like Rednaxela Gen2)
 
        SearchResult s = new SearchResult(-results.priorities[i],results.elements[i]);
 
        returnResults.add(s);
 
      }
 
      return returnResults;
 
  }
 
 
 
  
    //NB! This Priority Queue keeps things with the HIGHEST priority.
+
    }
//If you want lowest priority items kept, negate your values
 
  private static class PrioQueue<S>{
 
 
 
      Object[] elements;
 
      double[] priorities;
 
      private double minPrio;
 
      private int size;
 
 
 
      PrioQueue(int size, boolean prefill){
 
        elements = new Object[size];
 
        priorities = new double[size];
 
        Arrays.fill(priorities,Double.NEGATIVE_INFINITY);
 
        if(prefill){
 
            minPrio = Double.NEGATIVE_INFINITY;
 
            this.size = size;
 
        }
 
      }
 
      //uses O(log(n)) comparisons and one big shift of size O(N)
 
      //and is MUCH simpler than a heap --> faster on small sets, faster JIT
 
       
 
      void addNoGrow(S value, double priority){
 
        int index = searchFor(priority);
 
        int nextIndex = index + 1;
 
        int length = size - index - 1;//remove dependancy on nextIndex
 
        System.arraycopy(elements,index,elements,nextIndex,length);
 
        System.arraycopy(priorities,index,priorities,nextIndex,length);
 
        elements[index]=value;
 
        priorities[index]=priority;
 
           
 
        minPrio = priorities[size-1];
 
      }
 
 
 
      int searchFor(double priority){
 
        int i = size-1;
 
        int j = 0; 
 
        while(i>=j){
 
            int index = (i+j)>>>1;
 
       
 
            if( priorities[index] > priority)
 
              j = index+1;
 
            else
 
              i = index-1;
 
        }
 
        return j;
 
      }
 
      double peekPrio(){
 
        return minPrio;
 
      }
 
      /*
 
      //Methods for using it as a priority stack - leave them out for now
 
      void push(S value, double priority){
 
        if(++size > elements.length){
 
            elements = Arrays.copyOf(elements,size*2);
 
            priorities = Arrays.copyOf(priorities,size*2);
 
            Arrays.fill(priorities,size,size*2,Double.NEGATIVE_INFINITY);
 
            System.out.println("Expanding PrioQueue to " + elements.length);
 
        }
 
        addNoGrow(value,priority);
 
      }
 
      void pushTop(S value, double priority){
 
        if(++size > elements.length){
 
            elements = Arrays.copyOf(elements,size*2);
 
            priorities = Arrays.copyOf(priorities,size*2);
 
            Arrays.fill(priorities,size,size*2,Double.NEGATIVE_INFINITY);
 
            System.out.println("Expanding PrioQueue to " + elements.length);
 
        }
 
     
 
        elements[size-1] = value;
 
        priorities[size-1] = priority;
 
        minPrio = priority;
 
       
 
      }
 
     
 
      S pop(){
 
        Object value = elements[--size];
 
        priorities[size] = Double.NEGATIVE_INFINITY;
 
        if(size == 0)
 
            minPrio = Double.NEGATIVE_INFINITY;
 
        else
 
            minPrio = priorities[size-1];
 
        return (S)value;
 
      }
 
      int size(){
 
        return size-min;
 
      }
 
    //  */
 
  }
 
  
 +
    public static class Manhattan<T> extends KDTree<T> {
 +
        public Manhattan(int dims) {
 +
            super(dims);
 +
        }
  
  public static class SearchResult<S>{
+
        double pointRectDist(int offset, final double[] location) {
      public double distance;
+
            offset *= (2 * super._dimensions);
      public S payload;
+
            double distance = 0;
      SearchResult(double dist, S load){
+
            final double[] array = super.nodeMinMaxBounds.array;
        distance = dist;
+
            for (int i = 0; i < location.length; i++, offset += 2) {
        payload = load;
 
      }
 
  }
 
  
  private class Node {
+
                double diff = 0;
 
+
                double bv = array[offset];
  //for accessing bounding box data
+
                double lv = location[i];
  // - if trees weren't so unbalanced might be better to use an implicit heap?
+
                if (bv > lv)
      int index;
+
                    diff = bv - lv;
     
+
                else {
  //keep track of size of subtree
+
                    bv = array[offset + 1];
      int entries;
+
                    if (lv > bv)
 
+
                        diff = lv - bv;
  //leaf
+
                }
      ContiguousDoubleArrayList pointLocations ;
+
                distance += (diff);
      ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
+
            }
     
+
            return distance;
  //stem
+
        }
      Node less, more;
+
 
      int splitDim;
+
        double pointDist(double[] arr, double[] location, int index) {
      double splitVal;
+
            double distance = 0;
 
+
            int offset = (index + 1) * super._dimensions;
      private Node(){
+
 
        this(new double[_bucketSize*_dimensions]);
+
            for (int i = super._dimensions; i-- > 0;) {
      }
+
                distance += Math.abs(arr[--offset] - location[i]);
      private Node(double[] pointMemory){
 
        pointLocations = new ContiguousDoubleArrayList(pointMemory);
 
        index = _nodes++;
 
        nodeMinMaxBounds.add(bounds_template);
 
      }
 
      private final double pointRectDist(double[] location){
 
        int offset = (2*_dimensions)*(index+1)-2;
 
        double distance=0;
 
        double[] array = nodeMinMaxBounds.array;
 
        for(int i = _dimensions; i-- > 0; offset -= 2){
 
       
 
            double diff = 0;
 
            double bv = array[offset];
 
            double lv = location[i];
 
            if(bv > lv)
 
              diff = bv-lv;
 
            else{
 
              bv=array[offset+1];
 
              if(lv>bv)
 
                  diff = lv-bv;
 
 
             }
 
             }
             distance += sqr(diff);
+
             return distance;
        }
+
        }
        return distance;
+
    }
      }
+
 
      private final double pointDist(double[] location, int index){
+
    public static class WeightedManhattan<T> extends KDTree<T> {
        double[] arr = pointLocations.array;
+
        private double[] weights;
        double distance = 0;
+
 
        int offset = (index+1)*_dimensions;
+
        public WeightedManhattan(int dims) {
        for(int i = _dimensions; i-- > 0;)
+
            super(dims);
            distance += sqr(arr[--offset] - location[i]);
+
            weights = new double[dims];
        return distance;
+
            for (int i = 0; i < dims; i++)
      }
+
                weights[i] = 1.0;
      //returns number of points added to results
+
        }
      private int search(double[] searchLocation, ArrayDeque<Node> stack, PrioQueue<T> results){
+
 
     
+
        public void setWeights(double[] newWeights) {
        if(pointLocations == null){
+
            weights = newWeights;
            if(searchLocation[splitDim] < splitVal){
+
        }
              stack.push(more);
+
 
              stack.push(less);//less will be popped first
+
        double pointRectDist(int offset, final double[] location) {
 +
            offset *= (2 * super._dimensions);
 +
            double distance = 0;
 +
            final double[] array = super.nodeMinMaxBounds.array;
 +
            for (int i = 0; i < location.length; i++, offset += 2) {
 +
 
 +
                double diff = 0;
 +
                double bv = array[offset];
 +
                double lv = location[i];
 +
                if (bv > lv)
 +
                    diff = bv - lv;
 +
                else {
 +
                    bv = array[offset + 1];
 +
                    if (lv > bv)
 +
                        diff = lv - bv;
 +
                }
 +
                distance += (diff) * weights[i];
 
             }
 
             }
             else{
+
             return distance;
              stack.push(less);
+
        }
              stack.push(more);//more will be popped first
+
 
 +
        double pointDist(double[] arr, double[] location, int index) {
 +
            double distance = 0;
 +
            int offset = (index + 1) * super._dimensions;
 +
 
 +
            for (int i = super._dimensions; i-- > 0;) {
 +
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
 
             }
 
             }
             return 0;
+
             return distance;
        }
+
        }
        else{
+
    }
             double minD = results.peekPrio();
+
 
 +
    public static class SearchResult<S> {
 +
        public double distance;
 +
        public S payload;
 +
 
 +
        SearchResult(double dist, S load) {
 +
            distance = dist;
 +
            payload = load;
 +
        }
 +
    }
 +
 
 +
    private class Node {
 +
 
 +
        // for accessing bounding box data
 +
        // - if trees weren't so unbalanced might be better to use an implicit
 +
        // heap?
 +
        int index;
 +
 
 +
        // keep track of size of subtree
 +
        int entries;
 +
 
 +
        // leaf
 +
        ContiguousDoubleArrayList pointLocations;
 +
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
 +
 
 +
        // stem
 +
        // Node less, more;
 +
        int lessIndex, moreIndex;
 +
        int splitDim;
 +
        double splitVal;
 +
 
 +
        Node() {
 +
             this(new double[_bucketSize * _dimensions]);
 +
        }
 +
 
 +
        Node(double[] pointMemory) {
 +
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
 +
            index = _nodes++;
 +
            nodeList.add(this);
 +
            nodeMinMaxBounds.add(bounds_template);
 +
        }
 +
 
 +
        void search(double[] searchLocation, IntStack stack) {
 +
            if (searchLocation[splitDim] < splitVal)
 +
                stack.push(moreIndex).push(lessIndex); // less will be popped
 +
            // first
 +
            else
 +
                stack.push(lessIndex).push(moreIndex); // more will be popped
 +
            // first
 +
        }
 +
 
 +
        // returns number of points added to results
 +
        int search(double[] searchLocation, PrioQueue<T> results) {
 
             int updated = 0;
 
             int updated = 0;
             for(int j = entries; j-- > 0;){
+
             for (int j = entries; j-- > 0;) {
              double neg_distance = -pointDist(searchLocation,j);
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
              if(minD < neg_distance){
+
                if (results.peekPrio() > distance) {
                  results.addNoGrow(pointPayloads.get(j),neg_distance);
+
                    updated++;
                  minD = results.peekPrio();
+
                    results.addNoGrow(pointPayloads.get(j), distance);
                  updated++;
+
                }
              }
 
 
             }
 
             }
 
             return updated;
 
             return updated;
        }
+
        }
      }
+
 
      private void expandBounds(double[] location){
+
        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {
        entries++;
+
 
        int mio = index*2*_dimensions;
+
            for (int j = entries; j-- > 0;) {
        for(int i = 0; i < _dimensions;i++){
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
            nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio++],location[i]);
+
                if (radius >= distance) {
            nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio++],location[i]);
+
                    results.add(pointPayloads.get(j));
        }
+
                }
      }
+
            }
 
+
        }
      private int add(double[] location, T load){
+
 
        pointLocations.add(location);
+
        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {
        pointPayloads.add(load);
+
 
        return entries;
+
            for (int j = entries; j-- > 0;)
      }
+
                if (contains(pointLocations.array, mins, maxs, j))
      private void split(){
+
                    results.add(pointPayloads.get(j));
        int offset = index*2*_dimensions;
+
 
       
+
        }
        double diff = 0;
+
 
        for(int i = 0; i < _dimensions; i++){
+
        void expandBounds(double[] location) {
            double min = nodeMinMaxBounds.array[offset];
+
            entries++;
            double max = nodeMinMaxBounds.array[offset+1];
+
            int mio = index * 2 * _dimensions;
            if(max-min>diff){
+
            for (int i = 0; i < _dimensions; i++) {
              double mean = 0;
+
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
              for(int j = 0; j < entries; j++)
+
                mio++;
                  mean += pointLocations.array[i+_dimensions*j];
+
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
           
+
                mio++;
              mean = mean/entries;
+
            }
              double varianceSum = 0;
+
        }
           
+
 
              for(int j = 0; j < entries; j++)
+
        int add(double[] location, T load) {
                  varianceSum += sqr(mean-pointLocations.array[i+_dimensions*j]);
+
            pointLocations.add(location);
           
+
            pointPayloads.add(load);
              if(varianceSum>diff*entries){
+
            return entries;
                  diff = varianceSum/entries;
+
        }
                  splitVal = mean;
+
 
             
+
        void split() {
                  splitDim = i;
+
            int offset = index * 2 * _dimensions;
              }
+
 
 +
            double diff = 0;
 +
            for (int i = 0; i < _dimensions; i++) {
 +
                double min = nodeMinMaxBounds.array[offset];
 +
                double max = nodeMinMaxBounds.array[offset + 1];
 +
                if (max - min > diff) {
 +
                    double mean = 0;
 +
                    for (int j = 0; j < entries; j++)
 +
                        mean += pointLocations.array[i + _dimensions * j];
 +
 
 +
                    mean = mean / entries;
 +
                    double varianceSum = 0;
 +
 
 +
                    for (int j = 0; j < entries; j++)
 +
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);
 +
 
 +
                    if (varianceSum > diff * entries) {
 +
                        diff = varianceSum / entries;
 +
                        splitVal = mean;
 +
 
 +
                        splitDim = i;
 +
                    }
 +
                }
 +
                offset += 2;
 +
            }
 +
 
 +
            // kill all the nasties
 +
            if (splitVal == Double.POSITIVE_INFINITY)
 +
                splitVal = Double.MAX_VALUE;
 +
            else if (splitVal == Double.NEGATIVE_INFINITY)
 +
                splitVal = Double.MIN_VALUE;
 +
            else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
 +
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];
 +
 
 +
            Node less = new Node(mem_recycle); // recycle that memory!
 +
            Node more = new Node();
 +
            lessIndex = less.index;
 +
            moreIndex = more.index;
 +
 
 +
            // reduce garbage by factor of _bucketSize by recycling this array
 +
            double[] pointLocation = new double[_dimensions];
 +
            for (int i = 0; i < entries; i++) {
 +
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
 +
                T load = pointPayloads.get(i);
 +
 
 +
                if (pointLocation[splitDim] < splitVal) {
 +
                    less.expandBounds(pointLocation);
 +
                    less.add(pointLocation, load);
 +
                }
 +
                else {
 +
                    more.expandBounds(pointLocation);
 +
                    more.add(pointLocation, load);
 +
                }
 
             }
 
             }
             offset += 2;
+
             if (less.entries * more.entries == 0) {
        }
+
                // one of them was 0, so the split was worthless. throw it away.
     
+
                _nodes -= 2; // recall that bounds memory
        //kill all the nasties
+
                nodeList.remove(moreIndex);
        if(splitVal == Double.POSITIVE_INFINITY)
+
                nodeList.remove(lessIndex);
            splitVal = Double.MAX_VALUE;
 
        else if(splitVal == Double.NEGATIVE_INFINITY)
 
            splitVal = Double.MIN_VALUE;
 
        else if(splitVal == nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim + 1])
 
            splitVal = nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim]; 
 
     
 
        less = new Node(mem_recycle);//recycle that memory!
 
        more = new Node();
 
       
 
        //reduce garbage by factor of _bucketSize by recycling this array
 
        double[] pointLocation = new double[_dimensions];
 
        for(int i = 0; i < entries; i++){
 
            System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);
 
            T load = pointPayloads.get(i);
 
       
 
            if(pointLocation[splitDim] < splitVal){
 
              less.expandBounds(pointLocation);
 
              less.add(pointLocation,load);
 
 
             }
 
             }
             else{
+
             else {
              more.expandBounds(pointLocation);  
+
 
              more.add(pointLocation,load);
+
                // we won't be needing that now, so keep it for the next split
 +
                // to reduce garbage
 +
                mem_recycle = pointLocations.array;
 +
 
 +
                pointLocations = null;
 +
 
 +
                pointPayloads.clear();
 +
                pointPayloads = null;
 
             }
 
             }
        }
+
        }
        if(less.entries*more.entries == 0){
+
 
        //one of them was 0, so the split was worthless. throw it away.
+
    }
            less = null;
 
            more = null;
 
        }
 
        else{
 
       
 
        //we won't be needing that now, so keep it for the next split to reduce garbage
 
            mem_recycle = pointLocations.array;
 
       
 
            pointLocations = null;
 
       
 
            pointPayloads.clear();
 
            pointPayloads = null;
 
        }
 
      }
 
 
 
  }
 
  
 +
    // NB! This Priority Queue keeps things with the LOWEST priority.
 +
    // If you want highest priority items kept, negate your values
 +
    private static class PrioQueue<S> {
  
  private static class ContiguousDoubleArrayList{
+
        Object[] elements;
      double[] array;
+
        double[] priorities;
      int size;
+
        private double minPrio;
      ContiguousDoubleArrayList(){
+
        private int size;
        this(300);
 
      }
 
      ContiguousDoubleArrayList(int size){
 
        this(new double[size]);
 
      }
 
      ContiguousDoubleArrayList(double[] data){
 
        array = data;
 
      }
 
      ContiguousDoubleArrayList add(double[] da){
 
        if(size + da.length > array.length){
 
            array = Arrays.copyOf(array,(array.length+da.length)*2);
 
            //System.out.println("Doubling!");
 
        }
 
       
 
        System.arraycopy(da,0,array,size,da.length);
 
        size += da.length;
 
        return this;
 
      }
 
  }
 
  
  private static final double sqr(double d){
+
        PrioQueue(int size, boolean prefill) {
      return d*d;}
+
            elements = new Object[size];
 +
            priorities = new double[size];
 +
            Arrays.fill(priorities, Double.POSITIVE_INFINITY);
 +
            if (prefill) {
 +
                minPrio = Double.POSITIVE_INFINITY;
 +
                this.size = size;
 +
            }
 +
        }
 +
 
 +
        // uses O(log(n)) comparisons and one big shift of size O(N)
 +
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT
 +
 
 +
        void addNoGrow(S value, double priority) {
 +
            int index = searchFor(priority);
 +
            int nextIndex = index + 1;
 +
            int length = size - nextIndex;
 +
            System.arraycopy(elements, index, elements, nextIndex, length);
 +
            System.arraycopy(priorities, index, priorities, nextIndex, length);
 +
            elements[index] = value;
 +
            priorities[index] = priority;
 +
 
 +
            minPrio = priorities[size - 1];
 +
        }
 +
 
 +
        int searchFor(double priority) {
 +
            int i = size - 1;
 +
            int j = 0;
 +
            while (i >= j) {
 +
                int index = (i + j) >>> 1;
 +
                if (priorities[index] < priority)
 +
                    j = index + 1;
 +
                else
 +
                    i = index - 1;
 +
            }
 +
            return j;
 +
        }
 +
 
 +
        double peekPrio() {
 +
            return minPrio;
 +
        }
 +
    }
 +
 
 +
    private static class ContiguousDoubleArrayList {
 +
        double[] array;
 +
        int size;
 +
 
 +
        ContiguousDoubleArrayList(int size) {
 +
            this(new double[size]);
 +
        }
 +
 
 +
        ContiguousDoubleArrayList(double[] data) {
 +
            array = data;
 +
        }
 +
 
 +
        ContiguousDoubleArrayList add(double[] da) {
 +
            if (size + da.length > array.length)
 +
                array = Arrays.copyOf(array, (array.length + da.length) * 2);
 +
 
 +
            System.arraycopy(da, 0, array, size, da.length);
 +
            size += da.length;
 +
            return this;
 +
        }
 +
    }
 +
 
 +
    private static class IntStack {
 +
        int[] array;
 +
        int size;
 +
 
 +
        IntStack() {
 +
            this(64);
 +
        }
 +
 
 +
        IntStack(int size) {
 +
            this(new int[size]);
 +
        }
 +
 
 +
        IntStack(int[] data) {
 +
            array = data;
 +
        }
 +
 
 +
        IntStack push(int i) {
 +
            if (size >= array.length)
 +
                array = Arrays.copyOf(array, (array.length + 1) * 2);
 +
 
 +
            array[size++] = i;
 +
            return this;
 +
        }
 +
 
 +
        int pop() {
 +
            return array[--size];
 +
        }
 +
 
 +
        int size() {
 +
            return size;
 +
        }
 +
    }
 +
 
 +
    static final double sqr(double d) {
 +
        return d * d;
 +
    }
  
 
}
 
}
 
 
</syntaxhighlight></code>
 
</syntaxhighlight></code>

Latest revision as of 22:46, 20 September 2017

Latest version is available here: [1]

A possibly outdated version is listed below:

package jk.tree;
/*
 ** KDTree.java by Julian Kent
 **
 ** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
 **
 ** Licence summary:
 ** Under this licence you are free to:
 **      Share : copy and redistribute the material in any medium or format
 **      Adapt : remix, transform, and build upon the material
 **      The licensor cannot revoke these freedoms as long as you follow the license terms.
 **
 ** Under the following terms:
 **      Attribution:
 **            You must give appropriate credit, provide a link to the license, and indicate
 **            if changes were made. You may do so in any reasonable manner, but not in any
 **            way that suggests the licensor endorses you or your use.
 **      NonCommercial:
 **            You may not use the material for commercial purposes.
 **      ShareAlike:
 **            If you remix, transform, or build upon the material, you must distribute your
 **            contributions under the same license as the original.
 **      No additional restrictions:
 **            You may not apply legal terms or technological measures that legally restrict
 **            others from doing anything the license permits.
 **
 ** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
 **
 ** For additional licencing rights (including commercial) please contact jkflying@gmail.com
 **
 */

import java.util.ArrayList;
import java.util.Arrays;

public abstract class KDTree<T> {

    // use a big bucketSize so that we have less node bounds (for more cache
    // hits) and better splits
    // if you have lots of dimensions this should be big, and if you have few small
    private static final int _bucketSize = 50;

    private final int _dimensions;
    private int _nodes;
    private final Node root;
    private final ArrayList<Node> nodeList = new ArrayList<Node>();

    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
    // time a leaf splits
    private double[] mem_recycle;

    // the starting values for bounding boxes, for easy access
    private final double[] bounds_template;

    // one big self-expanding array to keep all the node bounding boxes so that
    // they stay in cache
    // node bounds available at:
    // low: 2 * _dimensions * node.index + 2 * dim
    // high: 2 * _dimensions * node.index + 2 * dim + 1
    private final ContiguousDoubleArrayList nodeMinMaxBounds;

    private KDTree(int dimensions) {
        _dimensions = dimensions;

        // initialise this big so that it ends up in 'old' memory
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
        mem_recycle = new double[_bucketSize * dimensions];

        bounds_template = new double[2 * _dimensions];
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
            bounds_template[i] = Double.POSITIVE_INFINITY;

        // and.... start!
        root = new Node();
    }

    public int nodes() {
        return _nodes;
    }

    public int size() {
        return root.entries;
    }

    public int addPoint(double[] location, T payload) {

        Node addNode = root;
        // Do a Depth First Search to find the Node where 'location' should be
        // stored
        while (addNode.pointLocations == null) {
            addNode.expandBounds(location);
            if (location[addNode.splitDim] < addNode.splitVal)
                addNode = nodeList.get(addNode.lessIndex);
            else
                addNode = nodeList.get(addNode.moreIndex);
        }
        addNode.expandBounds(location);

        int nodeSize = addNode.add(location, payload);

        if (nodeSize % _bucketSize == 0)
            // try splitting again once every time the node passes a _bucketSize
            // multiple
            // in case it is full of points of the same location and won't split
            addNode.split();

        return root.entries;
    }

    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {

        K = Math.min(K, size());

        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);

        if (K > 0) {
            IntStack stack = new IntStack();
            PrioQueue<T> results = new PrioQueue<T>(K, true);

            stack.push(root.index);

            int added = 0;

            while (stack.size() > 0) {
                int nodeIndex = stack.pop();
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
                    Node node = nodeList.get(nodeIndex);
                    if (node.pointLocations == null)
                        node.search(searchLocation, stack);
                    else
                        added += node.search(searchLocation, results);
                }
            }

            double[] priorities = results.priorities;
            Object[] elements = results.elements;
            for (int i = 0; i < K; i++) { // forward (closest first)
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
                returnResults.add(s);
            }
        }
        return returnResults;
    }

    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchBall(searchLocation, radius, results);
            }
        }
        return results;
    }

    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (overlaps(mins, maxs, nodeIndex)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchRect(mins, maxs, results);
            }
        }
        return results;

    }

    abstract double pointRectDist(int offset, final double[] location);

    abstract double pointDist(double[] arr, double[] location, int index);

    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {

        int offset = (index + 1) * mins.length;

        for (int i = mins.length; i-- > 0;) {
            double d = arr[--offset];
            if (mins[i] > d | d > maxs[i])
                return false;
        }
        return true;
    }

    boolean overlaps(double[] mins, double[] maxs, int offset) {
        offset *= (2 * maxs.length);
        final double[] array = nodeMinMaxBounds.array;
        for (int i = 0; i < maxs.length; i++, offset += 2) {
            double bmin = array[offset], bmax = array[offset + 1];
            if (mins[i] > bmax | maxs[i] < bmin)
                return false;
        }

        return true;
    }

    public static class Euclidean<T> extends KDTree<T> {
        public Euclidean(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += sqr(diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += sqr(arr[--offset] - location[i]);
            }
            return distance;
        }

    }

    public static class Manhattan<T> extends KDTree<T> {
        public Manhattan(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]);
            }
            return distance;
        }
    }

    public static class WeightedManhattan<T> extends KDTree<T> {
        private double[] weights;

        public WeightedManhattan(int dims) {
            super(dims);
            weights = new double[dims];
            for (int i = 0; i < dims; i++)
                weights[i] = 1.0;
        }

        public void setWeights(double[] newWeights) {
            weights = newWeights;
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff) * weights[i];
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
            }
            return distance;
        }
    }

    public static class SearchResult<S> {
        public double distance;
        public S payload;

        SearchResult(double dist, S load) {
            distance = dist;
            payload = load;
        }
    }

    private class Node {

        // for accessing bounding box data
        // - if trees weren't so unbalanced might be better to use an implicit
        // heap?
        int index;

        // keep track of size of subtree
        int entries;

        // leaf
        ContiguousDoubleArrayList pointLocations;
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);

        // stem
        // Node less, more;
        int lessIndex, moreIndex;
        int splitDim;
        double splitVal;

        Node() {
            this(new double[_bucketSize * _dimensions]);
        }

        Node(double[] pointMemory) {
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
            index = _nodes++;
            nodeList.add(this);
            nodeMinMaxBounds.add(bounds_template);
        }

        void search(double[] searchLocation, IntStack stack) {
            if (searchLocation[splitDim] < splitVal)
                stack.push(moreIndex).push(lessIndex); // less will be popped
            // first
            else
                stack.push(lessIndex).push(moreIndex); // more will be popped
            // first
        }

        // returns number of points added to results
        int search(double[] searchLocation, PrioQueue<T> results) {
            int updated = 0;
            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (results.peekPrio() > distance) {
                    updated++;
                    results.addNoGrow(pointPayloads.get(j), distance);
                }
            }
            return updated;
        }

        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {

            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (radius >= distance) {
                    results.add(pointPayloads.get(j));
                }
            }
        }

        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {

            for (int j = entries; j-- > 0;)
                if (contains(pointLocations.array, mins, maxs, j))
                    results.add(pointPayloads.get(j));

        }

        void expandBounds(double[] location) {
            entries++;
            int mio = index * 2 * _dimensions;
            for (int i = 0; i < _dimensions; i++) {
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
            }
        }

        int add(double[] location, T load) {
            pointLocations.add(location);
            pointPayloads.add(load);
            return entries;
        }

        void split() {
            int offset = index * 2 * _dimensions;

            double diff = 0;
            for (int i = 0; i < _dimensions; i++) {
                double min = nodeMinMaxBounds.array[offset];
                double max = nodeMinMaxBounds.array[offset + 1];
                if (max - min > diff) {
                    double mean = 0;
                    for (int j = 0; j < entries; j++)
                        mean += pointLocations.array[i + _dimensions * j];

                    mean = mean / entries;
                    double varianceSum = 0;

                    for (int j = 0; j < entries; j++)
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);

                    if (varianceSum > diff * entries) {
                        diff = varianceSum / entries;
                        splitVal = mean;

                        splitDim = i;
                    }
                }
                offset += 2;
            }

            // kill all the nasties
            if (splitVal == Double.POSITIVE_INFINITY)
                splitVal = Double.MAX_VALUE;
            else if (splitVal == Double.NEGATIVE_INFINITY)
                splitVal = Double.MIN_VALUE;
            else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];

            Node less = new Node(mem_recycle); // recycle that memory!
            Node more = new Node();
            lessIndex = less.index;
            moreIndex = more.index;

            // reduce garbage by factor of _bucketSize by recycling this array
            double[] pointLocation = new double[_dimensions];
            for (int i = 0; i < entries; i++) {
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
                T load = pointPayloads.get(i);

                if (pointLocation[splitDim] < splitVal) {
                    less.expandBounds(pointLocation);
                    less.add(pointLocation, load);
                }
                else {
                    more.expandBounds(pointLocation);
                    more.add(pointLocation, load);
                }
            }
            if (less.entries * more.entries == 0) {
                // one of them was 0, so the split was worthless. throw it away.
                _nodes -= 2; // recall that bounds memory
                nodeList.remove(moreIndex);
                nodeList.remove(lessIndex);
            }
            else {

                // we won't be needing that now, so keep it for the next split
                // to reduce garbage
                mem_recycle = pointLocations.array;

                pointLocations = null;

                pointPayloads.clear();
                pointPayloads = null;
            }
        }

    }

    // NB! This Priority Queue keeps things with the LOWEST priority.
    // If you want highest priority items kept, negate your values
    private static class PrioQueue<S> {

        Object[] elements;
        double[] priorities;
        private double minPrio;
        private int size;

        PrioQueue(int size, boolean prefill) {
            elements = new Object[size];
            priorities = new double[size];
            Arrays.fill(priorities, Double.POSITIVE_INFINITY);
            if (prefill) {
                minPrio = Double.POSITIVE_INFINITY;
                this.size = size;
            }
        }

        // uses O(log(n)) comparisons and one big shift of size O(N)
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT

        void addNoGrow(S value, double priority) {
            int index = searchFor(priority);
            int nextIndex = index + 1;
            int length = size - nextIndex;
            System.arraycopy(elements, index, elements, nextIndex, length);
            System.arraycopy(priorities, index, priorities, nextIndex, length);
            elements[index] = value;
            priorities[index] = priority;

            minPrio = priorities[size - 1];
        }

        int searchFor(double priority) {
            int i = size - 1;
            int j = 0;
            while (i >= j) {
                int index = (i + j) >>> 1;
                if (priorities[index] < priority)
                    j = index + 1;
                else
                    i = index - 1;
            }
            return j;
        }

        double peekPrio() {
            return minPrio;
        }
    }

    private static class ContiguousDoubleArrayList {
        double[] array;
        int size;

        ContiguousDoubleArrayList(int size) {
            this(new double[size]);
        }

        ContiguousDoubleArrayList(double[] data) {
            array = data;
        }

        ContiguousDoubleArrayList add(double[] da) {
            if (size + da.length > array.length)
                array = Arrays.copyOf(array, (array.length + da.length) * 2);

            System.arraycopy(da, 0, array, size, da.length);
            size += da.length;
            return this;
        }
    }

    private static class IntStack {
        int[] array;
        int size;

        IntStack() {
            this(64);
        }

        IntStack(int size) {
            this(new int[size]);
        }

        IntStack(int[] data) {
            array = data;
        }

        IntStack push(int i) {
            if (size >= array.length)
                array = Arrays.copyOf(array, (array.length + 1) * 2);

            array[size++] = i;
            return this;
        }

        int pop() {
            return array[--size];
        }

        int size() {
            return size;
        }
    }

    static final double sqr(double d) {
        return d * d;
    }

}