Help:Help/Nat/Kd-Tree
Jump to navigation
Jump to search
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
package nat.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import nat.util.M;
/**
*
* Implementation of bucket PR k-d tree.
*
* @author Nat Pavasant
*
* @param <V>
* The type of data to store
*/
public class PRKdBucketTree<V> implements Serializable {
private static final long serialVersionUID = 1L;
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
private final PRKdBucketTree<V>[] children;
private final Queue<KdEntry<V>> data;
private final Distancer distancer;
private final double[] lowerBound, upperBound;
private final int[] numChildren;
private final int allDimensions;
private final int dimension;
private final int maxDepth, maxDensity;
private final double splitMedian;
private boolean isLeaf = true;
/**
* Create new Bucket PR k-d tree.
*
* @param allDimensions
* number of dimensions in the tree
* @param lowerBound
* the minimum value of the location of each dimension
* @param upperBound
* the maximum value of the location of each dimension
* @param numChildren
* number of children in each dimension
* @param maxDepth
* the max depth of the tree
* @param maxDensity
* size of bucket in each leaf
* @param distancer
* distance measurer
*/
@SuppressWarnings("unchecked")
public PRKdBucketTree(int allDimensions, double[] lowerBound,
double[] upperBound, int[] numChildren, int maxDepth,
int maxDensity, Distancer distancer) {
if (allDimensions < 1 || maxDensity < 1)
throw new IllegalArgumentException(
"Either dimension or density isn't positive integer.");
if (lowerBound.length != allDimensions
|| upperBound.length != allDimensions
|| numChildren.length != allDimensions)
throw new IllegalArgumentException(
"Either bounds or children amount is more or less than dimension count.");
for (double a : lowerBound) {
if (a < 0)
throw new IllegalArgumentException(
"Can't set lower bound to negative number.");
}
for (int i = 0; i < lowerBound.length; i++) {
if (lowerBound[i] > upperBound[i])
throw new IllegalArgumentException(
"Upper bound must have a value higer than lower bound.");
}
this.allDimensions = allDimensions;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.numChildren = numChildren;
this.distancer = distancer;
this.maxDepth = maxDepth;
this.maxDensity = maxDensity;
this.dimension = maxDepth % allDimensions;
this.data = new LinkedList<KdEntry<V>>();
this.children = new PRKdBucketTree[numChildren[this.dimension]];
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
/ numChildren[this.dimension];
}
// Another constructor
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double lowerBound, double upperBound, int numChildren,
int maxDepth, int maxDensity, Distancer distance) {
double[] lowerBounds = new double[dimension];
double[] upperBounds = new double[dimension];
int[] numChildrens = new int[dimension];
Arrays.fill(lowerBounds, lowerBound);
Arrays.fill(upperBounds, upperBound);
Arrays.fill(numChildrens, numChildren);
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
numChildrens, maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity,
Distancer distance) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren) {
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound) {
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
}
/**
* Add new point to the tree
*
* @param value
* the stored value
* @param location
* location of the point
* @return removed point if any
*/
public KdEntry<V> addPoint(V value, double[] location) {
if (location.length != allDimensions) {
throw new IllegalArgumentException(
"Provided location have either more or less dimensions than the tree.");
}
KdEntry<V> entry = new KdEntry<V>(value, location);
return addPoint(entry);
}
/**
* Get the n-nearest neighbor.
*
* @param size
* number of neighbors
* @param center
* center of the cluster
* @param weight
* weighting for the distancer
* @return
*/
public KdCluster<V> getNearestNeighbor(int size, double[] center,
double[] weight) {
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
nearestNeighborSearch(cluster);
return cluster;
}
/**
* Back-end implementation of the entry adding.
*
* @param entry
* new entry to the tree
* @return removed point if any
*/
private KdEntry<V> addPoint(KdEntry<V> entry) {
if (isLeaf) {
// Still has spaces, add the data
if (data.size() < maxDensity) {
data.add(entry);
return null;
}
// final leaf, unsplitable, remove element and add new one.
if (maxDepth <= 1) {
data.add(entry);
return data.poll();
}
// if we reached here, we need to split this leaf to a branch.
isLeaf = false;
for (KdEntry<V> p : data) {
passToChildren(p);
}
data.clear();
}
// we are branch, pass to children
return passToChildren(entry);
}
/**
* Perform n-nearest neighbor search
*
* @param cluster
* current working cluster
*/
private void nearestNeighborSearch(KdCluster<V> cluster) {
if (isLeaf) {
for (KdEntry<V> p : data) {
cluster.consider(p);
}
} else {
for (PRKdBucketTree<V> child : children) {
if (child != null && cluster.isViable(child))
child.nearestNeighborSearch(cluster);
}
}
}
/**
* Pass the entry to correct children
*
* @param p
* entry
* @return removed point if any
*/
private KdEntry<V> passToChildren(KdEntry<V> p) {
int i = getChildrenIndex(p.getLocation(dimension));
if (children[i] == null)
children[i] = createChildTree(i);
return children[i].addPoint(p);
}
/**
* Return index of the children which will contains data with that value.
*
* @param value
* the value
* @return index of children
*/
private int getChildrenIndex(double value) {
return (int) M.limit(0, Math.floor(value / splitMedian),
numChildren[dimension] - 1);
}
/**
* Create children tree with correct bound.
*
* @param i
* the index of the children
* @return created tree
*/
private PRKdBucketTree<V> createChildTree(int i) {
double[] upperBound = this.upperBound.clone();
double[] lowerBound = this.lowerBound.clone();
lowerBound[i] = splitMedian * i;
upperBound[i] = splitMedian * (i + 1);
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
numChildren, maxDepth - 1, maxDensity, distancer);
}
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
private final PriorityQueue<KdPoint<K>> points;
private final double[] center;
private final double[] weight;
private final int size;
private final Distancer distancer;
public KdCluster(int size, double[] center, double[] weight, Distancer distancer) {
points = new PriorityQueue<KdPoint<K>>();
this.size = size;
this.center = center;
this.weight = weight;
this.distancer = distancer;
}
public void consider(KdEntry<K> k) {
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
if (points.size() < size) {
points.add(p);
} else if (points.peek().isFurtherThan(p)) {
points.poll();
points.add(p);
}
}
public boolean isViable(PRKdBucketTree<K> tree) {
if (points.size() < size)
return true;
double[] testPoints = new double[center.length];
for (int i = 0; i < center.length; i++)
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
tree.upperBound[i]);
return points.peek().isFurtherThan(
distancer.getDistance(center, testPoints, weight));
}
@Override
public Iterator<KdPoint<K>> iterator() {
return points.iterator();
}
public Collection<KdPoint<K>> getValues() {
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
.size());
for (KdPoint<K> p : points) {
collect.add(p);
}
return collect;
}
}
private static class KdEntry<K> implements Serializable {
private static final long serialVersionUID = 1L;
private final K value;
private final double[] location;
public KdEntry(K value, double[] location) {
super();
this.value = value;
this.location = location;
}
public K getValue() {
return value;
}
public double[] getLocation() {
return location;
}
public double getLocation(int a) {
return location[a];
}
}
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
Comparable<KdPoint<K>> {
private static final long serialVersionUID = 1L;
private final double distanceToCenter;
public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) {
super(p.getValue(), p.getLocation());
distanceToCenter = distancer.getDistance(center, getLocation(),
weight);
}
public double getDistanceToCenter() {
return distanceToCenter;
}
@Override
public int compareTo(KdPoint<K> o) {
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
}
public boolean isFurtherThan(KdPoint<K> p) {
return compareTo(p) == -1;
}
public boolean isFurtherThan(double distance) {
return distanceToCenter > distance;
}
}
public static abstract class Distancer {
public double getDistance(double[] p1, double[] p2, double[] weight) {
if (p1.length != p2.length)
throw new IllegalArgumentException();
return getPointDistance(p1, p2, weight);
}
public abstract double getPointDistance(double[] p1, double[] p2,
double[] weight);
public static class EuclidianDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.sqr(p1[i] - p2[i]) * weight[i];
}
return M.sqrt(result);
}
}
public static class ManhattanDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.abs(p1[i] - p2[i]) * weight[i];
}
return result;
}
}
}
}
I use this code to check:
package nat.tree;
import java.util.ArrayList;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;
import nat.tree.PRKdBucketTree.KdPoint;
public class Test {
public static void main(String[] args) {
final int numTest = 13000;
final int clusterSize = 10;
long linearTime, tree2time, tree3time;
String[] answer = new String[100];
System.out.println("Starting Bucket PR k-d tree performance test...");
System.out.println("Generating points...");
ArrayList<String> input = new ArrayList<String>();
ArrayList<double[]> location = new ArrayList<double[]>();
for (int i = 0; i < numTest; i++) {
input.add(generateRandomString());
double[] p = new double[3];
p[0] = Math.random();
p[1] = Math.random();
p[2] = Math.random();
location.add(p);
}
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
for (int i = 0; i < numTest; i++) {
tree2.addPoint(input.get(i), location.get(i));
tree3.addPoint(input.get(i), location.get(i));
}
double[] center = new double[3];
center[0] = Math.random();
center[1] = Math.random();
center[2] = Math.random();
System.out.println("Data generated.");
System.out.println("Performing linear search...");
linearTime = -System.nanoTime();
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
for (int i = 0; i < numTest; i++) {
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
location.get(i), new double[] { 1, 1, 1 });
pq.add(new Compare(input.get(i), distance));
}
double distance = -1;
for (int i = 0; i < clusterSize; i++) {
if (distance == -1)
distance = pq.peek().val;
answer[i] = pq.poll().data;
}
linearTime += System.nanoTime();
System.out.println("Linear search complete; time = "
+ (linearTime / 1E9));
System.out.println("Performing binary k-d tree search...");
tree2time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree2r = tree2
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
tree2time += System.nanoTime();
System.out.println("Distance = " + distance + "; "
+ tree2a.iterator().next().getDistanceToCenter());
System.out.println("Binary tree search complete; time = "
+ (tree2time / 1E9));
int correct = 0;
int j = 0;
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
System.out.print(p.getValue());
System.out.print(" ");
System.out.println(answer[j]);
if (p.getValue().equals(answer[j++]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
}
private static class Compare implements Comparable<Compare> {
String data;
double val;
@Override
public int compareTo(Compare o) {
return (int) -Math.signum(o.val - val);
}
public Compare(String data, double val) {
this.data = data;
this.val = val;
}
}
private static String generateRandomString() {
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
Random r = new Random();
char[] buf = new char[15];
for (int i = 0; i < buf.length; i++) {
buf[i] = chars.charAt(r.nextInt(chars.length()));
}
return new String(buf);
}
}
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)