Difference between revisions of "User:Skilgannon/KDTree"
Jump to navigation
Jump to search
Skilgannon (talk | contribs) (2 step search, little changes) |
Skilgannon (talk | contribs) m (cleaning, reverse PrioQueue) |
||
Line 1: | Line 1: | ||
<code><syntaxhighlight> | <code><syntaxhighlight> | ||
− | |||
/* | /* | ||
** KDTree.java by Julian Kent | ** KDTree.java by Julian Kent | ||
Line 9: | Line 8: | ||
** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree | ** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree | ||
*/ | */ | ||
− | + | ||
− | + | ||
package jk.mega; | package jk.mega; | ||
import java.util.ArrayDeque; | import java.util.ArrayDeque; | ||
Line 17: | Line 16: | ||
//import ags.utils.*; | //import ags.utils.*; | ||
//import ags.utils.dataStructures.*; | //import ags.utils.dataStructures.*; | ||
− | + | ||
public class KDTree<T>{ | public class KDTree<T>{ | ||
− | + | ||
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits | //use a big bucketSize so that we have less node bounds (for more cache hits) and better splits | ||
private static final int _bucketSize = 50; | private static final int _bucketSize = 50; | ||
− | + | ||
private final int _dimensions; | private final int _dimensions; | ||
private int _nodes; | private int _nodes; | ||
− | private Node root; | + | private final Node root; |
− | + | ||
//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits | //prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits | ||
private double[] mem_recycle; | private double[] mem_recycle; | ||
− | + | ||
//the starting values for bounding boxes, for easy access | //the starting values for bounding boxes, for easy access | ||
private final double[] bounds_template; | private final double[] bounds_template; | ||
− | + | ||
//one big self-expanding array to keep all the node bounding boxes so that they stay in cache | //one big self-expanding array to keep all the node bounding boxes so that they stay in cache | ||
// node bounds available at: | // node bounds available at: | ||
//low: 2 * _dimensions * node.index + 2 * dim | //low: 2 * _dimensions * node.index + 2 * dim | ||
//high: 2 * _dimensions * node.index + 2 * dim + 1 | //high: 2 * _dimensions * node.index + 2 * dim + 1 | ||
− | private ContiguousDoubleArrayList nodeMinMaxBounds; | + | private final ContiguousDoubleArrayList nodeMinMaxBounds; |
/* | /* | ||
public static void main(String[] args){ | public static void main(String[] args){ | ||
Line 82: | Line 81: | ||
} | } | ||
long t3 = System.nanoTime(); | long t3 = System.nanoTime(); | ||
− | + | ||
long jtn = 0; | long jtn = 0; | ||
long rtn = 0; | long rtn = 0; | ||
long mjtn = 0; | long mjtn = 0; | ||
long mrtn = 0; | long mrtn = 0; | ||
− | + | ||
double dist1 = 0, dist2 = 0; | double dist1 = 0, dist2 = 0; | ||
for(int i = 0; i < testsize; i++){ | for(int i = 0; i < testsize; i++){ | ||
Line 101: | Line 100: | ||
mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6)); | mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6)); | ||
} | } | ||
− | + | ||
System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!")); | System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!")); | ||
if(Math.abs(dist1-dist2) > 1e-10){ | if(Math.abs(dist1-dist2) > 1e-10){ | ||
Line 109: | Line 108: | ||
long rts = t3 - t2; | long rts = t3 - t2; | ||
System.out.println("Iteration: " + (r+1) + "/" + iterations); | System.out.println("Iteration: " + (r+1) + "/" + iterations); | ||
− | + | ||
System.out.println("This tree add avg: " + jts/size + " ns"); | System.out.println("This tree add avg: " + jts/size + " ns"); | ||
System.out.println("Reds tree add avg: " + rts/size + " ns"); | System.out.println("Reds tree add avg: " + rts/size + " ns"); | ||
− | + | ||
System.out.println("This tree knn avg: " + jtn/testsize + " ns"); | System.out.println("This tree knn avg: " + jtn/testsize + " ns"); | ||
System.out.println("Reds tree knn avg: " + rtn/testsize + " ns"); | System.out.println("Reds tree knn avg: " + rtn/testsize + " ns"); | ||
Line 121: | Line 120: | ||
} | } | ||
// */ | // */ | ||
− | + | ||
public KDTree(int dimensions){ | public KDTree(int dimensions){ | ||
_dimensions = dimensions; | _dimensions = dimensions; | ||
Line 161: | Line 160: | ||
return root.entries; | return root.entries; | ||
} | } | ||
− | + | ||
− | + | ||
public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K){ | public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K){ | ||
ArrayDeque<Node> stack = new ArrayDeque<Node>(50); | ArrayDeque<Node> stack = new ArrayDeque<Node>(50); | ||
Line 171: | Line 170: | ||
int added = 0; | int added = 0; | ||
while(added < K ) | while(added < K ) | ||
− | added += stack. | + | added += stack.pollFirst().search(searchLocation,stack,results); |
− | + | ||
− | |||
while(stack.size() > 0 ){ | while(stack.size() > 0 ){ | ||
− | Node searchNode = stack. | + | Node searchNode = stack.pollFirst(); |
− | if( | + | if(results.peekPrio() > searchNode.pointRectDist(searchLocation)) |
searchNode.search(searchLocation,stack,results); | searchNode.search(searchLocation,stack,results); | ||
− | |||
− | |||
} | } | ||
− | + | ||
ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K); | ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K); | ||
− | for(int i = K; i | + | for(int i = 0; i < K; i++){//forward (closest first) |
− | SearchResult s = new SearchResult( | + | SearchResult s = new SearchResult(results.priorities[i],results.elements[i]); |
returnResults.add(s); | returnResults.add(s); | ||
} | } | ||
return returnResults; | return returnResults; | ||
} | } | ||
− | + | ||
− | + | ||
− | //NB! This Priority Queue keeps things with the | + | //NB! This Priority Queue keeps things with the LOWEST priority. |
− | //If you want | + | //If you want highest priority items kept, negate your values |
private static class PrioQueue<S>{ | private static class PrioQueue<S>{ | ||
Line 204: | Line 200: | ||
elements = new Object[size]; | elements = new Object[size]; | ||
priorities = new double[size]; | priorities = new double[size]; | ||
− | Arrays.fill(priorities,Double. | + | Arrays.fill(priorities,Double.POSITIVE_INFINITY); |
if(prefill){ | if(prefill){ | ||
− | minPrio = Double. | + | minPrio = Double.POSITIVE_INFINITY; |
this.size = size; | this.size = size; | ||
} | } | ||
Line 212: | Line 208: | ||
//uses O(log(n)) comparisons and one big shift of size O(N) | //uses O(log(n)) comparisons and one big shift of size O(N) | ||
//and is MUCH simpler than a heap --> faster on small sets, faster JIT | //and is MUCH simpler than a heap --> faster on small sets, faster JIT | ||
− | + | ||
void addNoGrow(S value, double priority){ | void addNoGrow(S value, double priority){ | ||
int index = searchFor(priority); | int index = searchFor(priority); | ||
Line 221: | Line 217: | ||
elements[index]=value; | elements[index]=value; | ||
priorities[index]=priority; | priorities[index]=priority; | ||
− | + | ||
minPrio = priorities[size-1]; | minPrio = priorities[size-1]; | ||
} | } | ||
Line 231: | Line 227: | ||
int index = (i+j)>>>1; | int index = (i+j)>>>1; | ||
− | if( priorities[index] | + | if( priorities[index] < priority) |
j = index+1; | j = index+1; | ||
else | else | ||
Line 259: | Line 255: | ||
System.out.println("Expanding PrioQueue to " + elements.length); | System.out.println("Expanding PrioQueue to " + elements.length); | ||
} | } | ||
− | + | ||
elements[size-1] = value; | elements[size-1] = value; | ||
priorities[size-1] = priority; | priorities[size-1] = priority; | ||
minPrio = priority; | minPrio = priority; | ||
− | + | ||
} | } | ||
− | + | ||
S pop(){ | S pop(){ | ||
Object value = elements[--size]; | Object value = elements[--size]; | ||
Line 280: | Line 276: | ||
// */ | // */ | ||
} | } | ||
− | + | ||
− | + | ||
public static class SearchResult<S>{ | public static class SearchResult<S>{ | ||
public double distance; | public double distance; | ||
Line 290: | Line 286: | ||
} | } | ||
} | } | ||
− | + | ||
private class Node { | private class Node { | ||
Line 296: | Line 292: | ||
// - if trees weren't so unbalanced might be better to use an implicit heap? | // - if trees weren't so unbalanced might be better to use an implicit heap? | ||
int index; | int index; | ||
− | + | ||
//keep track of size of subtree | //keep track of size of subtree | ||
int entries; | int entries; | ||
Line 303: | Line 299: | ||
ContiguousDoubleArrayList pointLocations ; | ContiguousDoubleArrayList pointLocations ; | ||
ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize); | ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize); | ||
− | + | ||
//stem | //stem | ||
Node less, more; | Node less, more; | ||
Line 318: | Line 314: | ||
} | } | ||
private final double pointRectDist(double[] location){ | private final double pointRectDist(double[] location){ | ||
− | int offset = (2*_dimensions)*(index | + | int offset = (2*_dimensions)*(index); |
double distance=0; | double distance=0; | ||
double[] array = nodeMinMaxBounds.array; | double[] array = nodeMinMaxBounds.array; | ||
− | for(int i = | + | for(int i = 0; i < location.length; i++,offset += 2){ |
double diff = 0; | double diff = 0; | ||
Line 333: | Line 329: | ||
diff = lv-bv; | diff = lv-bv; | ||
} | } | ||
− | distance += | + | distance += (diff*diff); |
} | } | ||
return distance; | return distance; | ||
Line 341: | Line 337: | ||
double distance = 0; | double distance = 0; | ||
int offset = (index+1)*_dimensions; | int offset = (index+1)*_dimensions; | ||
− | for(int i = _dimensions; i-- > 0;) | + | |
− | distance += | + | for(int i = _dimensions; i-- > 0 ;){ |
+ | double d; | ||
+ | distance += (d = arr[--offset] - location[i])*d; | ||
+ | } | ||
return distance; | return distance; | ||
} | } | ||
+ | |||
//returns number of points added to results | //returns number of points added to results | ||
private int search(double[] searchLocation, ArrayDeque<Node> stack, PrioQueue<T> results){ | private int search(double[] searchLocation, ArrayDeque<Node> stack, PrioQueue<T> results){ | ||
− | |||
if(pointLocations == null){ | if(pointLocations == null){ | ||
+ | |||
if(searchLocation[splitDim] < splitVal){ | if(searchLocation[splitDim] < splitVal){ | ||
− | stack. | + | stack.addFirst(more); |
− | stack. | + | stack.addFirst(less);//less will be popped first |
} | } | ||
else{ | else{ | ||
− | stack. | + | stack.addFirst(less); |
− | stack. | + | stack.addFirst(more);//more will be popped first |
} | } | ||
− | |||
} | } | ||
else{ | else{ | ||
− | |||
int updated = 0; | int updated = 0; | ||
for(int j = entries; j-- > 0;){ | for(int j = entries; j-- > 0;){ | ||
− | double | + | double distance = pointDist(searchLocation,j); |
− | if( | + | if(results.peekPrio() > distance){ |
− | |||
− | |||
updated++; | updated++; | ||
+ | results.addNoGrow(pointPayloads.get(j),distance); | ||
} | } | ||
} | } | ||
return updated; | return updated; | ||
} | } | ||
+ | return 0; | ||
} | } | ||
+ | |||
private void expandBounds(double[] location){ | private void expandBounds(double[] location){ | ||
entries++; | entries++; | ||
Line 389: | Line 388: | ||
private void split(){ | private void split(){ | ||
int offset = index*2*_dimensions; | int offset = index*2*_dimensions; | ||
− | + | ||
double diff = 0; | double diff = 0; | ||
for(int i = 0; i < _dimensions; i++){ | for(int i = 0; i < _dimensions; i++){ | ||
Line 398: | Line 397: | ||
for(int j = 0; j < entries; j++) | for(int j = 0; j < entries; j++) | ||
mean += pointLocations.array[i+_dimensions*j]; | mean += pointLocations.array[i+_dimensions*j]; | ||
− | + | ||
mean = mean/entries; | mean = mean/entries; | ||
double varianceSum = 0; | double varianceSum = 0; | ||
Line 425: | Line 424: | ||
less = new Node(mem_recycle);//recycle that memory! | less = new Node(mem_recycle);//recycle that memory! | ||
more = new Node(); | more = new Node(); | ||
− | + | ||
//reduce garbage by factor of _bucketSize by recycling this array | //reduce garbage by factor of _bucketSize by recycling this array | ||
double[] pointLocation = new double[_dimensions]; | double[] pointLocation = new double[_dimensions]; | ||
Line 445: | Line 444: | ||
less = null; | less = null; | ||
more = null; | more = null; | ||
+ | _nodes -= 2;//recall that bounds memory | ||
} | } | ||
else{ | else{ | ||
Line 459: | Line 459: | ||
} | } | ||
− | + | ||
− | + | ||
private static class ContiguousDoubleArrayList{ | private static class ContiguousDoubleArrayList{ | ||
double[] array; | double[] array; | ||
Line 478: | Line 478: | ||
//System.out.println("Doubling!"); | //System.out.println("Doubling!"); | ||
} | } | ||
− | + | ||
System.arraycopy(da,0,array,size,da.length); | System.arraycopy(da,0,array,size,da.length); | ||
size += da.length; | size += da.length; | ||
Line 484: | Line 484: | ||
} | } | ||
} | } | ||
− | + | ||
private static final double sqr(double d){ | private static final double sqr(double d){ | ||
return d*d;} | return d*d;} | ||
− | + | ||
} | } | ||
</syntaxhighlight></code> | </syntaxhighlight></code> |
Revision as of 16:49, 20 July 2013
/*
** KDTree.java by Julian Kent
** Licenced under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
** For additional licencing rights please contact jkflying@gmail.com
**
** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree
*/
package jk.mega;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
//import ags.utils.*;
//import ags.utils.dataStructures.*;
public class KDTree<T>{
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
private static final int _bucketSize = 50;
private final int _dimensions;
private int _nodes;
private final Node root;
//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
private double[] mem_recycle;
//the starting values for bounding boxes, for easy access
private final double[] bounds_template;
//one big self-expanding array to keep all the node bounding boxes so that they stay in cache
// node bounds available at:
//low: 2 * _dimensions * node.index + 2 * dim
//high: 2 * _dimensions * node.index + 2 * dim + 1
private final ContiguousDoubleArrayList nodeMinMaxBounds;
/*
public static void main(String[] args){
int dims = 1;
int size = 2000000;
int testsize = 1;
int k = 40;
int iterations = 1;
System.out.println(
"Config:\n"
+ "No JIT Warmup\n"
+ "Tested on random data.\n"
+ "Training and testing points shared across iterations.\n"
+ "Searches interleaved.");
System.out.println("Num points: " + size);
System.out.println("Num searches: " + testsize);
System.out.println("Dimensions: " + dims);
System.out.println("Num Neighbours: " + k);
System.out.println();
ArrayList<double[]> locs = new ArrayList<double[]>(size);
for(int i = 0; i < size; i++){
double[] loc = new double[dims];
for(int j = 0; j < dims; j++)
loc[j] = Math.random();
locs.add(loc);
}
ArrayList<double[]> testlocs = new ArrayList<double[]>(testsize);
for(int i = 0; i < testsize; i++){
double[] loc = new double[dims];
for(int j = 0; j < dims; j++)
loc[j] = Math.random();
testlocs.add(loc);
}
for(int r = 0; r < iterations; r++){
long t1 = System.nanoTime();
KDTree<double[]> t = new KDTree<double[]>(dims);// This tree
for(int i = 0; i < size; i++){
t.addPoint(locs.get(i),locs.get(i));
}
long t2 = System.nanoTime();
KdTree<double[]> rt = new KdTree.Euclidean<double[]>(dims,null); //Rednaxela Gen2
for(int i = 0; i < size; i++){
rt.addPoint(locs.get(i),locs.get(i));
}
long t3 = System.nanoTime();
long jtn = 0;
long rtn = 0;
long mjtn = 0;
long mrtn = 0;
double dist1 = 0, dist2 = 0;
for(int i = 0; i < testsize; i++){
long t4 = System.nanoTime();
dist1 += t.nearestNeighbours(testlocs.get(i),k).iterator().next().distance;
long t5 = System.nanoTime();
dist2 += rt.nearestNeighbor(testlocs.get(i),k,true).iterator().next().distance;
long t6 = System.nanoTime();
long t7 = System.nanoTime();
jtn += t5 - t4 - (t7 - t6);
rtn += t6 - t5 - (t7 - t6);
mjtn = Math.max(mjtn,t5 - t4 - (t7 - t6));
mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6));
}
System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!"));
if(Math.abs(dist1-dist2) > 1e-10){
System.out.println("dist1: " + dist1 + " dist2: " + dist2);
}
long jts = t2 - t1;
long rts = t3 - t2;
System.out.println("Iteration: " + (r+1) + "/" + iterations);
System.out.println("This tree add avg: " + jts/size + " ns");
System.out.println("Reds tree add avg: " + rts/size + " ns");
System.out.println("This tree knn avg: " + jtn/testsize + " ns");
System.out.println("Reds tree knn avg: " + rtn/testsize + " ns");
System.out.println("This tree knn max: " + mjtn + " ns");
System.out.println("Reds tree knn max: " + mrtn + " ns");
System.out.println();
}
}
// */
public KDTree(int dimensions){
_dimensions = dimensions;
//initialise this so that it ends up in 'old' memory
nodeMinMaxBounds = new ContiguousDoubleArrayList(512 * 1024 / 8 + 2*_dimensions);
mem_recycle = new double[_bucketSize*dimensions];
bounds_template = new double[2*_dimensions];
Arrays.fill(bounds_template,Double.NEGATIVE_INFINITY);
for(int i = 0, max = 2*_dimensions; i < max; i+=2)
bounds_template[i] = Double.POSITIVE_INFINITY;
//and.... start!
root = new Node();
}
public int nodes(){
return _nodes;
}
public int addPoint(double[] location, T payload){
Node addNode = root;
//Do a Depth First Search to find the Node where 'location' should be stored
while(addNode.pointLocations == null){
addNode.expandBounds(location);
if(location[addNode.splitDim] < addNode.splitVal)
addNode = addNode.less;
else
addNode = addNode.more;
}
addNode.expandBounds(location);
int nodeSize = addNode.add(location,payload);
if(nodeSize % _bucketSize == 0)
//try splitting again once every time the node passes a _bucketSize multiple
addNode.split();
return root.entries;
}
public ArrayList<SearchResult<T>> nearestNeighbours(double[] searchLocation, int K){
ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
PrioQueue<T> results = new PrioQueue<T>(K,true);
stack.push(root);
int added = 0;
while(added < K )
added += stack.pollFirst().search(searchLocation,stack,results);
while(stack.size() > 0 ){
Node searchNode = stack.pollFirst();
if(results.peekPrio() > searchNode.pointRectDist(searchLocation))
searchNode.search(searchLocation,stack,results);
}
ArrayList<SearchResult<T>> returnResults = new ArrayList<SearchResult<T>>(K);
for(int i = 0; i < K; i++){//forward (closest first)
SearchResult s = new SearchResult(results.priorities[i],results.elements[i]);
returnResults.add(s);
}
return returnResults;
}
//NB! This Priority Queue keeps things with the LOWEST priority.
//If you want highest priority items kept, negate your values
private static class PrioQueue<S>{
Object[] elements;
double[] priorities;
private double minPrio;
private int size;
PrioQueue(int size, boolean prefill){
elements = new Object[size];
priorities = new double[size];
Arrays.fill(priorities,Double.POSITIVE_INFINITY);
if(prefill){
minPrio = Double.POSITIVE_INFINITY;
this.size = size;
}
}
//uses O(log(n)) comparisons and one big shift of size O(N)
//and is MUCH simpler than a heap --> faster on small sets, faster JIT
void addNoGrow(S value, double priority){
int index = searchFor(priority);
int nextIndex = index + 1;
int length = size - index - 1;//remove dependancy on nextIndex
System.arraycopy(elements,index,elements,nextIndex,length);
System.arraycopy(priorities,index,priorities,nextIndex,length);
elements[index]=value;
priorities[index]=priority;
minPrio = priorities[size-1];
}
int searchFor(double priority){
int i = size-1;
int j = 0;
while(i>=j){
int index = (i+j)>>>1;
if( priorities[index] < priority)
j = index+1;
else
i = index-1;
}
return j;
}
double peekPrio(){
return minPrio;
}
/*
//Methods for using it as a priority stack - leave them out for now
void push(S value, double priority){
if(++size > elements.length){
elements = Arrays.copyOf(elements,size*2);
priorities = Arrays.copyOf(priorities,size*2);
Arrays.fill(priorities,size,size*2,Double.NEGATIVE_INFINITY);
System.out.println("Expanding PrioQueue to " + elements.length);
}
addNoGrow(value,priority);
}
void pushTop(S value, double priority){
if(++size > elements.length){
elements = Arrays.copyOf(elements,size*2);
priorities = Arrays.copyOf(priorities,size*2);
Arrays.fill(priorities,size,size*2,Double.NEGATIVE_INFINITY);
System.out.println("Expanding PrioQueue to " + elements.length);
}
elements[size-1] = value;
priorities[size-1] = priority;
minPrio = priority;
}
S pop(){
Object value = elements[--size];
priorities[size] = Double.NEGATIVE_INFINITY;
if(size == 0)
minPrio = Double.NEGATIVE_INFINITY;
else
minPrio = priorities[size-1];
return (S)value;
}
int size(){
return size-min;
}
// */
}
public static class SearchResult<S>{
public double distance;
public S payload;
SearchResult(double dist, S load){
distance = dist;
payload = load;
}
}
private class Node {
//for accessing bounding box data
// - if trees weren't so unbalanced might be better to use an implicit heap?
int index;
//keep track of size of subtree
int entries;
//leaf
ContiguousDoubleArrayList pointLocations ;
ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
//stem
Node less, more;
int splitDim;
double splitVal;
private Node(){
this(new double[_bucketSize*_dimensions]);
}
private Node(double[] pointMemory){
pointLocations = new ContiguousDoubleArrayList(pointMemory);
index = _nodes++;
nodeMinMaxBounds.add(bounds_template);
}
private final double pointRectDist(double[] location){
int offset = (2*_dimensions)*(index);
double distance=0;
double[] array = nodeMinMaxBounds.array;
for(int i = 0; i < location.length; i++,offset += 2){
double diff = 0;
double bv = array[offset];
double lv = location[i];
if(bv > lv)
diff = bv-lv;
else{
bv=array[offset+1];
if(lv>bv)
diff = lv-bv;
}
distance += (diff*diff);
}
return distance;
}
private final double pointDist(double[] location, int index){
double[] arr = pointLocations.array;
double distance = 0;
int offset = (index+1)*_dimensions;
for(int i = _dimensions; i-- > 0 ;){
double d;
distance += (d = arr[--offset] - location[i])*d;
}
return distance;
}
//returns number of points added to results
private int search(double[] searchLocation, ArrayDeque<Node> stack, PrioQueue<T> results){
if(pointLocations == null){
if(searchLocation[splitDim] < splitVal){
stack.addFirst(more);
stack.addFirst(less);//less will be popped first
}
else{
stack.addFirst(less);
stack.addFirst(more);//more will be popped first
}
}
else{
int updated = 0;
for(int j = entries; j-- > 0;){
double distance = pointDist(searchLocation,j);
if(results.peekPrio() > distance){
updated++;
results.addNoGrow(pointPayloads.get(j),distance);
}
}
return updated;
}
return 0;
}
private void expandBounds(double[] location){
entries++;
int mio = index*2*_dimensions;
for(int i = 0; i < _dimensions;i++){
nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio++],location[i]);
nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio++],location[i]);
}
}
private int add(double[] location, T load){
pointLocations.add(location);
pointPayloads.add(load);
return entries;
}
private void split(){
int offset = index*2*_dimensions;
double diff = 0;
for(int i = 0; i < _dimensions; i++){
double min = nodeMinMaxBounds.array[offset];
double max = nodeMinMaxBounds.array[offset+1];
if(max-min>diff){
double mean = 0;
for(int j = 0; j < entries; j++)
mean += pointLocations.array[i+_dimensions*j];
mean = mean/entries;
double varianceSum = 0;
for(int j = 0; j < entries; j++)
varianceSum += sqr(mean-pointLocations.array[i+_dimensions*j]);
if(varianceSum>diff*entries){
diff = varianceSum/entries;
splitVal = mean;
splitDim = i;
}
}
offset += 2;
}
//kill all the nasties
if(splitVal == Double.POSITIVE_INFINITY)
splitVal = Double.MAX_VALUE;
else if(splitVal == Double.NEGATIVE_INFINITY)
splitVal = Double.MIN_VALUE;
else if(splitVal == nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim + 1])
splitVal = nodeMinMaxBounds.array[index*2*_dimensions + 2*splitDim];
less = new Node(mem_recycle);//recycle that memory!
more = new Node();
//reduce garbage by factor of _bucketSize by recycling this array
double[] pointLocation = new double[_dimensions];
for(int i = 0; i < entries; i++){
System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);
T load = pointPayloads.get(i);
if(pointLocation[splitDim] < splitVal){
less.expandBounds(pointLocation);
less.add(pointLocation,load);
}
else{
more.expandBounds(pointLocation);
more.add(pointLocation,load);
}
}
if(less.entries*more.entries == 0){
//one of them was 0, so the split was worthless. throw it away.
less = null;
more = null;
_nodes -= 2;//recall that bounds memory
}
else{
//we won't be needing that now, so keep it for the next split to reduce garbage
mem_recycle = pointLocations.array;
pointLocations = null;
pointPayloads.clear();
pointPayloads = null;
}
}
}
private static class ContiguousDoubleArrayList{
double[] array;
int size;
ContiguousDoubleArrayList(){
this(300);
}
ContiguousDoubleArrayList(int size){
this(new double[size]);
}
ContiguousDoubleArrayList(double[] data){
array = data;
}
ContiguousDoubleArrayList add(double[] da){
if(size + da.length > array.length){
array = Arrays.copyOf(array,(array.length+da.length)*2);
//System.out.println("Doubling!");
}
System.arraycopy(da,0,array,size,da.length);
size += da.length;
return this;
}
}
private static final double sqr(double d){
return d*d;}
}