Difference between revisions of "User:Skilgannon/KDTree"

From Robowiki
Jump to navigation Jump to search
(Add a weighted-manhattan distance metric for Wintermute)
m (Update, link to latest on Bitbucket)
 
Line 1: Line 1:
 +
Latest version is available here: [https://bitbucket.org/jkflying/kd-tree]
 +
 +
A possibly outdated version is listed below:
 +
 
<code><syntaxhighlight>
 
<code><syntaxhighlight>
 +
package jk.tree;
 
/*
 
/*
** KDTree.java by Julian Kent
+
** KDTree.java by Julian Kent
**
+
**
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
+
** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
**
+
**
** Licence summary:
+
** Licence summary:
** Under this licence you are free to:
+
** Under this licence you are free to:
**     Share copy and redistribute the material in any medium or format
+
**     Share : copy and redistribute the material in any medium or format
**     Adapt remix, transform, and build upon the material
+
**     Adapt : remix, transform, and build upon the material
**     The licensor cannot revoke these freedoms as long as you follow the license terms.
+
**     The licensor cannot revoke these freedoms as long as you follow the license terms.
**  
+
**
** Under the following terms:
+
** Under the following terms:
**     Attribution   — You must give appropriate credit, provide a link to the license, and indicate  
+
**     Attribution:
**                     if changes were made. You may do so in any reasonable manner, but not in any  
+
**            You must give appropriate credit, provide a link to the license, and indicate
**                     way that suggests the licensor endorses you or your use.
+
**           if changes were made. You may do so in any reasonable manner, but not in any
**     NonCommercial You may not use the material for commercial purposes.
+
**           way that suggests the licensor endorses you or your use.
**     ShareAlike   — If you remix, transform, or build upon the material, you must distribute your  
+
**     NonCommercial:
**                     contributions under the same license as the original.
+
**            You may not use the material for commercial purposes.
**     No additional restrictions  
+
**     ShareAlike:
**                   — You may not apply legal terms or technological measures that legally restrict
+
**            If you remix, transform, or build upon the material, you must distribute your
**                     others from doing anything the license permits.
+
**           contributions under the same license as the original.
**
+
**     No additional restrictions:
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
+
**           You may not apply legal terms or technological measures that legally restrict
**
+
**           others from doing anything the license permits.
** For additional licencing rights please contact jkflying@gmail.com
+
**
**
+
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
*/
+
**
+
** For additional licencing rights (including commercial) please contact jkflying@gmail.com
+
**
package jk.mega;
+
*/
  
 
import java.util.ArrayList;
 
import java.util.ArrayList;
 
import java.util.Arrays;
 
import java.util.Arrays;
+
 
public abstract class KDTree<T>{
+
public abstract class KDTree<T> {
+
 
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
+
    // use a big bucketSize so that we have less node bounds (for more cache
  private static final int _bucketSize = 50;
+
    // hits) and better splits
+
    // if you have lots of dimensions this should be big, and if you have few small
  private final int _dimensions;
+
    private static final int _bucketSize = 50;
  private int _nodes;  
+
 
  private final Node root;
+
    private final int _dimensions;
  private final ArrayList<Node> nodeList = new ArrayList<Node>();
+
    private int _nodes;
+
    private final Node root;
  //prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
+
    private final ArrayList<Node> nodeList = new ArrayList<Node>();
  private double[] mem_recycle;
+
 
+
    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
  //the starting values for bounding boxes, for easy access
+
    // time a leaf splits
  private final double[] bounds_template;
+
    private double[] mem_recycle;
+
 
  //one big self-expanding array to keep all the node bounding boxes so that they stay in cache
+
    // the starting values for bounding boxes, for easy access
  // node bounds available at:
+
    private final double[] bounds_template;
  //low: 2 * _dimensions * node.index + 2 * dim
+
 
  //high: 2 * _dimensions * node.index + 2 * dim + 1
+
    // one big self-expanding array to keep all the node bounding boxes so that
  private final ContiguousDoubleArrayList nodeMinMaxBounds;
+
    // they stay in cache
+
    // node bounds available at:
  private KDTree(int dimensions){
+
    // low: 2 * _dimensions * node.index + 2 * dim
      _dimensions = dimensions;
+
    // high: 2 * _dimensions * node.index + 2 * dim + 1
 
+
    private final ContiguousDoubleArrayList nodeMinMaxBounds;
  //initialise this big so that it ends up in 'old' memory
+
 
      nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2*_dimensions);
+
    private KDTree(int dimensions) {
      mem_recycle = new double[_bucketSize*dimensions];
+
        _dimensions = dimensions;
 
+
 
      bounds_template = new double[2*_dimensions];
+
        // initialise this big so that it ends up in 'old' memory
      Arrays.fill(bounds_template,Double.NEGATIVE_INFINITY);
+
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
      for(int i = 0, max = 2*_dimensions; i < max; i+=2)
+
        mem_recycle = new double[_bucketSize * dimensions];
        bounds_template[i] = Double.POSITIVE_INFINITY;
+
 
 
+
        bounds_template = new double[2 * _dimensions];
  //and.... start!
+
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
      root = new Node();
+
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
  }
+
            bounds_template[i] = Double.POSITIVE_INFINITY;
  public int nodes(){
+
 
      return _nodes;
+
        // and.... start!
  }
+
        root = new Node();
  public int size(){
+
    }
      return root.entries;
+
 
  }
+
    public int nodes() {
  public int addPoint(double[] location, T payload){
+
        return _nodes;
 
+
    }
      Node addNode = root;
+
 
  //Do a Depth First Search to find the Node where 'location' should be stored
+
    public int size() {
      while(addNode.pointLocations == null){
+
        return root.entries;
        addNode.expandBounds(location);
+
    }
        if(location[addNode.splitDim] < addNode.splitVal)
+
 
            addNode = nodeList.get(addNode.lessIndex);
+
    public int addPoint(double[] location, T payload) {
        else
+
 
            addNode = nodeList.get(addNode.moreIndex);
+
        Node addNode = root;
      }
+
        // Do a Depth First Search to find the Node where 'location' should be
      addNode.expandBounds(location);
+
        // stored
 
+
        while (addNode.pointLocations == null) {
      int nodeSize = addNode.add(location,payload);
+
            addNode.expandBounds(location);
 
+
            if (location[addNode.splitDim] < addNode.splitVal)
      if(nodeSize % _bucketSize == 0)
+
                addNode = nodeList.get(addNode.lessIndex);
      //try splitting again once every time the node passes a _bucketSize multiple
 
      //in case it is full of points of the same location and won't split
 
        addNode.split();
 
 
 
      return root.entries;
 
  }
 
 
 
  public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K){
 
      IntStack stack = new IntStack();
 
      PrioQueue<T> results = new PrioQueue<T>(K,true);
 
 
 
      stack.push(root.index);
 
 
 
      int added = 0;
 
 
 
      while(stack.size() > 0 ){
 
        int nodeIndex = stack.pop();
 
        if(added < K || results.peekPrio() > pointRectDist(nodeIndex,searchLocation)){
 
            Node node = nodeList.get(nodeIndex);
 
            if(node.pointLocations == null)
 
              node.search(searchLocation,stack);
 
 
             else
 
             else
              added += node.search(searchLocation,results);
+
                addNode = nodeList.get(addNode.moreIndex);
        }
+
        }
      }
+
        addNode.expandBounds(location);
 
+
 
      ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
+
        int nodeSize = addNode.add(location, payload);
      double[] priorities = results.priorities;
+
 
      Object[] elements = results.elements;
+
        if (nodeSize % _bucketSize == 0)
      for(int i = 0; i < K; i++){//forward (closest first)
+
            // try splitting again once every time the node passes a _bucketSize
        SearchResult s = new SearchResult(priorities[i],(T)elements[i]);
+
            // multiple
        returnResults.add(s);
+
            // in case it is full of points of the same location and won't split
      }
+
            addNode.split();
      return returnResults;
+
 
  }
+
        return root.entries;
 
+
    }
  public ArrayList<T> ballSearch(double[] searchLocation, double radius){
+
 
      IntStack stack = new IntStack();
+
    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {
      ArrayList<T> results = new ArrayList<T>();
+
 
 
+
        K = Math.min(K, size());
      stack.push(root.index);
+
 
 
+
        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
      while(stack.size() > 0 ){
+
 
        int nodeIndex = stack.pop();
+
        if (K > 0) {
        if(radius > pointRectDist(nodeIndex, searchLocation)){
+
            IntStack stack = new IntStack();
             Node node = nodeList.get(nodeIndex);
+
            PrioQueue<T> results = new PrioQueue<T>(K, true);
            if(node.pointLocations == null)
+
 
              stack.push(node.moreIndex).push(node.lessIndex);
+
            stack.push(root.index);
 +
 
 +
            int added = 0;
 +
 
 +
            while (stack.size() > 0) {
 +
                int nodeIndex = stack.pop();
 +
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
 +
                    Node node = nodeList.get(nodeIndex);
 +
                    if (node.pointLocations == null)
 +
                        node.search(searchLocation, stack);
 +
                    else
 +
                        added += node.search(searchLocation, results);
 +
                }
 +
            }
 +
 
 +
            double[] priorities = results.priorities;
 +
            Object[] elements = results.elements;
 +
            for (int i = 0; i < K; i++) { // forward (closest first)
 +
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
 +
                returnResults.add(s);
 +
            }
 +
        }
 +
        return returnResults;
 +
    }
 +
 
 +
    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
 +
        IntStack stack = new IntStack();
 +
        ArrayList<T> results = new ArrayList<T>();
 +
 
 +
        stack.push(root.index);
 +
 
 +
        while (stack.size() > 0) {
 +
            int nodeIndex = stack.pop();
 +
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
 +
                Node node = nodeList.get(nodeIndex);
 +
                if (node.pointLocations == null)
 +
                    stack.push(node.moreIndex).push(node.lessIndex);
 +
                else
 +
                    node.searchBall(searchLocation, radius, results);
 +
             }
 +
        }
 +
        return results;
 +
    }
 +
 
 +
    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
 +
        IntStack stack = new IntStack();
 +
        ArrayList<T> results = new ArrayList<T>();
 +
 
 +
        stack.push(root.index);
 +
 
 +
        while (stack.size() > 0) {
 +
            int nodeIndex = stack.pop();
 +
            if (overlaps(mins, maxs, nodeIndex)) {
 +
                Node node = nodeList.get(nodeIndex);
 +
                if (node.pointLocations == null)
 +
                    stack.push(node.moreIndex).push(node.lessIndex);
 +
                else
 +
                    node.searchRect(mins, maxs, results);
 +
            }
 +
        }
 +
        return results;
 +
 
 +
    }
 +
 
 +
    abstract double pointRectDist(int offset, final double[] location);
 +
 
 +
    abstract double pointDist(double[] arr, double[] location, int index);
 +
 
 +
    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {
 +
 
 +
        int offset = (index + 1) * mins.length;
 +
 
 +
        for (int i = mins.length; i-- > 0;) {
 +
            double d = arr[--offset];
 +
            if (mins[i] > d | d > maxs[i])
 +
                return false;
 +
        }
 +
        return true;
 +
    }
 +
 
 +
    boolean overlaps(double[] mins, double[] maxs, int offset) {
 +
        offset *= (2 * maxs.length);
 +
        final double[] array = nodeMinMaxBounds.array;
 +
        for (int i = 0; i < maxs.length; i++, offset += 2) {
 +
            double bmin = array[offset], bmax = array[offset + 1];
 +
            if (mins[i] > bmax | maxs[i] < bmin)
 +
                return false;
 +
        }
 +
 
 +
        return true;
 +
    }
 +
 
 +
    public static class Euclidean<T> extends KDTree<T> {
 +
        public Euclidean(int dims) {
 +
            super(dims);
 +
        }
 +
 
 +
        double pointRectDist(int offset, final double[] location) {
 +
            offset *= (2 * super._dimensions);
 +
            double distance = 0;
 +
            final double[] array = super.nodeMinMaxBounds.array;
 +
            for (int i = 0; i < location.length; i++, offset += 2) {
 +
 
 +
                double diff = 0;
 +
                double bv = array[offset];
 +
                double lv = location[i];
 +
                if (bv > lv)
 +
                    diff = bv - lv;
 +
                else {
 +
                    bv = array[offset + 1];
 +
                    if (lv > bv)
 +
                        diff = lv - bv;
 +
                }
 +
                distance += sqr(diff);
 +
            }
 +
            return distance;
 +
        }
 +
 
 +
        double pointDist(double[] arr, double[] location, int index) {
 +
            double distance = 0;
 +
            int offset = (index + 1) * super._dimensions;
 +
 
 +
            for (int i = super._dimensions; i-- > 0;) {
 +
                distance += sqr(arr[--offset] - location[i]);
 +
            }
 +
            return distance;
 +
        }
 +
 
 +
    }
 +
 
 +
    public static class Manhattan<T> extends KDTree<T> {
 +
        public Manhattan(int dims) {
 +
            super(dims);
 +
        }
 +
 
 +
        double pointRectDist(int offset, final double[] location) {
 +
            offset *= (2 * super._dimensions);
 +
            double distance = 0;
 +
            final double[] array = super.nodeMinMaxBounds.array;
 +
            for (int i = 0; i < location.length; i++, offset += 2) {
 +
 
 +
                double diff = 0;
 +
                double bv = array[offset];
 +
                double lv = location[i];
 +
                if (bv > lv)
 +
                    diff = bv - lv;
 +
                else {
 +
                    bv = array[offset + 1];
 +
                    if (lv > bv)
 +
                        diff = lv - bv;
 +
                }
 +
                distance += (diff);
 +
            }
 +
            return distance;
 +
        }
 +
 
 +
        double pointDist(double[] arr, double[] location, int index) {
 +
            double distance = 0;
 +
            int offset = (index + 1) * super._dimensions;
 +
 
 +
            for (int i = super._dimensions; i-- > 0;) {
 +
                distance += Math.abs(arr[--offset] - location[i]);
 +
            }
 +
            return distance;
 +
        }
 +
    }
 +
 
 +
    public static class WeightedManhattan<T> extends KDTree<T> {
 +
        private double[] weights;
 +
 
 +
        public WeightedManhattan(int dims) {
 +
            super(dims);
 +
            weights = new double[dims];
 +
            for (int i = 0; i < dims; i++)
 +
                weights[i] = 1.0;
 +
        }
 +
 
 +
        public void setWeights(double[] newWeights) {
 +
            weights = newWeights;
 +
        }
 +
 
 +
        double pointRectDist(int offset, final double[] location) {
 +
            offset *= (2 * super._dimensions);
 +
            double distance = 0;
 +
            final double[] array = super.nodeMinMaxBounds.array;
 +
            for (int i = 0; i < location.length; i++, offset += 2) {
 +
 
 +
                double diff = 0;
 +
                double bv = array[offset];
 +
                double lv = location[i];
 +
                if (bv > lv)
 +
                    diff = bv - lv;
 +
                else {
 +
                    bv = array[offset + 1];
 +
                    if (lv > bv)
 +
                        diff = lv - bv;
 +
                }
 +
                distance += (diff) * weights[i];
 +
            }
 +
            return distance;
 +
        }
 +
 
 +
        double pointDist(double[] arr, double[] location, int index) {
 +
            double distance = 0;
 +
            int offset = (index + 1) * super._dimensions;
 +
 
 +
            for (int i = super._dimensions; i-- > 0;) {
 +
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
 +
            }
 +
            return distance;
 +
        }
 +
    }
 +
 
 +
    public static class SearchResult<S> {
 +
        public double distance;
 +
        public S payload;
 +
 
 +
        SearchResult(double dist, S load) {
 +
            distance = dist;
 +
            payload = load;
 +
        }
 +
    }
 +
 
 +
    private class Node {
 +
 
 +
        // for accessing bounding box data
 +
        // - if trees weren't so unbalanced might be better to use an implicit
 +
        // heap?
 +
        int index;
 +
 
 +
        // keep track of size of subtree
 +
        int entries;
 +
 
 +
        // leaf
 +
        ContiguousDoubleArrayList pointLocations;
 +
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
 +
 
 +
        // stem
 +
        // Node less, more;
 +
        int lessIndex, moreIndex;
 +
        int splitDim;
 +
        double splitVal;
 +
 
 +
        Node() {
 +
            this(new double[_bucketSize * _dimensions]);
 +
        }
 +
 
 +
        Node(double[] pointMemory) {
 +
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
 +
            index = _nodes++;
 +
            nodeList.add(this);
 +
            nodeMinMaxBounds.add(bounds_template);
 +
        }
 +
 
 +
        void search(double[] searchLocation, IntStack stack) {
 +
            if (searchLocation[splitDim] < splitVal)
 +
                stack.push(moreIndex).push(lessIndex); // less will be popped
 +
            // first
 
             else
 
             else
              node.searchBall(searchLocation, radius, results);
+
                stack.push(lessIndex).push(moreIndex); // more will be popped
        }
+
             // first
      }
