Help:Help/Nat/Kd-Tree
Jump to navigation
Jump to search
Here is my code that I have problem state in Talk:Kd-Tree. My M class can be found at User:Nat/Free code.
Tree
package nat.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import nat.util.M;
/**
*
* Implementation of bucket PR k-d tree.
*
* @author Nat Pavasant
*
* @param <V>
* The type of data to store
*/
public class PRKdBucketTree<V> implements Serializable {
private static final long serialVersionUID = 1L;
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
private final PRKdBucketTree<V>[] children;
private final Queue<KdEntry<V>> data;
private final Distancer distancer;
private final double[] lowerBound, upperBound;
private final int[] numChildren;
private final int allDimensions;
private final int dimension;
private final int maxDepth, maxDensity;
private final double splitMedian;
private boolean isLeaf = true;
/**
* Create new Bucket PR k-d tree.
*
* @param allDimensions
* number of dimensions in the tree
* @param lowerBound
* the minimum value of the location of each dimension
* @param upperBound
* the maximum value of the location of each dimension
* @param numChildren
* number of children in each dimension
* @param maxDepth
* the max depth of the tree
* @param maxDensity
* size of bucket in each leaf
* @param distancer
* distance measurer
*/
@SuppressWarnings("unchecked")
public PRKdBucketTree(int allDimensions, double[] lowerBound,
double[] upperBound, int[] numChildren, int maxDepth,
int maxDensity, Distancer distancer) {
if (allDimensions < 1 || maxDensity < 1)
throw new IllegalArgumentException(
"Either dimension or density isn't positive integer.");
if (lowerBound.length != allDimensions
|| upperBound.length != allDimensions
|| numChildren.length != allDimensions)
throw new IllegalArgumentException(
"Either bounds or children amount is more or less than dimension count.");
for (double a : lowerBound) {
if (a < 0)
throw new IllegalArgumentException(
"Can't set lower bound to negative number.");
}
for (int i = 0; i < lowerBound.length; i++) {
if (lowerBound[i] > upperBound[i])
throw new IllegalArgumentException(
"Upper bound must have a value higer than lower bound.");
}
this.allDimensions = allDimensions;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.numChildren = numChildren;
this.distancer = distancer;
this.maxDepth = maxDepth;
this.maxDensity = maxDensity;
this.dimension = maxDepth % allDimensions;
this.data = new LinkedList<KdEntry<V>>();
this.children = new PRKdBucketTree[numChildren[this.dimension]];
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
/ numChildren[this.dimension];
}
// Another constructor
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double lowerBound, double upperBound, int numChildren,
int maxDepth, int maxDensity, Distancer distance) {
double[] lowerBounds = new double[dimension];
double[] upperBounds = new double[dimension];
int[] numChildrens = new int[dimension];
Arrays.fill(lowerBounds, lowerBound);
Arrays.fill(upperBounds, upperBound);
Arrays.fill(numChildrens, numChildren);
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
numChildrens, maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity,
Distancer distance) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren) {
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound) {
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
}
/**
* Add new point to the tree
*
* @param value
* the stored value
* @param location
* location of the point
* @return removed point if any
*/
public KdEntry<V> addPoint(V value, double[] location) {
if (location.length != allDimensions) {
throw new IllegalArgumentException(
"Provided location have either more or less dimensions than the tree.");
}
KdEntry<V> entry = new KdEntry<V>(value, location);
return addPoint(entry);
}
/**
* Get the n-nearest neighbor.
*
* @param size
* number of neighbors
* @param center
* center of the cluster
* @param weight
* weighting for the distancer
* @return
*/
public KdCluster<V> getNearestNeighbor(int size, double[] center,
double[] weight) {
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
nearestNeighborSearch(cluster);
return cluster;
}
/**
* Back-end implementation of the entry adding.
*
* @param entry
* new entry to the tree
* @return removed point if any
*/
private KdEntry<V> addPoint(KdEntry<V> entry) {
if (isLeaf) {
// Still has spaces, add the data
if (data.size() < maxDensity) {
data.add(entry);
return null;
}
// final leaf, unsplitable, remove element and add new one.
if (maxDepth <= 1) {
data.add(entry);
return data.poll();
}
// if we reached here, we need to split this leaf to a branch.
isLeaf = false;
for (KdEntry<V> p : data) {
passToChildren(p);
}
data.clear();
}
// we are branch, pass to children
return passToChildren(entry);
}
/**
* Perform n-nearest neighbor search
*
* @param cluster
* current working cluster
*/
private void nearestNeighborSearch(KdCluster<V> cluster) {
if (isLeaf) {
for (KdEntry<V> p : data) {
cluster.consider(p);
}
} else {
for (PRKdBucketTree<V> child : children) {
if (child != null && cluster.isViable(child))
child.nearestNeighborSearch(cluster);
}
}
}
/**
* Pass the entry to correct children
*
* @param p
* entry
* @return removed point if any
*/
private KdEntry<V> passToChildren(KdEntry<V> p) {
int i = getChildrenIndex(p.getLocation(dimension));
if (children[i] == null)
children[i] = createChildTree(i);
return children[i].addPoint(p);
}
/**
* Return index of the children which will contains data with that value.
*
* @param value
* the value
* @return index of children
*/
private int getChildrenIndex(double value) {
return (int) M.limit(0, Math.floor(value / splitMedian),
numChildren[dimension] - 1);
}
/**
* Create children tree with correct bound.
*
* @param i
* the index of the children
* @return created tree
*/
private PRKdBucketTree<V> createChildTree(int i) {
double[] upperBound = this.upperBound.clone();
double[] lowerBound = this.lowerBound.clone();
lowerBound[dimension] = splitMedian * i;
upperBound[dimension] = splitMedian * (i + 1);
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
numChildren, maxDepth - 1, maxDensity, distancer);
}
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
private final PriorityQueue<KdPoint<K>> points;
private final double[] center;
private final double[] weight;
private final int size;
private final Distancer distancer;
public KdCluster(int size, double[] center, double[] weight,
Distancer distancer) {
points = new PriorityQueue<KdPoint<K>>();
this.size = size;
this.center = center;
this.weight = weight;
this.distancer = distancer;
}
public void consider(KdEntry<K> k) {
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
if (points.size() < size) {
points.add(p);
System.out.println(points.toString());
} else if (points.peek().getDistanceToCenter() > p
.getDistanceToCenter()) {
System.out.println(points.toString());
points.poll();
points.add(p);
}
}
public boolean isViable(PRKdBucketTree<K> tree) {
if (points.size() < size)
return true;
double[] testPoints = new double[center.length];
for (int i = 0; i < center.length; i++)
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
tree.upperBound[i]);
return points.peek().getDistanceToCenter() > distancer.getDistance(
center, testPoints, weight);
}
@Override
public Iterator<KdPoint<K>> iterator() {
return points.iterator();
}
public Collection<KdPoint<K>> getValues() {
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
.size());
for (KdPoint<K> p : points) {
collect.add(p);
}
return collect;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(center);
result = prime * result
+ ((distancer == null) ? 0 : distancer.hashCode());
result = prime * result
+ ((points == null) ? 0 : points.hashCode());
result = prime * result + size;
result = prime * result + Arrays.hashCode(weight);
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof KdCluster))
return false;
KdCluster other = (KdCluster) obj;
if (!Arrays.equals(center, other.center))
return false;
if (distancer == null) {
if (other.distancer != null)
return false;
} else if (!distancer.equals(other.distancer))
return false;
if (points == null) {
if (other.points != null)
return false;
} else if (!points.equals(other.points))
return false;
if (size != other.size)
return false;
if (!Arrays.equals(weight, other.weight))
return false;
return true;
}
}
private static class KdEntry<K> implements Serializable {
private static final long serialVersionUID = 1L;
private final K value;
private final double[] location;
public KdEntry(K value, double[] location) {
super();
this.value = value;
this.location = location;
}
public K getValue() {
return value;
}
public double[] getLocation() {
return location;
}
public double getLocation(int a) {
return location[a];
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(location);
result = prime * result + ((value == null) ? 0 : value.hashCode());
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof KdEntry))
return false;
KdEntry other = (KdEntry) obj;
if (!Arrays.equals(location, other.location))
return false;
if (value == null) {
if (other.value != null)
return false;
} else if (!value.equals(other.value))
return false;
return true;
}
}
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
Comparable<KdPoint<K>> {
private static final long serialVersionUID = 1L;
private final double distanceToCenter;
public KdPoint(KdEntry<K> p, double[] center, double[] weight,
Distancer distancer) {
super(p.getValue(), p.getLocation());
distanceToCenter = distancer.getDistance(center, getLocation(),
weight);
}
public double getDistanceToCenter() {
return distanceToCenter;
}
@Override
public String toString() {
return (new Double(distanceToCenter)).toString();
}
@Override
public int compareTo(KdPoint<K> o) {
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
}
@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
long temp;
temp = Double.doubleToLongBits(distanceToCenter);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (!super.equals(obj))
return false;
if (!(obj instanceof KdPoint))
return false;
KdPoint other = (KdPoint) obj;
if (Double.doubleToLongBits(distanceToCenter) != Double
.doubleToLongBits(other.distanceToCenter))
return false;
return true;
}
}
public static abstract class Distancer {
public double getDistance(double[] p1, double[] p2, double[] weight) {
if (p1.length != p2.length)
throw new IllegalArgumentException();
return getPointDistance(p1, p2, weight);
}
public abstract double getPointDistance(double[] p1, double[] p2,
double[] weight);
public static class EuclidianDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.sqr(p1[i] - p2[i]) * weight[i];
}
return M.sqrt(result);
}
}
public static class ManhattanDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.abs(p1[i] - p2[i]) * weight[i];
}
return result;
}
}
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + allDimensions;
result = prime * result + Arrays.hashCode(children);
result = prime * result + ((data == null) ? 0 : data.hashCode());
result = prime * result + dimension;
result = prime * result + (isLeaf ? 1231 : 1237);
result = prime * result + Arrays.hashCode(lowerBound);
result = prime * result + maxDensity;
result = prime * result + maxDepth;
result = prime * result + Arrays.hashCode(numChildren);
long temp;
temp = Double.doubleToLongBits(splitMedian);
result = prime * result + (int) (temp ^ (temp >>> 32));
result = prime * result + Arrays.hashCode(upperBound);
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof PRKdBucketTree))
return false;
PRKdBucketTree other = (PRKdBucketTree) obj;
if (allDimensions != other.allDimensions)
return false;
if (!Arrays.equals(children, other.children))
return false;
if (data == null) {
if (other.data != null)
return false;
} else if (!data.equals(other.data))
return false;
if (dimension != other.dimension)
return false;
if (isLeaf != other.isLeaf)
return false;
if (!Arrays.equals(lowerBound, other.lowerBound))
return false;
if (maxDensity != other.maxDensity)
return false;
if (maxDepth != other.maxDepth)
return false;
if (!Arrays.equals(numChildren, other.numChildren))
return false;
if (Double.doubleToLongBits(splitMedian) != Double
.doubleToLongBits(other.splitMedian))
return false;
if (!Arrays.equals(upperBound, other.upperBound))
return false;
return true;
}
}
Tester
package nat.tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;
public class Test {
public static void main(String[] args) {
final int numTest = 100;
final int clusterSize = 15;
long linearTime, tree2time, tree3time, tree4time;
String[] answer = new String[clusterSize];
System.out.println("Starting Bucket PR k-d tree performance test...");
System.out.println("Generating points...");
ArrayList<String> input = new ArrayList<String>();
ArrayList<double[]> location = new ArrayList<double[]>();
for (int i = 0; i < numTest; i++) {
input.add(generateRandomString());
double[] p = new double[3];
p[0] = Math.random();
p[1] = Math.random();
p[2] = Math.random();
location.add(p);
}
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4);
for (int i = 0; i < numTest; i++) {
tree2.addPoint(input.get(i), location.get(i));
tree3.addPoint(input.get(i), location.get(i));
tree4.addPoint(input.get(i), location.get(i));
}
double[] center = new double[3];
center[0] = Math.random();
center[1] = Math.random();
center[2] = Math.random();
System.out.println("Data generated.");
System.out.println("Performing linear search...");
linearTime = -System.nanoTime();
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
for (int i = 0; i < numTest; i++) {
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
location.get(i), new double[] { 1, 1, 1 });
pq.add(new Compare(input.get(i), distance));
}
double distance = -1;
for (int i = 0; i < clusterSize; i++) {
if (distance == -1)
distance = pq.peek().val;
answer[i] = pq.poll().data;
}
linearTime += System.nanoTime();
Arrays.sort(answer);
System.out.println("Linear search complete; time = "
+ (linearTime / 1E9));
System.out.println("Performing binary k-d tree search...");
tree2time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree2r = tree2
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
tree2time += System.nanoTime();
System.out.println("Binary tree search complete; time = "
+ (tree2time / 1E9));
int correct = 0;
int j = 0;
String[] tree2o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
tree2o[j++] = p.getValue();
}
Arrays.sort(tree2o);
for (int i = 0; i < tree2o.length; i++) {
if (tree2o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
System.out.println("Performing ternary k-d tree search...");
tree3time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree3r = tree3
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues();
tree3time += System.nanoTime();
System.out.println("Ternary tree search complete; time = "
+ (tree3time / 1E9));
correct = 0;
j = 0;
String[] tree3o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree3a) {
tree3o[j++] = p.getValue();
}
Arrays.sort(tree3o);
for (int i = 0; i < tree3o.length; i++) {
if (tree3o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
System.out.println("Performing quaternary k-d tree search...");
tree4time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree4r = tree4
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues();
tree4time += System.nanoTime();
System.out.println("Quaternary tree search complete; time = "
+ (tree4time / 1E9));
correct = 0;
j = 0;
String[] tree4o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree4a) {
tree4o[j++] = p.getValue();
}
Arrays.sort(tree4o);
for (int i = 0; i < tree4o.length; i++) {
if (tree4o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
}
private static class Compare implements Comparable<Compare> {
String data;
double val;
@Override
public int compareTo(Compare o) {
return (int) -Math.signum(o.val - val);
}
public Compare(String data, double val) {
this.data = data;
this.val = val;
}
}
private static String generateRandomString() {
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
Random r = new Random();
char[] buf = new char[25];
for (int i = 0; i < buf.length; i++) {
buf[i] = chars.charAt(r.nextInt(chars.length()));
}
return new String(buf);
}
}
Old
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
package nat.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import nat.util.M;
/**
*
* Implementation of bucket PR k-d tree.
*
* @author Nat Pavasant
*
* @param <V>
* The type of data to store
*/
public class PRKdBucketTree<V> implements Serializable {
private static final long serialVersionUID = 1L;
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
private final PRKdBucketTree<V>[] children;
private final Queue<KdEntry<V>> data;
private final Distancer distancer;
private final double[] lowerBound, upperBound;
private final int[] numChildren;
private final int allDimensions;
private final int dimension;
private final int maxDepth, maxDensity;
private final double splitMedian;
private boolean isLeaf = true;
/**
* Create new Bucket PR k-d tree.
*
* @param allDimensions
* number of dimensions in the tree
* @param lowerBound
* the minimum value of the location of each dimension
* @param upperBound
* the maximum value of the location of each dimension
* @param numChildren
* number of children in each dimension
* @param maxDepth
* the max depth of the tree
* @param maxDensity
* size of bucket in each leaf
* @param distancer
* distance measurer
*/
@SuppressWarnings("unchecked")
public PRKdBucketTree(int allDimensions, double[] lowerBound,
double[] upperBound, int[] numChildren, int maxDepth,
int maxDensity, Distancer distancer) {
if (allDimensions < 1 || maxDensity < 1)
throw new IllegalArgumentException(
"Either dimension or density isn't positive integer.");
if (lowerBound.length != allDimensions
|| upperBound.length != allDimensions
|| numChildren.length != allDimensions)
throw new IllegalArgumentException(
"Either bounds or children amount is more or less than dimension count.");
for (double a : lowerBound) {
if (a < 0)
throw new IllegalArgumentException(
"Can't set lower bound to negative number.");
}
for (int i = 0; i < lowerBound.length; i++) {
if (lowerBound[i] > upperBound[i])
throw new IllegalArgumentException(
"Upper bound must have a value higer than lower bound.");
}
this.allDimensions = allDimensions;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.numChildren = numChildren;
this.distancer = distancer;
this.maxDepth = maxDepth;
this.maxDensity = maxDensity;
this.dimension = maxDepth % allDimensions;
this.data = new LinkedList<KdEntry<V>>();
this.children = new PRKdBucketTree[numChildren[this.dimension]];
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
/ numChildren[this.dimension];
}
// Another constructor
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double lowerBound, double upperBound, int numChildren,
int maxDepth, int maxDensity, Distancer distance) {
double[] lowerBounds = new double[dimension];
double[] upperBounds = new double[dimension];
int[] numChildrens = new int[dimension];
Arrays.fill(lowerBounds, lowerBound);
Arrays.fill(upperBounds, upperBound);
Arrays.fill(numChildrens, numChildren);
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
numChildrens, maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity,
Distancer distance) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren) {
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound) {
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
}
/**
* Add new point to the tree
*
* @param value
* the stored value
* @param location
* location of the point
* @return removed point if any
*/
public KdEntry<V> addPoint(V value, double[] location) {
if (location.length != allDimensions) {
throw new IllegalArgumentException(
"Provided location have either more or less dimensions than the tree.");
}
KdEntry<V> entry = new KdEntry<V>(value, location);
return addPoint(entry);
}
/**
* Get the n-nearest neighbor.
*
* @param size
* number of neighbors
* @param center
* center of the cluster
* @param weight
* weighting for the distancer
* @return
*/
public KdCluster<V> getNearestNeighbor(int size, double[] center,
double[] weight) {
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
nearestNeighborSearch(cluster);
return cluster;
}
/**
* Back-end implementation of the entry adding.
*
* @param entry
* new entry to the tree
* @return removed point if any
*/
private KdEntry<V> addPoint(KdEntry<V> entry) {
if (isLeaf) {
// Still has spaces, add the data
if (data.size() < maxDensity) {
data.add(entry);
return null;
}
// final leaf, unsplitable, remove element and add new one.
if (maxDepth <= 1) {
data.add(entry);
return data.poll();
}
// if we reached here, we need to split this leaf to a branch.
isLeaf = false;
for (KdEntry<V> p : data) {
passToChildren(p);
}
data.clear();
}
// we are branch, pass to children
return passToChildren(entry);
}
/**
* Perform n-nearest neighbor search
*
* @param cluster
* current working cluster
*/
private void nearestNeighborSearch(KdCluster<V> cluster) {
if (isLeaf) {
for (KdEntry<V> p : data) {
cluster.consider(p);
}
} else {
for (PRKdBucketTree<V> child : children) {
if (child != null && cluster.isViable(child))
child.nearestNeighborSearch(cluster);
}
}
}
/**
* Pass the entry to correct children
*
* @param p
* entry
* @return removed point if any
*/
private KdEntry<V> passToChildren(KdEntry<V> p) {
int i = getChildrenIndex(p.getLocation(dimension));
if (children[i] == null)
children[i] = createChildTree(i);
return children[i].addPoint(p);
}
/**
* Return index of the children which will contains data with that value.
*
* @param value
* the value
* @return index of children
*/
private int getChildrenIndex(double value) {
return (int) M.limit(0, Math.floor(value / splitMedian),
numChildren[dimension] - 1);
}
/**
* Create children tree with correct bound.
*
* @param i
* the index of the children
* @return created tree
*/
private PRKdBucketTree<V> createChildTree(int i) {
double[] upperBound = this.upperBound.clone();
double[] lowerBound = this.lowerBound.clone();
lowerBound[i] = splitMedian * i;
upperBound[i] = splitMedian * (i + 1);
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
numChildren, maxDepth - 1, maxDensity, distancer);
}
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
private final PriorityQueue<KdPoint<K>> points;
private final double[] center;
private final double[] weight;
private final int size;
private final Distancer distancer;
public KdCluster(int size, double[] center, double[] weight, Distancer distancer) {
points = new PriorityQueue<KdPoint<K>>();
this.size = size;
this.center = center;
this.weight = weight;
this.distancer = distancer;
}
public void consider(KdEntry<K> k) {
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
if (points.size() < size) {
points.add(p);
} else if (points.peek().isFurtherThan(p)) {
points.poll();
points.add(p);
}
}
public boolean isViable(PRKdBucketTree<K> tree) {
if (points.size() < size)
return true;
double[] testPoints = new double[center.length];
for (int i = 0; i < center.length; i++)
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
tree.upperBound[i]);
return points.peek().isFurtherThan(
distancer.getDistance(center, testPoints, weight));
}
@Override
public Iterator<KdPoint<K>> iterator() {
return points.iterator();
}
public Collection<KdPoint<K>> getValues() {
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
.size());
for (KdPoint<K> p : points) {
collect.add(p);
}
return collect;
}
}
private static class KdEntry<K> implements Serializable {
private static final long serialVersionUID = 1L;
private final K value;
private final double[] location;
public KdEntry(K value, double[] location) {
super();
this.value = value;
this.location = location;
}
public K getValue() {
return value;
}
public double[] getLocation() {
return location;
}
public double getLocation(int a) {
return location[a];
}
}
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
Comparable<KdPoint<K>> {
private static final long serialVersionUID = 1L;
private final double distanceToCenter;
public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) {
super(p.getValue(), p.getLocation());
distanceToCenter = distancer.getDistance(center, getLocation(),
weight);
}
public double getDistanceToCenter() {
return distanceToCenter;
}
@Override
public int compareTo(KdPoint<K> o) {
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
}
public boolean isFurtherThan(KdPoint<K> p) {
return compareTo(p) == -1;
}
public boolean isFurtherThan(double distance) {
return distanceToCenter > distance;
}
}
public static abstract class Distancer {
public double getDistance(double[] p1, double[] p2, double[] weight) {
if (p1.length != p2.length)
throw new IllegalArgumentException();
return getPointDistance(p1, p2, weight);
}
public abstract double getPointDistance(double[] p1, double[] p2,
double[] weight);
public static class EuclidianDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.sqr(p1[i] - p2[i]) * weight[i];
}
return M.sqrt(result);
}
}
public static class ManhattanDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.abs(p1[i] - p2[i]) * weight[i];
}
return result;
}
}
}
}
I use this code to check:
package nat.tree;
import java.util.ArrayList;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;
import nat.tree.PRKdBucketTree.KdPoint;
public class Test {
public static void main(String[] args) {
final int numTest = 13000;
final int clusterSize = 10;
long linearTime, tree2time, tree3time;
String[] answer = new String[100];
System.out.println("Starting Bucket PR k-d tree performance test...");
System.out.println("Generating points...");
ArrayList<String> input = new ArrayList<String>();
ArrayList<double[]> location = new ArrayList<double[]>();
for (int i = 0; i < numTest; i++) {
input.add(generateRandomString());
double[] p = new double[3];
p[0] = Math.random();
p[1] = Math.random();
p[2] = Math.random();
location.add(p);
}
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
for (int i = 0; i < numTest; i++) {
tree2.addPoint(input.get(i), location.get(i));
tree3.addPoint(input.get(i), location.get(i));
}
double[] center = new double[3];
center[0] = Math.random();
center[1] = Math.random();
center[2] = Math.random();
System.out.println("Data generated.");
System.out.println("Performing linear search...");
linearTime = -System.nanoTime();
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
for (int i = 0; i < numTest; i++) {
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
location.get(i), new double[] { 1, 1, 1 });
pq.add(new Compare(input.get(i), distance));
}
double distance = -1;
for (int i = 0; i < clusterSize; i++) {
if (distance == -1)
distance = pq.peek().val;
answer[i] = pq.poll().data;
}
linearTime += System.nanoTime();
System.out.println("Linear search complete; time = "
+ (linearTime / 1E9));
System.out.println("Performing binary k-d tree search...");
tree2time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree2r = tree2
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
tree2time += System.nanoTime();
System.out.println("Distance = " + distance + "; "
+ tree2a.iterator().next().getDistanceToCenter());
System.out.println("Binary tree search complete; time = "
+ (tree2time / 1E9));
int correct = 0;
int j = 0;
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
System.out.print(p.getValue());
System.out.print(" ");
System.out.println(answer[j]);
if (p.getValue().equals(answer[j++]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
}
private static class Compare implements Comparable<Compare> {
String data;
double val;
@Override
public int compareTo(Compare o) {
return (int) -Math.signum(o.val - val);
}
public Compare(String data, double val) {
this.data = data;
this.val = val;
}
}
private static String generateRandomString() {
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
Random r = new Random();
char[] buf = new char[15];
for (int i = 0; i < buf.length; i++) {
buf[i] = chars.charAt(r.nextInt(chars.length()));
}
return new String(buf);
}
}
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)
So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree() where is should be upperBound[dimension] and lowerBound[dimension] instead of i. » Nat | Talk » 16:30, 15 August 2009 (UTC)