Difference between revisions of "Help:Help/Nat/Kd-Tree"

From Robowiki
Jump to navigation Jump to search
(ok, I found it now.)
(latest code that I've problem)
 
Line 1: Line 1:
 +
Here is my code that I have problem state in [[Talk:Kd-Tree]]. My ''M'' class can be found at [[User:Nat/Free code]].
 +
 +
== Tree ==
 +
<pre>
 +
package nat.tree;
 +
 +
import java.io.Serializable;
 +
import java.util.ArrayList;
 +
import java.util.Arrays;
 +
import java.util.Collection;
 +
import java.util.Iterator;
 +
import java.util.LinkedList;
 +
import java.util.PriorityQueue;
 +
import java.util.Queue;
 +
 +
import nat.util.M;
 +
 +
/**
 +
*
 +
* Implementation of bucket PR k-d tree.
 +
*
 +
* @author Nat Pavasant
 +
*
 +
* @param <V>
 +
*            The type of data to store
 +
*/
 +
public class PRKdBucketTree<V> implements Serializable {
 +
 +
private static final long serialVersionUID = 1L;
 +
 +
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
 +
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
 +
 +
private final PRKdBucketTree<V>[] children;
 +
private final Queue<KdEntry<V>> data;
 +
private final Distancer distancer;
 +
private final double[] lowerBound, upperBound;
 +
private final int[] numChildren;
 +
private final int allDimensions;
 +
private final int dimension;
 +
private final int maxDepth, maxDensity;
 +
private final double splitMedian;
 +
 +
private boolean isLeaf = true;
 +
 +
/**
 +
* Create new Bucket PR k-d tree.
 +
*
 +
* @param allDimensions
 +
*            number of dimensions in the tree
 +
* @param lowerBound
 +
*            the minimum value of the location of each dimension
 +
* @param upperBound
 +
*            the maximum value of the location of each dimension
 +
* @param numChildren
 +
*            number of children in each dimension
 +
* @param maxDepth
 +
*            the max depth of the tree
 +
* @param maxDensity
 +
*            size of bucket in each leaf
 +
* @param distancer
 +
*            distance measurer
 +
*/
 +
@SuppressWarnings("unchecked")
 +
public PRKdBucketTree(int allDimensions, double[] lowerBound,
 +
double[] upperBound, int[] numChildren, int maxDepth,
 +
int maxDensity, Distancer distancer) {
 +
if (allDimensions < 1 || maxDensity < 1)
 +
throw new IllegalArgumentException(
 +
"Either dimension or density isn't positive integer.");
 +
 +
if (lowerBound.length != allDimensions
 +
|| upperBound.length != allDimensions
 +
|| numChildren.length != allDimensions)
 +
throw new IllegalArgumentException(
 +
"Either bounds or children amount is more or less than dimension count.");
 +
 +
for (double a : lowerBound) {
 +
if (a < 0)
 +
throw new IllegalArgumentException(
 +
"Can't set lower bound to negative number.");
 +
}
 +
 +
for (int i = 0; i < lowerBound.length; i++) {
 +
if (lowerBound[i] > upperBound[i])
 +
throw new IllegalArgumentException(
 +
"Upper bound must have a value higer than lower bound.");
 +
}
 +
 +
this.allDimensions = allDimensions;
 +
this.lowerBound = lowerBound;
 +
this.upperBound = upperBound;
 +
this.numChildren = numChildren;
 +
this.distancer = distancer;
 +
this.maxDepth = maxDepth;
 +
this.maxDensity = maxDensity;
 +
this.dimension = maxDepth % allDimensions;
 +
this.data = new LinkedList<KdEntry<V>>();
 +
this.children = new PRKdBucketTree[numChildren[this.dimension]];
 +
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
 +
/ numChildren[this.dimension];
 +
}
 +
 +
// Another constructor
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double lowerBound, double upperBound, int numChildren,
 +
int maxDepth, int maxDensity, Distancer distance) {
 +
double[] lowerBounds = new double[dimension];
 +
double[] upperBounds = new double[dimension];
 +
int[] numChildrens = new int[dimension];
 +
Arrays.fill(lowerBounds, lowerBound);
 +
Arrays.fill(upperBounds, upperBound);
 +
Arrays.fill(numChildrens, numChildren);
 +
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
 +
numChildrens, maxDepth, maxDensity, distance);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double upperBound, int numChildren, int maxDepth, int maxDensity,
 +
Distancer distance) {
 +
return getTree(dataType, dimension, 0, upperBound, numChildren,
 +
maxDepth, maxDensity, distance);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double upperBound, int numChildren, int maxDepth, int maxDensity) {
 +
return getTree(dataType, dimension, 0, upperBound, numChildren,
 +
maxDepth, maxDensity, EUCLIDIAN);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double upperBound, int maxDepth, int maxDensity) {
 +
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
 +
maxDensity, EUCLIDIAN);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double upperBound, int numChildren) {
 +
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
 +
EUCLIDIAN);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
 +
double upperBound) {
 +
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
 +
}
 +
 +
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
 +
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
 +
}
 +
 +
/**
 +
* Add new point to the tree
 +
*
 +
* @param value
 +
*            the stored value
 +
* @param location
 +
*            location of the point
 +
* @return removed point if any
 +
*/
 +
public KdEntry<V> addPoint(V value, double[] location) {
 +
if (location.length != allDimensions) {
 +
throw new IllegalArgumentException(
 +
"Provided location have either more or less dimensions than the tree.");
 +
}
 +
 +
KdEntry<V> entry = new KdEntry<V>(value, location);
 +
 +
return addPoint(entry);
 +
}
 +
 +
/**
 +
* Get the n-nearest neighbor.
 +
*
 +
* @param size
 +
*            number of neighbors
 +
* @param center
 +
*            center of the cluster
 +
* @param weight
 +
*            weighting for the distancer
 +
* @return
 +
*/
 +
public KdCluster<V> getNearestNeighbor(int size, double[] center,
 +
double[] weight) {
 +
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
 +
nearestNeighborSearch(cluster);
 +
return cluster;
 +
}
 +
 +
/**
 +
* Back-end implementation of the entry adding.
 +
*
 +
* @param entry
 +
*            new entry to the tree
 +
* @return removed point if any
 +
*/
 +
private KdEntry<V> addPoint(KdEntry<V> entry) {
 +
 +
if (isLeaf) {
 +
 +
// Still has spaces, add the data
 +
if (data.size() < maxDensity) {
 +
data.add(entry);
 +
return null;
 +
}
 +
 +
// final leaf, unsplitable, remove element and add new one.
 +
if (maxDepth <= 1) {
 +
data.add(entry);
 +
return data.poll();
 +
}
 +
 +
// if we reached here, we need to split this leaf to a branch.
 +
isLeaf = false;
 +
for (KdEntry<V> p : data) {
 +
passToChildren(p);
 +
}
 +
 +
data.clear();
 +
}
 +
 +
// we are branch, pass to children
 +
return passToChildren(entry);
 +
}
 +
 +