+
        }
      return results;
+
 
  }
+
        // returns number of points added to results
  public ArrayList<T> rectSearch(double[] mins, double[] maxs){
+
        int search(double[] searchLocation, PrioQueue<T> results) {
      IntStack stack = new IntStack();
+
            int updated = 0;
      ArrayList<T> results = new ArrayList<T>();
+
            for (int j = entries; j-- > 0;) {
 
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
      stack.push(root.index);
+
                if (results.peekPrio() > distance) {
 
+
                    updated++;
      while(stack.size() > 0 ){
+
                    results.addNoGrow(pointPayloads.get(j), distance);
        int nodeIndex = stack.pop();
+
                }
        if(overlaps(mins,maxs,nodeIndex)){
+
            }
            Node node = nodeList.get(nodeIndex);
+
            return updated;
            if(node.pointLocations == null)
+
        }
              stack.push(node.moreIndex).push(node.lessIndex);
+
 
             else
+
        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {
              node.searchRect(mins, maxs, results);
 
        }
 
      }
 
      return results;
 
 
 
  }
 
 
 
 
 
  abstract double pointRectDist(int offset, final double[] location);
 
  abstract double pointDist(double[] arr, double[] location, int index);
 
 
 
  boolean contains(double[] arr, double[] mins, double[] maxs, int index){
 
 
 
      int offset = (index+1)*mins.length;
 
       
 
      for(int i = mins.length; i-- > 0 ;){
 
        double d = arr[--offset];
 
        if(mins[i] > d | d > maxs[i])
 
            return false;
 
      }
 
      return true;
 
  }
 
 
 
  boolean overlaps(double[] mins, double[] maxs, int offset){
 
      offset *= (2*maxs.length);
 
      final double[] array = nodeMinMaxBounds.array;
 
      for(int i = 0; i < maxs.length; i++,offset += 2){
 
        double bmin = array[offset], bmax = array[offset+1];
 
        if(mins[i] > bmax | maxs[i] < bmin)
 
            return false;
 
      }
 
 
 
      return true;
 
  }
 
 
 
  
  public static class Euclidean<T> extends KDTree<T>{
+
            for (int j = entries; j-- > 0;) {
      public Euclidean(int dims){
+
                double distance = pointDist(pointLocations.array, searchLocation, j);
        super(dims);
+
                if (radius >= distance) {
      }
+
                    results.add(pointPayloads.get(j));
      double pointRectDist(int offset, final double[] location){
+
                }
        offset *= (2*super._dimensions);
 
        double distance=0;
 
        final double[] array = super.nodeMinMaxBounds.array;
 
        for(int i = 0; i < location.length; i++,offset += 2){
 
       
 
            double diff = 0;
 
            double bv = array[offset];
 
            double lv = location[i];
 
            if(bv > lv)
 
              diff = bv-lv;
 
            else{
 
              bv=array[offset+1];
 
              if(lv>bv)
 
                  diff = lv-bv;
 
 
             }
 
             }
            distance += sqr(diff);
+
        }
        }
+
 
        return distance;
+
        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {
      }
+
 
      double pointDist(double[] arr, double[] location, int index){
+
            for (int j = entries; j-- > 0;)
        double distance = 0;
+
                if (contains(pointLocations.array, mins, maxs, j))
        int offset = (index+1)*super._dimensions;
+
                    results.add(pointPayloads.get(j));
       
+
 
        for(int i = super._dimensions; i-- > 0 ;){
+
        }
            distance += sqr(arr[--offset] - location[i]);
+
 
        }
+
        void expandBounds(double[] location) {
        return distance;
+
            entries++;
      }
+
            int mio = index * 2 * _dimensions;
 
+
            for (int i = 0; i < _dimensions; i++) {
  }
+
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
  public static class Manhattan<T> extends KDTree<T>{
+
                mio++;
      public Manhattan(int dims){
+
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
        super(dims);
+
                mio++;
      }
 
      double pointRectDist(int offset, final double[] location){
 
        offset *= (2*super._dimensions);
 
        double distance=0;
 
        final double[] array = super.nodeMinMaxBounds.array;
 
        for(int i = 0; i < location.length; i++,offset += 2){
 
       
 
            double diff = 0;
 
            double bv = array[offset];
 
            double lv = location[i];
 
            if(bv > lv)
 
              diff = bv-lv;
 
            else{
 
              bv=array[offset+1];
 
              if(lv>bv)
 
                  diff = lv-bv;
 
 
             }
 
             }
            distance += (diff);
+
        }
        }
+
 
        return distance;
+
        int add(double[] location, T load) {
      }
+
            pointLocations.add(location);
      double pointDist(double[] arr, double[] location, int index){
+
             pointPayloads.add(load);
        double distance = 0;
+
            return entries;
        int offset = (index+1)*super._dimensions;
+
        }
       
+
 
        for(int i = super._dimensions; i-- > 0 ;){
+
        void split() {
             distance += Math.abs(arr[--offset] - location[i]);
+
            int offset = index * 2 * _dimensions;
        }
+
 
        return distance;
 
      }
 
  }
 
  public static class WeightedManhattan<T> extends KDTree<T>{
 
      double[] weights;
 
      public WeightedManhattan(int dims){
 
        super(dims);
 
      }
 
      public void setWeights(double[] newWeights){
 
        weights = newWeights;
 
      }
 
      double pointRectDist(int offset, final double[] location){
 
        offset *= (2*super._dimensions);
 
        double distance=0;
 
        final double[] array = super.nodeMinMaxBounds.array;
 
        for(int i = 0; i < location.length; i++,offset += 2){
 
       
 
 
             double diff = 0;
 
             double diff = 0;
             double bv = array[offset];
+
             for (int i = 0; i < _dimensions; i++) {
            double lv = location[i];
+
                double min = nodeMinMaxBounds.array[offset];
            if(bv > lv)
+
                double max = nodeMinMaxBounds.array[offset + 1];
              diff = bv-lv;
+
                if (max - min > diff) {
            else{
+
                    double mean = 0;
              bv=array[offset+1];
+
                    for (int j = 0; j < entries; j++)
              if(lv>bv)
+
                        mean += pointLocations.array[i + _dimensions * j];
                  diff = lv-bv;
+
 
 +
                    mean = mean / entries;
 +
                    double varianceSum = 0;
 +
 
 +
                    for (int j = 0; j < entries; j++)
 +
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);
 +
 
 +
                    if (varianceSum > diff * entries) {
 +
                        diff = varianceSum / entries;
 +
                        splitVal = mean;
 +
 
 +
                        splitDim = i;
 +
                    }
 +
                }
 +
                offset += 2;
 
             }
 
             }
             distance += (diff)*weights[i];
+
 
        }
+
             // kill all the nasties
        return distance;
+
            if (splitVal == Double.POSITIVE_INFINITY)
      }
+
                splitVal = Double.MAX_VALUE;
      double pointDist(double[] arr, double[] location, int index){
+
             else if (splitVal == Double.NEGATIVE_INFINITY)
        double distance = 0;
+
                splitVal = Double.MIN_VALUE;
        int offset = (index+1)*super._dimensions;
+
             else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
       
+
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];
        for(int i = super._dimensions; i-- > 0 ;){
+
 
             distance += Math.abs(arr[--offset] - location[i])*weights[i];
+
            Node less = new Node(mem_recycle); // recycle that memory!
        }
+
             Node more = new Node();
        return distance;
+
             lessIndex = less.index;
      }