/**
 +
* Perform n-nearest neighbor search
 +
*
 +
* @param cluster
 +
*            current working cluster
 +
*/
 +
private void nearestNeighborSearch(KdCluster<V> cluster) {
 +
if (isLeaf) {
 +
for (KdEntry<V> p : data) {
 +
cluster.consider(p);
 +
}
 +
} else {
 +
for (PRKdBucketTree<V> child : children) {
 +
if (child != null && cluster.isViable(child))
 +
child.nearestNeighborSearch(cluster);
 +
}
 +
}
 +
}
 +
 +
/**
 +
* Pass the entry to correct children
 +
*
 +
* @param p
 +
*            entry
 +
* @return removed point if any
 +
*/
 +
private KdEntry<V> passToChildren(KdEntry<V> p) {
 +
int i = getChildrenIndex(p.getLocation(dimension));
 +
 +
if (children[i] == null)
 +
children[i] = createChildTree(i);
 +
 +
return children[i].addPoint(p);
 +
}
 +
 +
/**
 +
* Return index of the children which will contains data with that value.
 +
*
 +
* @param value
 +
*            the value
 +
* @return index of children
 +
*/
 +
private int getChildrenIndex(double value) {
 +
return (int) M.limit(0, Math.floor(value / splitMedian),
 +
numChildren[dimension] - 1);
 +
}
 +
 +
/**
 +
* Create children tree with correct bound.
 +
*
 +
* @param i
 +
*            the index of the children
 +
* @return created tree
 +
*/
 +
private PRKdBucketTree<V> createChildTree(int i) {
 +
double[] upperBound = this.upperBound.clone();
 +
double[] lowerBound = this.lowerBound.clone();
 +
lowerBound[dimension] = splitMedian * i;
 +
upperBound[dimension] = splitMedian * (i + 1);
 +
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
 +
numChildren, maxDepth - 1, maxDensity, distancer);
 +
}
 +
 +
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
 +
private final PriorityQueue<KdPoint<K>> points;
 +
private final double[] center;
 +
private final double[] weight;
 +
private final int size;
 +
private final Distancer distancer;
 +
 +
public KdCluster(int size, double[] center, double[] weight,
 +
Distancer distancer) {
 +
points = new PriorityQueue<KdPoint<K>>();
 +
this.size = size;
 +
this.center = center;
 +
this.weight = weight;
 +
this.distancer = distancer;
 +
}
 +
 +
public void consider(KdEntry<K> k) {
 +
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
 +
 +
if (points.size() < size) {
 +
points.add(p);
 +
System.out.println(points.toString());
 +
} else if (points.peek().getDistanceToCenter() > p
 +
.getDistanceToCenter()) {
 +
System.out.println(points.toString());
 +
points.poll();
 +
points.add(p);
 +
}
 +
}
 +
 +
public boolean isViable(PRKdBucketTree<K> tree) {
 +
 +
if (points.size() < size)
 +
return true;
 +
 +
double[] testPoints = new double[center.length];
 +
 +
for (int i = 0; i < center.length; i++)
 +
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
 +
tree.upperBound[i]);
 +
 +
return points.peek().getDistanceToCenter() > distancer.getDistance(
 +
center, testPoints, weight);
 +
}
 +
 +
@Override
 +
public Iterator<KdPoint<K>> iterator() {
 +
return points.iterator();
 +
}
 +
 +
public Collection<KdPoint<K>> getValues() {
 +
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
 +
.size());
 +
 +
for (KdPoint<K> p : points) {
 +
collect.add(p);
 +
}
 +
 +
return collect;
 +
}
 +
 +
@Override
 +
public int hashCode() {
 +
final int prime = 31;
 +
int result = 1;
 +
result = prime * result + Arrays.hashCode(center);
 +
result = prime * result
 +
+ ((distancer == null) ? 0 : distancer.hashCode());
 +
result = prime * result
 +
+ ((points == null) ? 0 : points.hashCode());
 +
result = prime * result + size;
 +
result = prime * result + Arrays.hashCode(weight);
 +
return result;
 +
}
 +
 +
@SuppressWarnings("unchecked")
 +
@Override
 +
public boolean equals(Object obj) {
 +
if (this == obj)
 +
return true;
 +
if (obj == null)
 +
return false;
 +
if (!(obj instanceof KdCluster))
 +
return false;
 +
KdCluster other = (KdCluster) obj;
 +
if (!Arrays.equals(center, other.center))
 +
return false;
 +
if (distancer == null) {
 +
if (other.distancer != null)
 +
return false;
 +
} else if (!distancer.equals(other.distancer))
 +
return false;
 +
if (points == null) {
 +
if (other.points != null)
 +
return false;
 +
} else if (!points.equals(other.points))
 +
return false;
 +
if (size != other.size)
 +
return false;
 +
if (!Arrays.equals(weight, other.weight))
 +
return false;
 +
return true;
 +
}
 +
}
 +
 +
private static class KdEntry<K> implements Serializable {
 +
private static final long serialVersionUID = 1L;
 +
 +
private final K value;
 +
private final double[] location;
 +
 +
public KdEntry(K value, double[] location) {
 +
super();
 +
this.value = value;
 +
this.location = location;
 +
}
 +
 +
public K getValue() {
 +
return value;
 +
}
 +
 +
public double[] getLocation() {
 +
return location;
 +
}
 +
 +
public double getLocation(int a) {
 +
return location[a];
 +
}
 +
 +
@Override
 +
public int hashCode() {
 +
final int prime = 31;
 +
int result = 1;
 +
result = prime * result + Arrays.hashCode(location);
 +
result = prime * result + ((value == null) ? 0 : value.hashCode());
 +
return result;
 +
}
 +
 +
@SuppressWarnings("unchecked")
 +
@Override
 +
public boolean equals(Object obj) {
 +
if (this == obj)
 +
return true;
 +
if (obj == null)
 +
return false;
 +
if (!(obj instanceof KdEntry))
 +
return false;
 +
KdEntry other = (KdEntry) obj;
 +
if (!Arrays.equals(location, other.location))
 +
return false;
 +
if (value == null) {
 +
if (other.value != null)
 +
return false;
 +
} else if (!value.equals(other.value))
 +
return false;
 +
return true;
 +
}
 +
}
 +
 +
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
 +