+
             moreIndex = more.index;
  }
+
 
+
            // reduce garbage by factor of _bucketSize by recycling this array
    //NB! This Priority Queue keeps things with the LOWEST priority.
+
            double[] pointLocation = new double[_dimensions];
//If you want highest priority items kept, negate your values
+
            for (int i = 0; i < entries; i++) {
  private static class PrioQueue<S>{
+
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
 
+
                T load = pointPayloads.get(i);
      Object[] elements;
+
 
      double[] priorities;
+
                if (pointLocation[splitDim] < splitVal) {
      private double minPrio;
+
                    less.expandBounds(pointLocation);
      private int size;
+
                    less.add(pointLocation, load);
 
+
                }
      PrioQueue(int size, boolean prefill){
+
                else {
        elements = new Object[size];
+
                    more.expandBounds(pointLocation);
        priorities = new double[size];
+
                    more.add(pointLocation, load);
        Arrays.fill(priorities,Double.POSITIVE_INFINITY);
+
                }
        if(prefill){
 
            minPrio = Double.POSITIVE_INFINITY;
 
             this.size = size;
 
        }
 
      }
 
      //uses O(log(n)) comparisons and one big shift of size O(N)
 
      //and is MUCH simpler than a heap --> faster on small sets, faster JIT
 
 
 
      void addNoGrow(S value, double priority){
 
        int index = searchFor(priority);
 
        int nextIndex = index + 1;
 
        int length = size - index - 1;
 
        System.arraycopy(elements,index,elements,nextIndex,length);
 
        System.arraycopy(priorities,index,priorities,nextIndex,length);
 
        elements[index]=value;
 
        priorities[index]=priority;
 
     
 
        minPrio = priorities[size-1];
 
      }
 
 
 
      int searchFor(double priority){
 
        int i = size-1;
 
        int j = 0; 
 
        while(i>=j){
 
             int index = (i+j)>>>1;
 
             if( priorities[index] < priority)
 
              j = index+1;
 
             else
 
              i = index-1;
 
        }
 
        return j;
 
      }
 
      double peekPrio(){
 
        return minPrio;
 
      }
 
  }
 
 
  public static class SearchResult<S>{
 
      public double distance;
 
      public S payload;
 
      SearchResult(double dist, S load){
 
        distance = dist;
 
        payload = load;
 
      }
 
  }
 
 
  private class Node {
 
 
 
  //for accessing bounding box data
 
  // - if trees weren't so unbalanced might be better to use an implicit heap?
 
      int index;
 
 
 
  //keep track of size of subtree
 
      int entries;
 
 
 
  //leaf
 
      ContiguousDoubleArrayList pointLocations ;
 
      ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
 
 
 
  //stem
 
      //Node less, more;
 
      int lessIndex, moreIndex;
 
      int splitDim;
 
      double splitVal;
 
 
 
      Node(){
 
        this(new double[_bucketSize*_dimensions]);
 
      }
 
      Node(double[] pointMemory){
 
        pointLocations = new ContiguousDoubleArrayList(pointMemory);
 
        index = _nodes++;
 
        nodeList.add(this);
 
        nodeMinMaxBounds.add(bounds_template);
 
      }
 
 
 
 
 
      void search(double[] searchLocation, IntStack stack){
 
        if(searchLocation[splitDim] < splitVal)
 
            stack.push(moreIndex).push(lessIndex);//less will be popped first
 
        else
 
            stack.push(lessIndex).push(moreIndex);//more will be popped first
 
      }
 
     
 
      //returns number of points added to results
 
      int search(double[] searchLocation, PrioQueue<T> results){
 
        int updated = 0;
 
        for(int j = entries; j-- > 0;){
 
            double distance = pointDist(pointLocations.array,searchLocation,j);
 
            if(results.peekPrio() > distance){
 
              updated++;
 
              results.addNoGrow(pointPayloads.get(j),distance);
 
 
             }
 
             }
        }
+
            if (less.entries * more.entries == 0) {
        return updated;
+
                // one of them was 0, so the split was worthless. throw it away.
      }
+
                _nodes -= 2; // recall that bounds memory
     
+
                nodeList.remove(moreIndex);
      void searchBall(double[] searchLocation, double radius, ArrayList<T> results){
+
                nodeList.remove(lessIndex);
       
 
        for(int j = entries; j-- > 0;){
 
            double distance = pointDist(pointLocations.array,searchLocation,j);
 
            if(radius >= distance){
 
              results.add(pointPayloads.get(j));
 
 
             }
 
             }
        }
+
            else {
      }
+
 
     
+
                // we won't be needing that now, so keep it for the next split
      void searchRect(double[] mins, double[] maxs, ArrayList<T> results){
+
                // to reduce garbage
     
+
                mem_recycle = pointLocations.array;
        for(int j = entries; j-- > 0;)
+
 
            if(contains(pointLocations.array,mins,maxs,j))
+
                pointLocations = null;
              results.add(pointPayloads.get(j));
+
 
     
+
                pointPayloads.clear();
      } 
+
                pointPayloads = null;
 
 
      void expandBounds(double[] location){
 
        entries++;
 
        int mio = index*2*_dimensions;
 
        for(int i = 0; i < _dimensions;i++){
 
            nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio++],location[i]);
 
            nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio++],location[i]);
 
        }
 
      }
 
 
 
      int add(double[] location, T load){
 
        pointLocations.add(location);
 
        pointPayloads.add(load);
 
        return entries;
 
      }
 
      void split(){
 
        int offset = index*2*_dimensions;
 
     
 
        double diff = 0;
 
        for(int i = 0; i < _dimensions; i++){
 
            double min = nodeMinMaxBounds.array[offset];
 
            double max = nodeMinMaxBounds.array[offset+1];
 
            if(max-min>diff){
 
              double mean = 0;
 
              for(int j = 0; j < entries; j++)
 
                  mean += pointLocations.array[i+_dimensions*j];
 
           
 
              mean = mean/entries;
 
              double varianceSum = 0;
 
           
 
              for(int j = 0; j < entries; j++)
 
                  varianceSum += sqr(mean-pointLocations.array[i+_dimensions*j]);
 
           
 
              if(varianceSum>diff*entries){
 
                  diff = varianceSum/entries;
 
                  splitVal = mean;
 
             
 
                  splitDim = i;
 
              }
 
 
             }
 
             }
            offset += 2;
+
        }
        }
+
 
     
+
    }
        //kill all the nasties
+
 
        if(splitVal == Double.POSITIVE_INFINITY)
+
    // NB! This Priority Queue keeps things with the LOWEST priority.
            splitVal = Double.MAX_VALUE;
+
    // If you want highest priority items kept, negate your values
        else if(splitVal == Double.NEGATIVE_INFINITY)
+
    private static class PrioQueue<S> {
            splitVal = Double.MIN_VALUE;
+
 
        else if(splitVal == nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim + 1])
+
        Object[] elements;
            splitVal = nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim];  
+
        double[] priorities;
     
+
        private double minPrio;
        Node less = new Node(mem_recycle);//recycle that memory!
+
        private int size;
        Node more = new Node();
+
 
        lessIndex = less.index;
+
        PrioQueue(int size, boolean prefill) {
        moreIndex = more.index;
+
            elements = new Object[size];
     
+
            priorities = new double[size];
        //reduce garbage by factor of _bucketSize by recycling this array
+
             Arrays.fill(priorities, Double.POSITIVE_INFINITY);
        double[] pointLocation = new double[_dimensions];
+
             if (prefill) {
        for(int i = 0; i < entries; i++){
+
                minPrio = Double.POSITIVE_INFINITY;
             System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);
+
                this.size = size;
            T load = pointPayloads.get(i);
 
       
 
             if(pointLocation[splitDim] < splitVal){
 
              less.expandBounds(pointLocation);
 
              less.add(pointLocation,load);
 
 
             }
 
             }
             else{
+
        }
              more.expandBounds(pointLocation);  
+
 
              more.add(pointLocation,load);
+
        // uses O(log(n)) comparisons and one big shift of size O(N)
 +
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT
 +
 
 +
        void addNoGrow(S value, double priority) {
 +
             int index = searchFor(priority);
 +
            int nextIndex = index + 1;
 +
            int length = size - nextIndex;
 +
            System.arraycopy(elements, index, elements, nextIndex, length);
 +
            System.arraycopy(priorities, index, priorities, nextIndex, length);
 +
            elements[index] = value;
 +
            priorities[index] = priority;
 +
 
 +
            minPrio = priorities[size - 1];
 +
        }
 +
 
 +
        int searchFor(double priority) {
 +
            int i = size - 1;
 +
            int j = 0;
 +
            while (i >= j) {
 +
                int index = (i + j) >>> 1;
 +
                if (priorities[index] < priority)
 +
                    j = index + 1;
 +
                else
 +
                    i = index - 1;
 
             }
 
             }
        }