Comparable<KdPoint<K>> {
 +
private static final long serialVersionUID = 1L;
 +
private final double distanceToCenter;
 +
 +
public KdPoint(KdEntry<K> p, double[] center, double[] weight,
 +
Distancer distancer) {
 +
super(p.getValue(), p.getLocation());
 +
distanceToCenter = distancer.getDistance(center, getLocation(),
 +
weight);
 +
}
 +
 +
public double getDistanceToCenter() {
 +
return distanceToCenter;
 +
}
 +
 +
@Override
 +
public String toString() {
 +
return (new Double(distanceToCenter)).toString();
 +
}
 +
 +
@Override
 +
public int compareTo(KdPoint<K> o) {
 +
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
 +
}
 +
 +
@Override
 +
public int hashCode() {
 +
final int prime = 31;
 +
int result = super.hashCode();
 +
long temp;
 +
temp = Double.doubleToLongBits(distanceToCenter);
 +
result = prime * result + (int) (temp ^ (temp >>> 32));
 +
return result;
 +
}
 +
 +
@SuppressWarnings("unchecked")
 +
@Override
 +
public boolean equals(Object obj) {
 +
if (this == obj)
 +
return true;
 +
if (!super.equals(obj))
 +
return false;
 +
if (!(obj instanceof KdPoint))
 +
return false;
 +
KdPoint other = (KdPoint) obj;
 +
if (Double.doubleToLongBits(distanceToCenter) != Double
 +
.doubleToLongBits(other.distanceToCenter))
 +
return false;
 +
return true;
 +
}
 +
}
 +
 +
public static abstract class Distancer {
 +
public double getDistance(double[] p1, double[] p2, double[] weight) {
 +
if (p1.length != p2.length)
 +
throw new IllegalArgumentException();
 +
return getPointDistance(p1, p2, weight);
 +
}
 +
 +
public abstract double getPointDistance(double[] p1, double[] p2,
 +
double[] weight);
 +
 +
public static class EuclidianDistancer extends Distancer {
 +
@Override
 +
public double getPointDistance(double[] p1, double[] p2,
 +
double[] weight) {
 +
double result = 0;
 +
for (int i = 0; i < p1.length; i++) {
 +
result += M.sqr(p1[i] - p2[i]) * weight[i];
 +
}
 +
return M.sqrt(result);
 +
}
 +
}
 +
 +
public static class ManhattanDistancer extends Distancer {
 +
@Override
 +
public double getPointDistance(double[] p1, double[] p2,
 +
double[] weight) {
 +
double result = 0;
 +
for (int i = 0; i < p1.length; i++) {
 +
result += M.abs(p1[i] - p2[i]) * weight[i];
 +
}
 +
return result;
 +
}
 +
}
 +
}
 +
 +
@Override
 +
public int hashCode() {
 +
final int prime = 31;
 +
int result = 1;
 +
result = prime * result + allDimensions;
 +
result = prime * result + Arrays.hashCode(children);
 +
result = prime * result + ((data == null) ? 0 : data.hashCode());
 +
result = prime * result + dimension;
 +
result = prime * result + (isLeaf ? 1231 : 1237);
 +
result = prime * result + Arrays.hashCode(lowerBound);
 +
result = prime * result + maxDensity;
 +
result = prime * result + maxDepth;
 +
result = prime * result + Arrays.hashCode(numChildren);
 +
long temp;
 +
temp = Double.doubleToLongBits(splitMedian);
 +
result = prime * result + (int) (temp ^ (temp >>> 32));
 +
result = prime * result + Arrays.hashCode(upperBound);
 +
return result;
 +
}
 +
 +
@SuppressWarnings("unchecked")
 +
@Override
 +
public boolean equals(Object obj) {
 +
if (this == obj)
 +
return true;
 +
if (obj == null)
 +
return false;
 +
if (!(obj instanceof PRKdBucketTree))
 +
return false;
 +
PRKdBucketTree other = (PRKdBucketTree) obj;
 +
if (allDimensions != other.allDimensions)
 +
return false;
 +
if (!Arrays.equals(children, other.children))
 +
return false;
 +
if (data == null) {
 +
if (other.data != null)
 +
return false;
 +
} else if (!data.equals(other.data))
 +
return false;
 +
if (dimension != other.dimension)
 +
return false;
 +
if (isLeaf != other.isLeaf)
 +
return false;
 +
if (!Arrays.equals(lowerBound, other.lowerBound))
 +
return false;
 +
if (maxDensity != other.maxDensity)
 +
return false;
 +
if (maxDepth != other.maxDepth)
 +
return false;
 +
if (!Arrays.equals(numChildren, other.numChildren))
 +
return false;
 +
if (Double.doubleToLongBits(splitMedian) != Double
 +
.doubleToLongBits(other.splitMedian))
 +
return false;
 +
if (!Arrays.equals(upperBound, other.upperBound))
 +
return false;
 +
return true;
 +
}
 +
 +
}
 +
</pre>
 +
 +
== Tester ==
 +
<pre>
 +
package nat.tree;
 +
 +
import java.util.ArrayList;
 +
import java.util.Arrays;
 +
import java.util.Collection;
 +
import java.util.PriorityQueue;
 +
import java.util.Random;
 +
 +
public class Test {
 +
public static void main(String[] args) {
 +
final int numTest = 100;
 +
final int clusterSize = 15;
 +
 +
long linearTime, tree2time, tree3time, tree4time;
 +
String[] answer = new String[clusterSize];
 +
 +
System.out.println("Starting Bucket PR k-d tree performance test...");
 +
System.out.println("Generating points...");
 +
 +
ArrayList<String> input = new ArrayList<String>();
 +
ArrayList<double[]> location = new ArrayList<double[]>();
 +
 +
for (int i = 0; i < numTest; i++) {
 +
input.add(generateRandomString());
 +
double[] p = new double[3];
 +
p[0] = Math.random();
 +
p[1] = Math.random();
 +
p[2] = Math.random();
 +
location.add(p);
 +
}
 +
 +
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
 +
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
 +
PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4);
 +
 +
for (int i = 0; i < numTest; i++) {
 +
tree2.addPoint(input.get(i), location.get(i));
 +
tree3.addPoint(input.get(i), location.get(i));
 +
tree4.addPoint(input.get(i), location.get(i));
 +
}
 +
 +
double[] center = new double[3];
 +
center[0] = Math.random();
 +
center[1] = Math.random();
 +
center[2] = Math.random();
 +
 +
System.out.println("Data generated.");
 +
System.out.println("Performing linear search...");
 +
 +
linearTime = -System.nanoTime();
 +
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
 +
 +
for (int i = 0; i < numTest; i++) {
 +
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
 +
location.get(i), new double[] { 1, 1, 1 });
 +
 +
pq.add(new Compare(input.get(i), distance));
 +
}
 +
 +
double distance = -1;
 +