+
             return j;
        if(less.entries*more.entries == 0){
+
        }
        //one of them was 0, so the split was worthless. throw it away.
+
 
             _nodes -= 2;//recall that bounds memory
+
        double peekPrio() {
            nodeList.remove(moreIndex);
+
             return minPrio;
            nodeList.remove(lessIndex);
+
        }
        }
+
    }
        else{
+
 
       
+
    private static class ContiguousDoubleArrayList {
        //we won't be needing that now, so keep it for the next split to reduce garbage
+
        double[] array;
            mem_recycle = pointLocations.array;
+
        int size;
       
+
 
            pointLocations = null;
+
        ContiguousDoubleArrayList(int size) {
       
+
            this(new double[size]);
            pointPayloads.clear();
+
        }
             pointPayloads = null;
+
 
        }
+
        ContiguousDoubleArrayList(double[] data) {
      }
+
            array = data;
 
+
        }
  }
+
 
+
        ContiguousDoubleArrayList add(double[] da) {
+
            if (size + da.length > array.length)
  private static class ContiguousDoubleArrayList{
+
                array = Arrays.copyOf(array, (array.length + da.length) * 2);
      double[] array;
+
 
      int size;
+
            System.arraycopy(da, 0, array, size, da.length);
      ContiguousDoubleArrayList(){this(300);}
+
            size += da.length;
      ContiguousDoubleArrayList(int size){this(new double[size]);}
+
            return this;
      ContiguousDoubleArrayList(double[] data){array = data;}
+
        }
     
+
    }
      ContiguousDoubleArrayList add(double[] da){
+
 
        if(size + da.length > array.length)
+
    private static class IntStack {
            array = Arrays.copyOf(array,(array.length+da.length)*2);
+
        int[] array;
       
+
        int size;
        System.arraycopy(da,0,array,size,da.length);
+
 
        size += da.length;
+
        IntStack() {
        return this;
+
            this(64);
      }
+
        }
  }