for (int i = 0; i < clusterSize; i++) {
 +
if (distance == -1)
 +
distance = pq.peek().val;
 +
answer[i] = pq.poll().data;
 +
}
 +
linearTime += System.nanoTime();
 +
Arrays.sort(answer);
 +
System.out.println("Linear search complete; time = "
 +
+ (linearTime / 1E9));
 +
 +
 +
System.out.println("Performing binary k-d tree search...");
 +
 +
tree2time = -System.nanoTime();
 +
 +
PRKdBucketTree.KdCluster<String> tree2r = tree2
 +
.getNearestNeighbor(clusterSize, center,
 +
new double[] { 1, 1, 1 });
 +
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
 +
 +
tree2time += System.nanoTime();
 +
 +
System.out.println("Binary tree search complete; time = "
 +
+ (tree2time / 1E9));
 +
int correct = 0;
 +
int j = 0;
 +
String[] tree2o = new String[clusterSize];
 +
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
 +
tree2o[j++] = p.getValue();
 +
}
 +
Arrays.sort(tree2o);
 +
for (int i = 0; i < tree2o.length; i++) {
 +
if (tree2o[i].equals(answer[i]))
 +
correct++;
 +
}
 +
 +
System.out.println(": accuracy = " + ((double) correct / clusterSize));
 +
 +
System.out.println("Performing ternary k-d tree search...");
 +
 +
tree3time = -System.nanoTime();
 +
 +
PRKdBucketTree.KdCluster<String> tree3r = tree3
 +
.getNearestNeighbor(clusterSize, center,
 +
new double[] { 1, 1, 1 });
 +
Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues();
 +
 +
tree3time += System.nanoTime();
 +
 +
System.out.println("Ternary tree search complete; time = "
 +
+ (tree3time / 1E9));
 +
correct = 0;
 +
j = 0;
 +
String[] tree3o = new String[clusterSize];
 +
for (PRKdBucketTree.KdPoint<String> p : tree3a) {
 +
tree3o[j++] = p.getValue();
 +
}
 +
Arrays.sort(tree3o);
 +
for (int i = 0; i < tree3o.length; i++) {
 +
if (tree3o[i].equals(answer[i]))
 +
correct++;
 +
}
 +
 +
System.out.println(": accuracy = " + ((double) correct / clusterSize));
 +
 +
System.out.println("Performing quaternary k-d tree search...");
 +
 +
tree4time = -System.nanoTime();
 +
 +
PRKdBucketTree.KdCluster<String> tree4r = tree4
 +
.getNearestNeighbor(clusterSize, center,
 +
new double[] { 1, 1, 1 });
 +
Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues();
 +
 +
tree4time += System.nanoTime();
 +
 +
System.out.println("Quaternary tree search complete; time = "
 +
+ (tree4time / 1E9));
 +
correct = 0;
 +
j = 0;
 +
String[] tree4o = new String[clusterSize];
 +
for (PRKdBucketTree.KdPoint<String> p : tree4a) {
 +
tree4o[j++] = p.getValue();
 +
}
 +
Arrays.sort(tree4o);
 +
for (int i = 0; i < tree4o.length; i++) {
 +
if (tree4o[i].equals(answer[i]))
 +
correct++;
 +
}
 +
 +
System.out.println(": accuracy = " + ((double) correct / clusterSize));
 +
}
 +
 +
private static class Compare implements Comparable<Compare> {
 +
String data;
 +
double val;
 +
 +
@Override
 +
public int compareTo(Compare o) {
 +
return (int) -Math.signum(o.val - val);
 +
}
 +
 +
public Compare(String data, double val) {
 +
this.data = data;
 +
this.val = val;
 +
}
 +
 +
}
 +
 +
private static String generateRandomString() {
 +
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
 +
Random r = new Random();
 +
char[] buf = new char[25];
 +
 +
for (int i = 0; i < buf.length; i++) {
 +
buf[i] = chars.charAt(r.nextInt(chars.length()));
 +
}
 +
 +
return new String(buf);
 +
}
 +
}
 +
 +
</pre>
 +
 +
== Old ==
 +
----
 +
 
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
 
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
  

Latest revision as of 19:12, 15 August 2009

Here is my code that I have problem state in Talk:Kd-Tree. My M class can be found at User:Nat/Free code.

Tree

package nat.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;

import nat.util.M;

/**
 * 
 * Implementation of bucket PR k-d tree.
 * 
 * @author Nat Pavasant
 * 
 * @param <V>
 *            The type of data to store
 */
public class PRKdBucketTree<V> implements Serializable {

	private static final long serialVersionUID = 1L;

	public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
	public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();

	private final PRKdBucketTree<V>[] children;
	private final Queue<KdEntry<V>> data;
	private final Distancer distancer;
	private final double[] lowerBound, upperBound;
	private final int[] numChildren;
	private final int allDimensions;
	private final int dimension;
	private final int maxDepth, maxDensity;
	private final double splitMedian;

	private boolean isLeaf = true;

	/**
	 * Create new Bucket PR k-d tree.
	 * 
	 * @param allDimensions
	 *            number of dimensions in the tree
	 * @param lowerBound
	 *            the minimum value of the location of each dimension
	 * @param upperBound
	 *            the maximum value of the location of each dimension
	 * @param numChildren
	 *            number of children in each dimension
	 * @param maxDepth
	 *            the max depth of the tree
	 * @param maxDensity
	 *            size of bucket in each leaf
	 * @param distancer
	 *            distance measurer
	 */
	@SuppressWarnings("unchecked")
	public PRKdBucketTree(int allDimensions, double[] lowerBound,
			double[] upperBound, int[] numChildren, int maxDepth,
			int maxDensity, Distancer distancer) {
		if (allDimensions < 1 || maxDensity < 1)
			throw new IllegalArgumentException(
					"Either dimension or density isn't positive integer.");

		if (lowerBound.length != allDimensions
				|| upperBound.length != allDimensions
				|| numChildren.length != allDimensions)
			throw new IllegalArgumentException(
					"Either bounds or children amount is more or less than dimension count.");

		for (double a : lowerBound) {
			if (a < 0)
				throw new IllegalArgumentException(
						"Can't set lower bound to negative number.");
		}

		for (int i = 0; i < lowerBound.length; i++) {
			if (lowerBound[i] > upperBound[i])
				throw new IllegalArgumentException(
						"Upper bound must have a value higer than lower bound.");
		}

		this.allDimensions = allDimensions;
		this.lowerBound = lowerBound;
		this.upperBound = upperBound;
		this.numChildren = numChildren;
		this.distancer = distancer;
		this.maxDepth = maxDepth;
		this.maxDensity = maxDensity;
		this.dimension = maxDepth % allDimensions;
		this.data = new LinkedList<KdEntry<V>>();
		this.children = new PRKdBucketTree[numChildren[this.dimension]];
		this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
				/ numChildren[this.dimension];
	}

	// Another constructor
	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double lowerBound, double upperBound, int numChildren,
			int maxDepth, int maxDensity, Distancer distance) {
		double[] lowerBounds = new double[dimension];
		double[] upperBounds = new double[dimension];
		int[] numChildrens = new int[dimension];
		Arrays.fill(lowerBounds, lowerBound);
		Arrays.fill(upperBounds, upperBound);
		Arrays.fill(numChildrens, numChildren);
		return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
				numChildrens, maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity,
			Distancer distance) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
				maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren) {
		return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
				EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound) {
		return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
		return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
	}

	/**
	 * Add new point to the tree
	 * 
	 * @param value
	 *            the stored value
	 * @param location
	 *            location of the point
	 * @return removed point if any
	 */
	public KdEntry<V> addPoint(V value, double[] location) {
		if (location.length != allDimensions) {
			throw new IllegalArgumentException(
					"Provided location have either more or less dimensions than the tree.");
		}

		KdEntry<V> entry = new KdEntry<V>(value, location);

		return addPoint(entry);
	}

	/**
	 * Get the n-nearest neighbor.
	 * 
	 * @param size
	 *            number of neighbors
	 * @param center
	 *            center of the cluster
	 * @param weight
	 *            weighting for the distancer
	 * @return
	 */
	public KdCluster<V> getNearestNeighbor(int size, double[] center,
			double[] weight) {
		KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
		nearestNeighborSearch(cluster);
		return cluster;
	}

	/**
	 * Back-end implementation of the entry adding.
	 * 
	 * @param entry
	 *            new entry to the tree
	 * @return removed point if any
	 */
	private KdEntry<V> addPoint(KdEntry<V> entry) {

		if (isLeaf) {

			// Still has spaces, add the data
			if (data.size() < maxDensity) {
				data.add(entry);
				return null;
			}

			// final leaf, unsplitable, remove element and add new one.
			if (maxDepth <= 1) {
				data.add(entry);
				return data.poll();
			}

			// if we reached here, we need to split this leaf to a branch.
			isLeaf = false;
			for (KdEntry<V> p : data) {
				passToChildren(p);
			}

			data.clear();
		}

		// we are branch, pass to children
		return passToChildren(entry);
	}

	/**
	 * Perform n-nearest neighbor search
	 * 
	 * @param cluster
	 *            current working cluster
	 */
	private void nearestNeighborSearch(KdCluster<V> cluster) {
		if (isLeaf) {
			for (KdEntry<V> p : data) {
				cluster.consider(p);
			}
		} else {
			for (PRKdBucketTree<V> child : children) {
				if (child != null && cluster.isViable(child))
					child.nearestNeighborSearch(cluster);
			}
		}
	}

	/**
	 * Pass the entry to correct children
	 * 
	 * @param p
	 *            entry
	 * @return removed point if any
	 */
	private KdEntry<V> passToChildren(KdEntry<V> p) {
		int i = getChildrenIndex(p.getLocation(dimension));

		if (children[i] == null)
			children[i] = createChildTree(i);

		return children[i].addPoint(p);
	}

	/**
	 * Return index of the children which will contains data with that value.
	 * 
	 * @param value
	 *            the value
	 * @return index of children
	 */
	private int getChildrenIndex(double value) {
		return (int) M.limit(0, Math.floor(value / splitMedian),
				numChildren[dimension] - 1);
	}

	/**
	 * Create children tree with correct bound.
	 * 
	 * @param i
	 *            the index of the children
	 * @return created tree
	 */
	private PRKdBucketTree<V> createChildTree(int i) {
		double[] upperBound = this.upperBound.clone();
		double[] lowerBound = this.lowerBound.clone();
		lowerBound[dimension] = splitMedian * i;
		upperBound[dimension] = splitMedian * (i + 1);
		return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
				numChildren, maxDepth - 1, maxDensity, distancer);
	}

	public static class KdCluster<K> implements Iterable<KdPoint<K>> {
		private final PriorityQueue<KdPoint<K>> points;
		private final double[] center;
		private final double[] weight;
		private final int size;
		private final Distancer distancer;

		public KdCluster(int size, double[] center, double[] weight,
				Distancer distancer) {
			points = new PriorityQueue<KdPoint<K>>();
			this.size = size;
			this.center = center;
			this.weight = weight;
			this.distancer = distancer;
		}

		public void consider(KdEntry<K> k) {
			KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);

			if (points.size() < size) {
				points.add(p);
				System.out.println(points.toString());
			} else if (points.peek().getDistanceToCenter() > p
					.getDistanceToCenter()) {
				System.out.println(points.toString());
				points.poll();
				points.add(p);
			}
		}

		public boolean isViable(PRKdBucketTree<K> tree) {

			if (points.size() < size)
				return true;

			double[] testPoints = new double[center.length];

			for (int i = 0; i < center.length; i++)
				testPoints[i] = M.limit(tree.lowerBound[i], center[i],
						tree.upperBound[i]);

			return points.peek().getDistanceToCenter() > distancer.getDistance(
					center, testPoints, weight);
		}

		@Override
		public Iterator<KdPoint<K>> iterator() {
			return points.iterator();
		}

		public Collection<KdPoint<K>> getValues() {
			Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
					.size());

			for (KdPoint<K> p : points) {
				collect.add(p);
			}

			return collect;
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + Arrays.hashCode(center);
			result = prime * result
					+ ((distancer == null) ? 0 : distancer.hashCode());
			result = prime * result
					+ ((points == null) ? 0 : points.hashCode());
			result = prime * result + size;
			result = prime * result + Arrays.hashCode(weight);
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (!(obj instanceof KdCluster))
				return false;
			KdCluster other = (KdCluster) obj;
			if (!Arrays.equals(center, other.center))
				return false;
			if (distancer == null) {
				if (other.distancer != null)
					return false;
			} else if (!distancer.equals(other.distancer))
				return false;
			if (points == null) {
				if (other.points != null)
					return false;
			} else if (!points.equals(other.points))
				return false;
			if (size != other.size)
				return false;
			if (!Arrays.equals(weight, other.weight))
				return false;
			return true;
		}
	}

	private static class KdEntry<K> implements Serializable {
		private static final long serialVersionUID = 1L;

		private final K value;
		private final double[] location;

		public KdEntry(K value, double[] location) {
			super();
			this.value = value;
			this.location = location;
		}

		public K getValue() {
			return value;
		}

		public double[] getLocation() {
			return location;
		}

		public double getLocation(int a) {
			return location[a];
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + Arrays.hashCode(location);
			result = prime * result + ((value == null) ? 0 : value.hashCode());
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (!(obj instanceof KdEntry))
				return false;
			KdEntry other = (KdEntry) obj;
			if (!Arrays.equals(location, other.location))
				return false;
			if (value == null) {
				if (other.value != null)
					return false;
			} else if (!value.equals(other.value))
				return false;
			return true;
		}
	}

	public static class KdPoint<K> extends KdEntry<K> implements Serializable,
			Comparable<KdPoint<K>> {
		private static final long serialVersionUID = 1L;
		private final double distanceToCenter;

		public KdPoint(KdEntry<K> p, double[] center, double[] weight,
				Distancer distancer) {
			super(p.getValue(), p.getLocation());
			distanceToCenter = distancer.getDistance(center, getLocation(),
					weight);
		}

		public double getDistanceToCenter() {
			return distanceToCenter;
		}
		
		@Override
		public String toString() {
			return (new Double(distanceToCenter)).toString();
		}

		@Override
		public int compareTo(KdPoint<K> o) {
			return (int) Math.signum(o.distanceToCenter - distanceToCenter);
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = super.hashCode();
			long temp;
			temp = Double.doubleToLongBits(distanceToCenter);
			result = prime * result + (int) (temp ^ (temp >>> 32));
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (!super.equals(obj))
				return false;
			if (!(obj instanceof KdPoint))
				return false;
			KdPoint other = (KdPoint) obj;
			if (Double.doubleToLongBits(distanceToCenter) != Double
					.doubleToLongBits(other.distanceToCenter))
				return false;
			return true;
		}
	}

	public static abstract class Distancer {
		public double getDistance(double[] p1, double[] p2, double[] weight) {
			if (p1.length != p2.length)
				throw new IllegalArgumentException();
			return getPointDistance(p1, p2, weight);
		}

		public abstract double getPointDistance(double[] p1, double[] p2,
				double[] weight);

		public static class EuclidianDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.sqr(p1[i] - p2[i]) * weight[i];
				}
				return M.sqrt(result);
			}
		}

		public static class ManhattanDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.abs(p1[i] - p2[i]) * weight[i];
				}
				return result;
			}
		}
	}

	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + allDimensions;
		result = prime * result + Arrays.hashCode(children);
		result = prime * result + ((data == null) ? 0 : data.hashCode());
		result = prime * result + dimension;
		result = prime * result + (isLeaf ? 1231 : 1237);
		result = prime * result + Arrays.hashCode(lowerBound);
		result = prime * result + maxDensity;
		result = prime * result + maxDepth;
		result = prime * result + Arrays.hashCode(numChildren);
		long temp;
		temp = Double.doubleToLongBits(splitMedian);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		result = prime * result + Arrays.hashCode(upperBound);
		return result;
	}

	@SuppressWarnings("unchecked")
	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (!(obj instanceof PRKdBucketTree))
			return false;
		PRKdBucketTree other = (PRKdBucketTree) obj;
		if (allDimensions != other.allDimensions)
			return false;
		if (!Arrays.equals(children, other.children))
			return false;
		if (data == null) {
			if (other.data != null)
				return false;
		} else if (!data.equals(other.data))
			return false;
		if (dimension != other.dimension)
			return false;
		if (isLeaf != other.isLeaf)
			return false;
		if (!Arrays.equals(lowerBound, other.lowerBound))
			return false;
		if (maxDensity != other.maxDensity)
			return false;
		if (maxDepth != other.maxDepth)
			return false;
		if (!Arrays.equals(numChildren, other.numChildren))
			return false;
		if (Double.doubleToLongBits(splitMedian) != Double
				.doubleToLongBits(other.splitMedian))
			return false;
		if (!Arrays.equals(upperBound, other.upperBound))
			return false;
		return true;
	}

}