+
 
  private static class IntStack{
+
        IntStack(int size) {
      int[] array;
+
            this(new int[size]);
      int size;
+
        }
      IntStack(){this(64);}
+
 
      IntStack(int size){this(new int[size]);}
+
        IntStack(int[] data) {
      IntStack(int[] data){array = data;}
+
            array = data;
     
+
        }
      IntStack push(int i){
+
 
        if(size>= array.length)
+
        IntStack push(int i) {
            array = Arrays.copyOf(array,(array.length+1)*2);
+
            if (size >= array.length)
       
+
                array = Arrays.copyOf(array, (array.length + 1) * 2);
        array[size++] = i;
+
 
        return this;
+
            array[size++] = i;
      }
+
            return this;
      int pop(){
+
        }
        return array[--size];
+
 
      }
+
        int pop() {
      int size(){
+
            return array[--size];
        return size;
+
        }
      }
+
 
  }
+
        int size() {
+
            return size;
  static final double sqr(double d){
+
        }
      return d*d;}
+
    }
+
 
 +
    static final double sqr(double d) {
 +
        return d * d;
 +
    }
 +
 
 
}
 
}
 
</syntaxhighlight></code>
 
</syntaxhighlight></code>

Latest revision as of 21:46, 20 September 2017

Latest version is available here: [1]

A possibly outdated version is listed below:

package jk.tree;
/*
 ** KDTree.java by Julian Kent
 **
 ** Licenced under the  Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
 **
 ** Licence summary:
 ** Under this licence you are free to:
 **      Share : copy and redistribute the material in any medium or format
 **      Adapt : remix, transform, and build upon the material
 **      The licensor cannot revoke these freedoms as long as you follow the license terms.
 **
 ** Under the following terms:
 **      Attribution:
 **            You must give appropriate credit, provide a link to the license, and indicate
 **            if changes were made. You may do so in any reasonable manner, but not in any
 **            way that suggests the licensor endorses you or your use.
 **      NonCommercial:
 **            You may not use the material for commercial purposes.
 **      ShareAlike:
 **            If you remix, transform, or build upon the material, you must distribute your
 **            contributions under the same license as the original.
 **      No additional restrictions:
 **            You may not apply legal terms or technological measures that legally restrict
 **            others from doing anything the license permits.
 **
 ** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
 **
 ** For additional licencing rights (including commercial) please contact jkflying@gmail.com
 **
 */

import java.util.ArrayList;
import java.util.Arrays;

public abstract class KDTree<T> {

    // use a big bucketSize so that we have less node bounds (for more cache
    // hits) and better splits
    // if you have lots of dimensions this should be big, and if you have few small
    private static final int _bucketSize = 50;