Tester

package nat.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;

public class Test {
	public static void main(String[] args) {
		final int numTest = 100;
		final int clusterSize = 15;

		long linearTime, tree2time, tree3time, tree4time;
		String[] answer = new String[clusterSize];

		System.out.println("Starting Bucket PR k-d tree performance test...");
		System.out.println("Generating points...");

		ArrayList<String> input = new ArrayList<String>();
		ArrayList<double[]> location = new ArrayList<double[]>();

		for (int i = 0; i < numTest; i++) {
			input.add(generateRandomString());
			double[] p = new double[3];
			p[0] = Math.random();
			p[1] = Math.random();
			p[2] = Math.random();
			location.add(p);
		}

		PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
		PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
		PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4);

		for (int i = 0; i < numTest; i++) {
			tree2.addPoint(input.get(i), location.get(i));
			tree3.addPoint(input.get(i), location.get(i));
			tree4.addPoint(input.get(i), location.get(i));
		}

		double[] center = new double[3];
		center[0] = Math.random();
		center[1] = Math.random();
		center[2] = Math.random();

		System.out.println("Data generated.");
		System.out.println("Performing linear search...");

		linearTime = -System.nanoTime();
		PriorityQueue<Compare> pq = new PriorityQueue<Compare>();

		for (int i = 0; i < numTest; i++) {
			double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
					location.get(i), new double[] { 1, 1, 1 });

			pq.add(new Compare(input.get(i), distance));
		}

		double distance = -1;
		for (int i = 0; i < clusterSize; i++) {
			if (distance == -1)
				distance = pq.peek().val;
			answer[i] = pq.poll().data;
		}
		linearTime += System.nanoTime();
		Arrays.sort(answer);
		System.out.println("Linear search complete; time = "
				+ (linearTime / 1E9));
		
		
		System.out.println("Performing binary k-d tree search...");

		tree2time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree2r = tree2
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();

		tree2time += System.nanoTime();

		System.out.println("Binary tree search complete; time = "
				+ (tree2time / 1E9));
		int correct = 0;
		int j = 0;
		String[] tree2o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree2a) {
			tree2o[j++] = p.getValue();
		}
		Arrays.sort(tree2o);
		for (int i = 0; i < tree2o.length; i++) {
			if (tree2o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
		
		System.out.println("Performing ternary k-d tree search...");

		tree3time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree3r = tree3
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues();

		tree3time += System.nanoTime();

		System.out.println("Ternary tree search complete; time = "
				+ (tree3time / 1E9));
		correct = 0;
		j = 0;
		String[] tree3o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree3a) {
			tree3o[j++] = p.getValue();
		}
		Arrays.sort(tree3o);
		for (int i = 0; i < tree3o.length; i++) {
			if (tree3o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
		
		System.out.println("Performing quaternary k-d tree search...");

		tree4time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree4r = tree4
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues();

		tree4time += System.nanoTime();

		System.out.println("Quaternary tree search complete; time = "
				+ (tree4time / 1E9));
		correct = 0;
		j = 0;
		String[] tree4o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree4a) {
			tree4o[j++] = p.getValue();
		}
		Arrays.sort(tree4o);
		for (int i = 0; i < tree4o.length; i++) {
			if (tree4o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
	}

	private static class Compare implements Comparable<Compare> {
		String data;
		double val;

		@Override
		public int compareTo(Compare o) {
			return (int) -Math.signum(o.val - val);
		}

		public Compare(String data, double val) {
			this.data = data;
			this.val = val;
		}

	}

	private static String generateRandomString() {
		String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
		Random r = new Random();
		char[] buf = new char[25];

		for (int i = 0; i < buf.length; i++) {
			buf[i] = chars.charAt(r.nextInt(chars.length()));
		}

		return new String(buf);
	}
}

Old


It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!

package nat.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;

import nat.util.M;

/**
 * 
 * Implementation of bucket PR k-d tree.
 * 
 * @author Nat Pavasant
 * 
 * @param <V>
 *            The type of data to store
 */
public class PRKdBucketTree<V> implements Serializable {
	private static final long serialVersionUID = 1L;

	public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
	public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();

	private final PRKdBucketTree<V>[] children;
	private final Queue<KdEntry<V>> data;
	private final Distancer distancer;
	private final double[] lowerBound, upperBound;
	private final int[] numChildren;
	private final int allDimensions;
	private final int dimension;
	private final int maxDepth, maxDensity;
	private final double splitMedian;

	private boolean isLeaf = true;

	/**
	 * Create new Bucket PR k-d tree.
	 * 
	 * @param allDimensions
	 *            number of dimensions in the tree
	 * @param lowerBound
	 *            the minimum value of the location of each dimension
	 * @param upperBound
	 *            the maximum value of the location of each dimension
	 * @param numChildren
	 *            number of children in each dimension
	 * @param maxDepth
	 *            the max depth of the tree
	 * @param maxDensity
	 *            size of bucket in each leaf
	 * @param distancer
	 *            distance measurer
	 */
	@SuppressWarnings("unchecked")
	public PRKdBucketTree(int allDimensions, double[] lowerBound,
			double[] upperBound, int[] numChildren, int maxDepth,
			int maxDensity, Distancer distancer) {
		if (allDimensions < 1 || maxDensity < 1)
			throw new IllegalArgumentException(
					"Either dimension or density isn't positive integer.");

		if (lowerBound.length != allDimensions
				|| upperBound.length != allDimensions
				|| numChildren.length != allDimensions)
			throw new IllegalArgumentException(
					"Either bounds or children amount is more or less than dimension count.");

		for (double a : lowerBound) {
			if (a < 0)
				throw new IllegalArgumentException(
						"Can't set lower bound to negative number.");
		}

		for (int i = 0; i < lowerBound.length; i++) {
			if (lowerBound[i] > upperBound[i])
				throw new IllegalArgumentException(
						"Upper bound must have a value higer than lower bound.");
		}

		this.allDimensions = allDimensions;
		this.lowerBound = lowerBound;
		this.upperBound = upperBound;
		this.numChildren = numChildren;
		this.distancer = distancer;
		this.maxDepth = maxDepth;
		this.maxDensity = maxDensity;
		this.dimension = maxDepth % allDimensions;
		this.data = new LinkedList<KdEntry<V>>();
		this.children = new PRKdBucketTree[numChildren[this.dimension]];
		this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
				/ numChildren[this.dimension];
	}

	// Another constructor
	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double lowerBound, double upperBound, int numChildren,
			int maxDepth, int maxDensity, Distancer distance) {
		double[] lowerBounds = new double[dimension];
		double[] upperBounds = new double[dimension];
		int[] numChildrens = new int[dimension];
		Arrays.fill(lowerBounds, lowerBound);
		Arrays.fill(upperBounds, upperBound);
		Arrays.fill(numChildrens, numChildren);
		return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
				numChildrens, maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity,
			Distancer distance) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
				maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren) {
		return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
				EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound) {
		return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
		return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
	}

	/**
	 * Add new point to the tree
	 * 
	 * @param value
	 *            the stored value
	 * @param location
	 *            location of the point
	 * @return removed point if any
	 */
	public KdEntry<V> addPoint(V value, double[] location) {
		if (location.length != allDimensions) {
			throw new IllegalArgumentException(
					"Provided location have either more or less dimensions than the tree.");
		}

		KdEntry<V> entry = new KdEntry<V>(value, location);

		return addPoint(entry);
	}

	/**
	 * Get the n-nearest neighbor.
	 * 
	 * @param size
	 *            number of neighbors
	 * @param center
	 *            center of the cluster
	 * @param weight
	 *            weighting for the distancer
	 * @return
	 */
	public KdCluster<V> getNearestNeighbor(int size, double[] center,
			double[] weight) {
		KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
		nearestNeighborSearch(cluster);
		return cluster;
	}

	/**
	 * Back-end implementation of the entry adding.
	 * 
	 * @param entry
	 *            new entry to the tree
	 * @return removed point if any
	 */
	private KdEntry<V> addPoint(KdEntry<V> entry) {

		if (isLeaf) {

			// Still has spaces, add the data
			if (data.size() < maxDensity) {
				data.add(entry);
				return null;
			}

			// final leaf, unsplitable, remove element and add new one.
			if (maxDepth <= 1) {
				data.add(entry);
				return data.poll();
			}

			// if we reached here, we need to split this leaf to a branch.
			isLeaf = false;
			for (KdEntry<V> p : data) {
				passToChildren(p);
			}

			data.clear();
		}

		// we are branch, pass to children
		return passToChildren(entry);
	}

	/**
	 * Perform n-nearest neighbor search
	 * 
	 * @param cluster
	 *            current working cluster
	 */
	private void nearestNeighborSearch(KdCluster<V> cluster) {
		if (isLeaf) {
			for (KdEntry<V> p : data) {
				cluster.consider(p);
			}
		} else {
			for (PRKdBucketTree<V> child : children) {
				if (child != null && cluster.isViable(child))
					child.nearestNeighborSearch(cluster);
			}
		}
	}

	/**
	 * Pass the entry to correct children
	 * 
	 * @param p
	 *            entry
	 * @return removed point if any
	 */
	private KdEntry<V> passToChildren(KdEntry<V> p) {
		int i = getChildrenIndex(p.getLocation(dimension));

		if (children[i] == null)
			children[i] = createChildTree(i);

		return children[i].addPoint(p);
	}

	/**
	 * Return index of the children which will contains data with that value.
	 * 
	 * @param value
	 *            the value
	 * @return index of children
	 */
	private int getChildrenIndex(double value) {
		return (int) M.limit(0, Math.floor(value / splitMedian),
				numChildren[dimension] - 1);
	}

	/**
	 * Create children tree with correct bound.
	 * 
	 * @param i
	 *            the index of the children
	 * @return created tree
	 */
	private PRKdBucketTree<V> createChildTree(int i) {
		double[] upperBound = this.upperBound.clone();
		double[] lowerBound = this.lowerBound.clone();
		lowerBound[i] = splitMedian * i;
		upperBound[i] = splitMedian * (i + 1);
		return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
				numChildren, maxDepth - 1, maxDensity, distancer);
	}

	public static class KdCluster<K> implements Iterable<KdPoint<K>> {
		private final PriorityQueue<KdPoint<K>> points;
		private final double[] center;
		private final double[] weight;
		private final int size;
		private final Distancer distancer;

		public KdCluster(int size, double[] center, double[] weight, Distancer distancer) {
			points = new PriorityQueue<KdPoint<K>>();
			this.size = size;
			this.center = center;
			this.weight = weight;
			this.distancer = distancer;
		}

		public void consider(KdEntry<K> k) {
			KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);

			if (points.size() < size) {
				points.add(p);
			} else if (points.peek().isFurtherThan(p)) {
				points.poll();
				points.add(p);
			}
		}

		public boolean isViable(PRKdBucketTree<K> tree) {

			if (points.size() < size)
				return true;

			double[] testPoints = new double[center.length];

			for (int i = 0; i < center.length; i++)
				testPoints[i] = M.limit(tree.lowerBound[i], center[i],
						tree.upperBound[i]);

			return points.peek().isFurtherThan(
					distancer.getDistance(center, testPoints, weight));
		}

		@Override
		public Iterator<KdPoint<K>> iterator() {
			return points.iterator();
		}

		public Collection<KdPoint<K>> getValues() {
			Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
					.size());

			for (KdPoint<K> p : points) {
				collect.add(p);
			}

			return collect;
		}
	}

	private static class KdEntry<K> implements Serializable {
		private static final long serialVersionUID = 1L;

		private final K value;
		private final double[] location;

		public KdEntry(K value, double[] location) {
			super();
			this.value = value;
			this.location = location;
		}

		public K getValue() {
			return value;
		}

		public double[] getLocation() {
			return location;
		}

		public double getLocation(int a) {
			return location[a];
		}
	}

	public static class KdPoint<K> extends KdEntry<K> implements Serializable,
			Comparable<KdPoint<K>> {
		private static final long serialVersionUID = 1L;
		private final double distanceToCenter;

		public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) {
			super(p.getValue(), p.getLocation());
			distanceToCenter = distancer.getDistance(center, getLocation(),
					weight);
		}

		public double getDistanceToCenter() {
			return distanceToCenter;
		}

		@Override
		public int compareTo(KdPoint<K> o) {
			return (int) Math.signum(o.distanceToCenter - distanceToCenter);
		}

		public boolean isFurtherThan(KdPoint<K> p) {
			return compareTo(p) == -1;
		}

		public boolean isFurtherThan(double distance) {
			return distanceToCenter > distance;
		}
	}

	public static abstract class Distancer {
		public double getDistance(double[] p1, double[] p2, double[] weight) {
			if (p1.length != p2.length)
				throw new IllegalArgumentException();
			return getPointDistance(p1, p2, weight);
		}

		public abstract double getPointDistance(double[] p1, double[] p2,
				double[] weight);

		public static class EuclidianDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.sqr(p1[i] - p2[i]) * weight[i];
				}
				return M.sqrt(result);
			}
		}

		public static class ManhattanDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.abs(p1[i] - p2[i]) * weight[i];
				}
				return result;
			}
		}
	}

	
}