    private final int _dimensions;
    private int _nodes;
    private final Node root;
    private final ArrayList<Node> nodeList = new ArrayList<Node>();

    // prevent GC from having to collect _bucketSize*dimensions*sizeof(double) bytes each
    // time a leaf splits
    private double[] mem_recycle;

    // the starting values for bounding boxes, for easy access
    private final double[] bounds_template;

    // one big self-expanding array to keep all the node bounding boxes so that
    // they stay in cache
    // node bounds available at:
    // low: 2 * _dimensions * node.index + 2 * dim
    // high: 2 * _dimensions * node.index + 2 * dim + 1
    private final ContiguousDoubleArrayList nodeMinMaxBounds;

    private KDTree(int dimensions) {
        _dimensions = dimensions;

        // initialise this big so that it ends up in 'old' memory
        nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2 * _dimensions);
        mem_recycle = new double[_bucketSize * dimensions];

        bounds_template = new double[2 * _dimensions];
        Arrays.fill(bounds_template, Double.NEGATIVE_INFINITY);
        for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
            bounds_template[i] = Double.POSITIVE_INFINITY;

        // and.... start!
        root = new Node();
    }

    public int nodes() {
        return _nodes;
    }

    public int size() {
        return root.entries;
    }

    public int addPoint(double[] location, T payload) {

        Node addNode = root;
        // Do a Depth First Search to find the Node where 'location' should be
        // stored
        while (addNode.pointLocations == null) {
            addNode.expandBounds(location);
            if (location[addNode.splitDim] < addNode.splitVal)
                addNode = nodeList.get(addNode.lessIndex);
            else
                addNode = nodeList.get(addNode.moreIndex);
        }
        addNode.expandBounds(location);

        int nodeSize = addNode.add(location, payload);

        if (nodeSize % _bucketSize == 0)
            // try splitting again once every time the node passes a _bucketSize
            // multiple
            // in case it is full of points of the same location and won't split
            addNode.split();

        return root.entries;
    }

    public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K) {

        K = Math.min(K, size());

        ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);

        if (K > 0) {
            IntStack stack = new IntStack();
            PrioQueue<T> results = new PrioQueue<T>(K, true);

            stack.push(root.index);

            int added = 0;

            while (stack.size() > 0) {
                int nodeIndex = stack.pop();
                if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
                    Node node = nodeList.get(nodeIndex);
                    if (node.pointLocations == null)
                        node.search(searchLocation, stack);
                    else
                        added += node.search(searchLocation, results);
                }
            }

            double[] priorities = results.priorities;
            Object[] elements = results.elements;
            for (int i = 0; i < K; i++) { // forward (closest first)
                SearchResult<T> s = new SearchResult<T>(priorities[i], (T) elements[i]);
                returnResults.add(s);
            }
        }
        return returnResults;
    }

    public ArrayList<T> ballSearch(double[] searchLocation, double radius) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (radius > pointRectDist(nodeIndex, searchLocation)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchBall(searchLocation, radius, results);
            }
        }
        return results;
    }

    public ArrayList<T> rectSearch(double[] mins, double[] maxs) {
        IntStack stack = new IntStack();
        ArrayList<T> results = new ArrayList<T>();

        stack.push(root.index);

        while (stack.size() > 0) {
            int nodeIndex = stack.pop();
            if (overlaps(mins, maxs, nodeIndex)) {
                Node node = nodeList.get(nodeIndex);
                if (node.pointLocations == null)
                    stack.push(node.moreIndex).push(node.lessIndex);
                else
                    node.searchRect(mins, maxs, results);
            }
        }
        return results;

    }

    abstract double pointRectDist(int offset, final double[] location);

    abstract double pointDist(double[] arr, double[] location, int index);

    boolean contains(double[] arr, double[] mins, double[] maxs, int index) {

        int offset = (index + 1) * mins.length;

        for (int i = mins.length; i-- > 0;) {
            double d = arr[--offset];
            if (mins[i] > d | d > maxs[i])
                return false;
        }
        return true;
    }

    boolean overlaps(double[] mins, double[] maxs, int offset) {
        offset *= (2 * maxs.length);
        final double[] array = nodeMinMaxBounds.array;
        for (int i = 0; i < maxs.length; i++, offset += 2) {
            double bmin = array[offset], bmax = array[offset + 1];
            if (mins[i] > bmax | maxs[i] < bmin)
                return false;
        }

        return true;
    }

    public static class Euclidean<T> extends KDTree<T> {
        public Euclidean(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += sqr(diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += sqr(arr[--offset] - location[i]);
            }
            return distance;
        }

    }

    public static class Manhattan<T> extends KDTree<T> {
        public Manhattan(int dims) {
            super(dims);
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff);
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]);
            }
            return distance;
        }
    }

    public static class WeightedManhattan<T> extends KDTree<T> {
        private double[] weights;

        public WeightedManhattan(int dims) {
            super(dims);
            weights = new double[dims];
            for (int i = 0; i < dims; i++)
                weights[i] = 1.0;
        }

        public void setWeights(double[] newWeights) {
            weights = newWeights;
        }

        double pointRectDist(int offset, final double[] location) {
            offset *= (2 * super._dimensions);
            double distance = 0;
            final double[] array = super.nodeMinMaxBounds.array;
            for (int i = 0; i < location.length; i++, offset += 2) {

                double diff = 0;
                double bv = array[offset];
                double lv = location[i];
                if (bv > lv)
                    diff = bv - lv;
                else {
                    bv = array[offset + 1];
                    if (lv > bv)
                        diff = lv - bv;
                }
                distance += (diff) * weights[i];
            }
            return distance;
        }

        double pointDist(double[] arr, double[] location, int index) {
            double distance = 0;
            int offset = (index + 1) * super._dimensions;

            for (int i = super._dimensions; i-- > 0;) {
                distance += Math.abs(arr[--offset] - location[i]) * weights[i];
            }
            return distance;
        }
    }

    public static class SearchResult<S> {
        public double distance;
        public S payload;

        SearchResult(double dist, S load) {
            distance = dist;
            payload = load;
        }
    }

    private class Node {

        // for accessing bounding box data
        // - if trees weren't so unbalanced might be better to use an implicit
        // heap?
        int index;

        // keep track of size of subtree
        int entries;

        // leaf
        ContiguousDoubleArrayList pointLocations;
        ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);

        // stem
        // Node less, more;
        int lessIndex, moreIndex;
        int splitDim;
        double splitVal;

        Node() {
            this(new double[_bucketSize * _dimensions]);
        }

        Node(double[] pointMemory) {
            pointLocations = new ContiguousDoubleArrayList(pointMemory);
            index = _nodes++;
            nodeList.add(this);
            nodeMinMaxBounds.add(bounds_template);
        }

        void search(double[] searchLocation, IntStack stack) {
            if (searchLocation[splitDim] < splitVal)
                stack.push(moreIndex).push(lessIndex); // less will be popped
            // first
            else
                stack.push(lessIndex).push(moreIndex); // more will be popped
            // first
        }

        // returns number of points added to results
        int search(double[] searchLocation, PrioQueue<T> results) {
            int updated = 0;
            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (results.peekPrio() > distance) {
                    updated++;
                    results.addNoGrow(pointPayloads.get(j), distance);
                }
            }
            return updated;
        }

        void searchBall(double[] searchLocation, double radius, ArrayList<T> results) {

            for (int j = entries; j-- > 0;) {
                double distance = pointDist(pointLocations.array, searchLocation, j);
                if (radius >= distance) {
                    results.add(pointPayloads.get(j));
                }
            }
        }

        void searchRect(double[] mins, double[] maxs, ArrayList<T> results) {

            for (int j = entries; j-- > 0;)
                if (contains(pointLocations.array, mins, maxs, j))
                    results.add(pointPayloads.get(j));

        }

        void expandBounds(double[] location) {
            entries++;
            int mio = index * 2 * _dimensions;
            for (int i = 0; i < _dimensions; i++) {
                nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
                nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio], location[i]);
                mio++;
            }
        }

        int add(double[] location, T load) {
            pointLocations.add(location);
            pointPayloads.add(load);
            return entries;
        }

        void split() {
            int offset = index * 2 * _dimensions;

            double diff = 0;
            for (int i = 0; i < _dimensions; i++) {
                double min = nodeMinMaxBounds.array[offset];
                double max = nodeMinMaxBounds.array[offset + 1];
                if (max - min > diff) {
                    double mean = 0;
                    for (int j = 0; j < entries; j++)
                        mean += pointLocations.array[i + _dimensions * j];

                    mean = mean / entries;
                    double varianceSum = 0;

                    for (int j = 0; j < entries; j++)
                        varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);

                    if (varianceSum > diff * entries) {
                        diff = varianceSum / entries;
                        splitVal = mean;

                        splitDim = i;
                    }
                }
                offset += 2;
            }

            // kill all the nasties
            if (splitVal == Double.POSITIVE_INFINITY)
                splitVal = Double.MAX_VALUE;
            else if (splitVal == Double.NEGATIVE_INFINITY)
                splitVal = Double.MIN_VALUE;
            else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
                splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];

            Node less = new Node(mem_recycle); // recycle that memory!
            Node more = new Node();
            lessIndex = less.index;
            moreIndex = more.index;

            // reduce garbage by factor of _bucketSize by recycling this array
            double[] pointLocation = new double[_dimensions];
            for (int i = 0; i < entries; i++) {
                System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
                T load = pointPayloads.get(i);

                if (pointLocation[splitDim] < splitVal) {
                    less.expandBounds(pointLocation);
                    less.add(pointLocation, load);
                }
                else {
                    more.expandBounds(pointLocation);
                    more.add(pointLocation, load);
                }
            }
            if (less.entries * more.entries == 0) {
                // one of them was 0, so the split was worthless. throw it away.
                _nodes -= 2; // recall that bounds memory
                nodeList.remove(moreIndex);
                nodeList.remove(lessIndex);
            }
            else {

                // we won't be needing that now, so keep it for the next split
                // to reduce garbage
                mem_recycle = pointLocations.array;

                pointLocations = null;

                pointPayloads.clear();
                pointPayloads = null;
            }
        }

    }

    // NB! This Priority Queue keeps things with the LOWEST priority.
    // If you want highest priority items kept, negate your values
    private static class PrioQueue<S> {

        Object[] elements;
        double[] priorities;
        private double minPrio;
        private int size;

        PrioQueue(int size, boolean prefill) {
            elements = new Object[size];
            priorities = new double[size];
            Arrays.fill(priorities, Double.POSITIVE_INFINITY);
            if (prefill) {
                minPrio = Double.POSITIVE_INFINITY;
                this.size = size;
            }
        }

        // uses O(log(n)) comparisons and one big shift of size O(N)
        // and is MUCH simpler than a heap --> faster on small sets, faster JIT

        void addNoGrow(S value, double priority) {
            int index = searchFor(priority);
            int nextIndex = index + 1;
            int length = size - nextIndex;
            System.arraycopy(elements, index, elements, nextIndex, length);
            System.arraycopy(priorities, index, priorities, nextIndex, length);
            elements[index] = value;
            priorities[index] = priority;

            minPrio = priorities[size - 1];
        }

        int searchFor(double priority) {
            int i = size - 1;
            int j = 0;
            while (i >= j) {
                int index = (i + j) >>> 1;
                if (priorities[index] < priority)
                    j = index + 1;
                else
                    i = index - 1;
            }
            return j;
        }

        double peekPrio() {
            return minPrio;
        }
    }

    private static class ContiguousDoubleArrayList {
        double[] array;
        int size;

        ContiguousDoubleArrayList(int size) {
            this(new double[size]);
        }

        ContiguousDoubleArrayList(double[] data) {
            array = data;
        }

        ContiguousDoubleArrayList add(double[] da) {
            if (size + da.length > array.length)
                array = Arrays.copyOf(array, (array.length + da.length) * 2);

            System.arraycopy(da, 0, array, size, da.length);
            size += da.length;
            return this;
        }
    }

    private static class IntStack {
        int[] array;
        int size;

        IntStack() {
            this(64);
        }

        IntStack(int size) {
            this(new int[size]);
        }

        IntStack(int[] data) {
            array = data;
        }

        IntStack push(int i) {
            if (size >= array.length)
                array = Arrays.copyOf(array, (array.length + 1) * 2);

            array[size++] = i;
            return this;
        }

        int pop() {
            return array[--size];
        }

        int size() {
            return size;
        }
    }

    static final double sqr(double d) {
        return d * d;
    }

}