I use this code to check:

package nat.tree;

import java.util.ArrayList;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;

import nat.tree.PRKdBucketTree.KdPoint;

public class Test {
	public static void main(String[] args) {
		final int numTest = 13000;
		final int clusterSize = 10;

		long linearTime, tree2time, tree3time;
		String[] answer = new String[100];

		System.out.println("Starting Bucket PR k-d tree performance test...");
		System.out.println("Generating points...");

		ArrayList<String> input = new ArrayList<String>();
		ArrayList<double[]> location = new ArrayList<double[]>();

		for (int i = 0; i < numTest; i++) {
			input.add(generateRandomString());
			double[] p = new double[3];
			p[0] = Math.random();
			p[1] = Math.random();
			p[2] = Math.random();
			location.add(p);
		}

		PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
		PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);

		for (int i = 0; i < numTest; i++) {
			tree2.addPoint(input.get(i), location.get(i));
			tree3.addPoint(input.get(i), location.get(i));
		}

		double[] center = new double[3];
		center[0] = Math.random();
		center[1] = Math.random();
		center[2] = Math.random();

		System.out.println("Data generated.");
		System.out.println("Performing linear search...");

		linearTime = -System.nanoTime();
		PriorityQueue<Compare> pq = new PriorityQueue<Compare>();

		for (int i = 0; i < numTest; i++) {
			double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
					location.get(i), new double[] { 1, 1, 1 });

			pq.add(new Compare(input.get(i), distance));
		}

		double distance = -1;
		for (int i = 0; i < clusterSize; i++) {
			if (distance == -1)
				distance = pq.peek().val;
			answer[i] = pq.poll().data;
		}
		linearTime += System.nanoTime();
		System.out.println("Linear search complete; time = "
				+ (linearTime / 1E9));
		System.out.println("Performing binary k-d tree search...");

		tree2time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree2r = tree2
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();

		tree2time += System.nanoTime();

		System.out.println("Distance = " + distance + "; "
				+ tree2a.iterator().next().getDistanceToCenter());

		System.out.println("Binary tree search complete; time = "
				+ (tree2time / 1E9));
		int correct = 0;
		int j = 0;
		for (PRKdBucketTree.KdPoint<String> p : tree2a) {
			System.out.print(p.getValue());
			System.out.print(" ");
			System.out.println(answer[j]);
			if (p.getValue().equals(answer[j++]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
	}

	private static class Compare implements Comparable<Compare> {
		String data;
		double val;

		@Override
		public int compareTo(Compare o) {
			return (int) -Math.signum(o.val - val);
		}

		public Compare(String data, double val) {
			this.data = data;
			this.val = val;
		}

	}

	private static String generateRandomString() {
		String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
		Random r = new Random();
		char[] buf = new char[15];

		for (int i = 0; i < buf.length; i++) {
			buf[i] = chars.charAt(r.nextInt(chars.length()));
		}

		return new String(buf);
	}
}

The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)

So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree() where is should be upperBound[dimension] and lowerBound[dimension] instead of i. » Nat | Talk » 16:30, 15 August 2009 (UTC)