Help:Help/Nat/Kd-Tree

From Robowiki
< Help:Help
Revision as of 18:12, 15 August 2009 by Nat (talk | contribs) (latest code that I've problem)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

Here is my code that I have problem state in Talk:Kd-Tree. My M class can be found at User:Nat/Free code.

Tree

package nat.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;

import nat.util.M;

/**
 * 
 * Implementation of bucket PR k-d tree.
 * 
 * @author Nat Pavasant
 * 
 * @param <V>
 *            The type of data to store
 */
public class PRKdBucketTree<V> implements Serializable {

	private static final long serialVersionUID = 1L;

	public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
	public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();

	private final PRKdBucketTree<V>[] children;
	private final Queue<KdEntry<V>> data;
	private final Distancer distancer;
	private final double[] lowerBound, upperBound;
	private final int[] numChildren;
	private final int allDimensions;
	private final int dimension;
	private final int maxDepth, maxDensity;
	private final double splitMedian;

	private boolean isLeaf = true;

	/**
	 * Create new Bucket PR k-d tree.
	 * 
	 * @param allDimensions
	 *            number of dimensions in the tree
	 * @param lowerBound
	 *            the minimum value of the location of each dimension
	 * @param upperBound
	 *            the maximum value of the location of each dimension
	 * @param numChildren
	 *            number of children in each dimension
	 * @param maxDepth
	 *            the max depth of the tree
	 * @param maxDensity
	 *            size of bucket in each leaf
	 * @param distancer
	 *            distance measurer
	 */
	@SuppressWarnings("unchecked")
	public PRKdBucketTree(int allDimensions, double[] lowerBound,
			double[] upperBound, int[] numChildren, int maxDepth,
			int maxDensity, Distancer distancer) {
		if (allDimensions < 1 || maxDensity < 1)
			throw new IllegalArgumentException(
					"Either dimension or density isn't positive integer.");

		if (lowerBound.length != allDimensions
				|| upperBound.length != allDimensions
				|| numChildren.length != allDimensions)
			throw new IllegalArgumentException(
					"Either bounds or children amount is more or less than dimension count.");

		for (double a : lowerBound) {
			if (a < 0)
				throw new IllegalArgumentException(
						"Can't set lower bound to negative number.");
		}

		for (int i = 0; i < lowerBound.length; i++) {
			if (lowerBound[i] > upperBound[i])
				throw new IllegalArgumentException(
						"Upper bound must have a value higer than lower bound.");
		}

		this.allDimensions = allDimensions;
		this.lowerBound = lowerBound;
		this.upperBound = upperBound;
		this.numChildren = numChildren;
		this.distancer = distancer;
		this.maxDepth = maxDepth;
		this.maxDensity = maxDensity;
		this.dimension = maxDepth % allDimensions;
		this.data = new LinkedList<KdEntry<V>>();
		this.children = new PRKdBucketTree[numChildren[this.dimension]];
		this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
				/ numChildren[this.dimension];
	}

	// Another constructor
	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double lowerBound, double upperBound, int numChildren,
			int maxDepth, int maxDensity, Distancer distance) {
		double[] lowerBounds = new double[dimension];
		double[] upperBounds = new double[dimension];
		int[] numChildrens = new int[dimension];
		Arrays.fill(lowerBounds, lowerBound);
		Arrays.fill(upperBounds, upperBound);
		Arrays.fill(numChildrens, numChildren);
		return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
				numChildrens, maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity,
			Distancer distance) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
				maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren) {
		return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
				EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound) {
		return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
		return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
	}

	/**
	 * Add new point to the tree
	 * 
	 * @param value
	 *            the stored value
	 * @param location
	 *            location of the point
	 * @return removed point if any
	 */
	public KdEntry<V> addPoint(V value, double[] location) {
		if (location.length != allDimensions) {
			throw new IllegalArgumentException(
					"Provided location have either more or less dimensions than the tree.");
		}

		KdEntry<V> entry = new KdEntry<V>(value, location);

		return addPoint(entry);
	}

	/**
	 * Get the n-nearest neighbor.
	 * 
	 * @param size
	 *            number of neighbors
	 * @param center
	 *            center of the cluster
	 * @param weight
	 *            weighting for the distancer
	 * @return
	 */
	public KdCluster<V> getNearestNeighbor(int size, double[] center,
			double[] weight) {
		KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
		nearestNeighborSearch(cluster);
		return cluster;
	}

	/**
	 * Back-end implementation of the entry adding.
	 * 
	 * @param entry
	 *            new entry to the tree
	 * @return removed point if any
	 */
	private KdEntry<V> addPoint(KdEntry<V> entry) {

		if (isLeaf) {

			// Still has spaces, add the data
			if (data.size() < maxDensity) {
				data.add(entry);
				return null;
			}

			// final leaf, unsplitable, remove element and add new one.
			if (maxDepth <= 1) {
				data.add(entry);
				return data.poll();
			}

			// if we reached here, we need to split this leaf to a branch.
			isLeaf = false;
			for (KdEntry<V> p : data) {
				passToChildren(p);
			}

			data.clear();
		}

		// we are branch, pass to children
		return passToChildren(entry);
	}

	/**
	 * Perform n-nearest neighbor search
	 * 
	 * @param cluster
	 *            current working cluster
	 */
	private void nearestNeighborSearch(KdCluster<V> cluster) {
		if (isLeaf) {
			for (KdEntry<V> p : data) {
				cluster.consider(p);
			}
		} else {
			for (PRKdBucketTree<V> child : children) {
				if (child != null && cluster.isViable(child))
					child.nearestNeighborSearch(cluster);
			}
		}
	}

	/**
	 * Pass the entry to correct children
	 * 
	 * @param p
	 *            entry
	 * @return removed point if any
	 */
	private KdEntry<V> passToChildren(KdEntry<V> p) {
		int i = getChildrenIndex(p.getLocation(dimension));

		if (children[i] == null)
			children[i] = createChildTree(i);

		return children[i].addPoint(p);
	}

	/**
	 * Return index of the children which will contains data with that value.
	 * 
	 * @param value
	 *            the value
	 * @return index of children
	 */
	private int getChildrenIndex(double value) {
		return (int) M.limit(0, Math.floor(value / splitMedian),
				numChildren[dimension] - 1);
	}

	/**
	 * Create children tree with correct bound.
	 * 
	 * @param i
	 *            the index of the children
	 * @return created tree
	 */
	private PRKdBucketTree<V> createChildTree(int i) {
		double[] upperBound = this.upperBound.clone();
		double[] lowerBound = this.lowerBound.clone();
		lowerBound[dimension] = splitMedian * i;
		upperBound[dimension] = splitMedian * (i + 1);
		return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
				numChildren, maxDepth - 1, maxDensity, distancer);
	}

	public static class KdCluster<K> implements Iterable<KdPoint<K>> {
		private final PriorityQueue<KdPoint<K>> points;
		private final double[] center;
		private final double[] weight;
		private final int size;
		private final Distancer distancer;

		public KdCluster(int size, double[] center, double[] weight,
				Distancer distancer) {
			points = new PriorityQueue<KdPoint<K>>();
			this.size = size;
			this.center = center;
			this.weight = weight;
			this.distancer = distancer;
		}

		public void consider(KdEntry<K> k) {
			KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);

			if (points.size() < size) {
				points.add(p);
				System.out.println(points.toString());
			} else if (points.peek().getDistanceToCenter() > p
					.getDistanceToCenter()) {
				System.out.println(points.toString());
				points.poll();
				points.add(p);
			}
		}

		public boolean isViable(PRKdBucketTree<K> tree) {

			if (points.size() < size)
				return true;

			double[] testPoints = new double[center.length];

			for (int i = 0; i < center.length; i++)
				testPoints[i] = M.limit(tree.lowerBound[i], center[i],
						tree.upperBound[i]);

			return points.peek().getDistanceToCenter() > distancer.getDistance(
					center, testPoints, weight);
		}

		@Override
		public Iterator<KdPoint<K>> iterator() {
			return points.iterator();
		}

		public Collection<KdPoint<K>> getValues() {
			Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
					.size());

			for (KdPoint<K> p : points) {
				collect.add(p);
			}

			return collect;
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + Arrays.hashCode(center);
			result = prime * result
					+ ((distancer == null) ? 0 : distancer.hashCode());
			result = prime * result
					+ ((points == null) ? 0 : points.hashCode());
			result = prime * result + size;
			result = prime * result + Arrays.hashCode(weight);
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (!(obj instanceof KdCluster))
				return false;
			KdCluster other = (KdCluster) obj;
			if (!Arrays.equals(center, other.center))
				return false;
			if (distancer == null) {
				if (other.distancer != null)
					return false;
			} else if (!distancer.equals(other.distancer))
				return false;
			if (points == null) {
				if (other.points != null)
					return false;
			} else if (!points.equals(other.points))
				return false;
			if (size != other.size)
				return false;
			if (!Arrays.equals(weight, other.weight))
				return false;
			return true;
		}
	}

	private static class KdEntry<K> implements Serializable {
		private static final long serialVersionUID = 1L;

		private final K value;
		private final double[] location;

		public KdEntry(K value, double[] location) {
			super();
			this.value = value;
			this.location = location;
		}

		public K getValue() {
			return value;
		}

		public double[] getLocation() {
			return location;
		}

		public double getLocation(int a) {
			return location[a];
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + Arrays.hashCode(location);
			result = prime * result + ((value == null) ? 0 : value.hashCode());
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (!(obj instanceof KdEntry))
				return false;
			KdEntry other = (KdEntry) obj;
			if (!Arrays.equals(location, other.location))
				return false;
			if (value == null) {
				if (other.value != null)
					return false;
			} else if (!value.equals(other.value))
				return false;
			return true;
		}
	}

	public static class KdPoint<K> extends KdEntry<K> implements Serializable,
			Comparable<KdPoint<K>> {
		private static final long serialVersionUID = 1L;
		private final double distanceToCenter;

		public KdPoint(KdEntry<K> p, double[] center, double[] weight,
				Distancer distancer) {
			super(p.getValue(), p.getLocation());
			distanceToCenter = distancer.getDistance(center, getLocation(),
					weight);
		}

		public double getDistanceToCenter() {
			return distanceToCenter;
		}
		
		@Override
		public String toString() {
			return (new Double(distanceToCenter)).toString();
		}

		@Override
		public int compareTo(KdPoint<K> o) {
			return (int) Math.signum(o.distanceToCenter - distanceToCenter);
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = super.hashCode();
			long temp;
			temp = Double.doubleToLongBits(distanceToCenter);
			result = prime * result + (int) (temp ^ (temp >>> 32));
			return result;
		}

		@SuppressWarnings("unchecked")
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (!super.equals(obj))
				return false;
			if (!(obj instanceof KdPoint))
				return false;
			KdPoint other = (KdPoint) obj;
			if (Double.doubleToLongBits(distanceToCenter) != Double
					.doubleToLongBits(other.distanceToCenter))
				return false;
			return true;
		}
	}

	public static abstract class Distancer {
		public double getDistance(double[] p1, double[] p2, double[] weight) {
			if (p1.length != p2.length)
				throw new IllegalArgumentException();
			return getPointDistance(p1, p2, weight);
		}

		public abstract double getPointDistance(double[] p1, double[] p2,
				double[] weight);

		public static class EuclidianDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.sqr(p1[i] - p2[i]) * weight[i];
				}
				return M.sqrt(result);
			}
		}

		public static class ManhattanDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.abs(p1[i] - p2[i]) * weight[i];
				}
				return result;
			}
		}
	}

	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + allDimensions;
		result = prime * result + Arrays.hashCode(children);
		result = prime * result + ((data == null) ? 0 : data.hashCode());
		result = prime * result + dimension;
		result = prime * result + (isLeaf ? 1231 : 1237);
		result = prime * result + Arrays.hashCode(lowerBound);
		result = prime * result + maxDensity;
		result = prime * result + maxDepth;
		result = prime * result + Arrays.hashCode(numChildren);
		long temp;
		temp = Double.doubleToLongBits(splitMedian);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		result = prime * result + Arrays.hashCode(upperBound);
		return result;
	}

	@SuppressWarnings("unchecked")
	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (!(obj instanceof PRKdBucketTree))
			return false;
		PRKdBucketTree other = (PRKdBucketTree) obj;
		if (allDimensions != other.allDimensions)
			return false;
		if (!Arrays.equals(children, other.children))
			return false;
		if (data == null) {
			if (other.data != null)
				return false;
		} else if (!data.equals(other.data))
			return false;
		if (dimension != other.dimension)
			return false;
		if (isLeaf != other.isLeaf)
			return false;
		if (!Arrays.equals(lowerBound, other.lowerBound))
			return false;
		if (maxDensity != other.maxDensity)
			return false;
		if (maxDepth != other.maxDepth)
			return false;
		if (!Arrays.equals(numChildren, other.numChildren))
			return false;
		if (Double.doubleToLongBits(splitMedian) != Double
				.doubleToLongBits(other.splitMedian))
			return false;
		if (!Arrays.equals(upperBound, other.upperBound))
			return false;
		return true;
	}

}

Tester

package nat.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;

public class Test {
	public static void main(String[] args) {
		final int numTest = 100;
		final int clusterSize = 15;

		long linearTime, tree2time, tree3time, tree4time;
		String[] answer = new String[clusterSize];

		System.out.println("Starting Bucket PR k-d tree performance test...");
		System.out.println("Generating points...");

		ArrayList<String> input = new ArrayList<String>();
		ArrayList<double[]> location = new ArrayList<double[]>();

		for (int i = 0; i < numTest; i++) {
			input.add(generateRandomString());
			double[] p = new double[3];
			p[0] = Math.random();
			p[1] = Math.random();
			p[2] = Math.random();
			location.add(p);
		}

		PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
		PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
		PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4);

		for (int i = 0; i < numTest; i++) {
			tree2.addPoint(input.get(i), location.get(i));
			tree3.addPoint(input.get(i), location.get(i));
			tree4.addPoint(input.get(i), location.get(i));
		}

		double[] center = new double[3];
		center[0] = Math.random();
		center[1] = Math.random();
		center[2] = Math.random();

		System.out.println("Data generated.");
		System.out.println("Performing linear search...");

		linearTime = -System.nanoTime();
		PriorityQueue<Compare> pq = new PriorityQueue<Compare>();

		for (int i = 0; i < numTest; i++) {
			double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
					location.get(i), new double[] { 1, 1, 1 });

			pq.add(new Compare(input.get(i), distance));
		}

		double distance = -1;
		for (int i = 0; i < clusterSize; i++) {
			if (distance == -1)
				distance = pq.peek().val;
			answer[i] = pq.poll().data;
		}
		linearTime += System.nanoTime();
		Arrays.sort(answer);
		System.out.println("Linear search complete; time = "
				+ (linearTime / 1E9));
		
		
		System.out.println("Performing binary k-d tree search...");

		tree2time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree2r = tree2
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();

		tree2time += System.nanoTime();

		System.out.println("Binary tree search complete; time = "
				+ (tree2time / 1E9));
		int correct = 0;
		int j = 0;
		String[] tree2o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree2a) {
			tree2o[j++] = p.getValue();
		}
		Arrays.sort(tree2o);
		for (int i = 0; i < tree2o.length; i++) {
			if (tree2o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
		
		System.out.println("Performing ternary k-d tree search...");

		tree3time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree3r = tree3
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues();

		tree3time += System.nanoTime();

		System.out.println("Ternary tree search complete; time = "
				+ (tree3time / 1E9));
		correct = 0;
		j = 0;
		String[] tree3o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree3a) {
			tree3o[j++] = p.getValue();
		}
		Arrays.sort(tree3o);
		for (int i = 0; i < tree3o.length; i++) {
			if (tree3o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
		
		System.out.println("Performing quaternary k-d tree search...");

		tree4time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree4r = tree4
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues();

		tree4time += System.nanoTime();

		System.out.println("Quaternary tree search complete; time = "
				+ (tree4time / 1E9));
		correct = 0;
		j = 0;
		String[] tree4o = new String[clusterSize];
		for (PRKdBucketTree.KdPoint<String> p : tree4a) {
			tree4o[j++] = p.getValue();
		}
		Arrays.sort(tree4o);
		for (int i = 0; i < tree4o.length; i++) {
			if (tree4o[i].equals(answer[i]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
	}

	private static class Compare implements Comparable<Compare> {
		String data;
		double val;

		@Override
		public int compareTo(Compare o) {
			return (int) -Math.signum(o.val - val);
		}

		public Compare(String data, double val) {
			this.data = data;
			this.val = val;
		}

	}

	private static String generateRandomString() {
		String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
		Random r = new Random();
		char[] buf = new char[25];

		for (int i = 0; i < buf.length; i++) {
			buf[i] = chars.charAt(r.nextInt(chars.length()));
		}

		return new String(buf);
	}
}

Old


It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!

package nat.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;

import nat.util.M;

/**
 * 
 * Implementation of bucket PR k-d tree.
 * 
 * @author Nat Pavasant
 * 
 * @param <V>
 *            The type of data to store
 */
public class PRKdBucketTree<V> implements Serializable {
	private static final long serialVersionUID = 1L;

	public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
	public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();

	private final PRKdBucketTree<V>[] children;
	private final Queue<KdEntry<V>> data;
	private final Distancer distancer;
	private final double[] lowerBound, upperBound;
	private final int[] numChildren;
	private final int allDimensions;
	private final int dimension;
	private final int maxDepth, maxDensity;
	private final double splitMedian;

	private boolean isLeaf = true;

	/**
	 * Create new Bucket PR k-d tree.
	 * 
	 * @param allDimensions
	 *            number of dimensions in the tree
	 * @param lowerBound
	 *            the minimum value of the location of each dimension
	 * @param upperBound
	 *            the maximum value of the location of each dimension
	 * @param numChildren
	 *            number of children in each dimension
	 * @param maxDepth
	 *            the max depth of the tree
	 * @param maxDensity
	 *            size of bucket in each leaf
	 * @param distancer
	 *            distance measurer
	 */
	@SuppressWarnings("unchecked")
	public PRKdBucketTree(int allDimensions, double[] lowerBound,
			double[] upperBound, int[] numChildren, int maxDepth,
			int maxDensity, Distancer distancer) {
		if (allDimensions < 1 || maxDensity < 1)
			throw new IllegalArgumentException(
					"Either dimension or density isn't positive integer.");

		if (lowerBound.length != allDimensions
				|| upperBound.length != allDimensions
				|| numChildren.length != allDimensions)
			throw new IllegalArgumentException(
					"Either bounds or children amount is more or less than dimension count.");

		for (double a : lowerBound) {
			if (a < 0)
				throw new IllegalArgumentException(
						"Can't set lower bound to negative number.");
		}

		for (int i = 0; i < lowerBound.length; i++) {
			if (lowerBound[i] > upperBound[i])
				throw new IllegalArgumentException(
						"Upper bound must have a value higer than lower bound.");
		}

		this.allDimensions = allDimensions;
		this.lowerBound = lowerBound;
		this.upperBound = upperBound;
		this.numChildren = numChildren;
		this.distancer = distancer;
		this.maxDepth = maxDepth;
		this.maxDensity = maxDensity;
		this.dimension = maxDepth % allDimensions;
		this.data = new LinkedList<KdEntry<V>>();
		this.children = new PRKdBucketTree[numChildren[this.dimension]];
		this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
				/ numChildren[this.dimension];
	}

	// Another constructor
	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double lowerBound, double upperBound, int numChildren,
			int maxDepth, int maxDensity, Distancer distance) {
		double[] lowerBounds = new double[dimension];
		double[] upperBounds = new double[dimension];
		int[] numChildrens = new int[dimension];
		Arrays.fill(lowerBounds, lowerBound);
		Arrays.fill(upperBounds, upperBound);
		Arrays.fill(numChildrens, numChildren);
		return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
				numChildrens, maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity,
			Distancer distance) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, distance);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, numChildren,
				maxDepth, maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int maxDepth, int maxDensity) {
		return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
				maxDensity, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound, int numChildren) {
		return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
				EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
			double upperBound) {
		return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
	}

	public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
		return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
	}

	/**
	 * Add new point to the tree
	 * 
	 * @param value
	 *            the stored value
	 * @param location
	 *            location of the point
	 * @return removed point if any
	 */
	public KdEntry<V> addPoint(V value, double[] location) {
		if (location.length != allDimensions) {
			throw new IllegalArgumentException(
					"Provided location have either more or less dimensions than the tree.");
		}

		KdEntry<V> entry = new KdEntry<V>(value, location);

		return addPoint(entry);
	}

	/**
	 * Get the n-nearest neighbor.
	 * 
	 * @param size
	 *            number of neighbors
	 * @param center
	 *            center of the cluster
	 * @param weight
	 *            weighting for the distancer
	 * @return
	 */
	public KdCluster<V> getNearestNeighbor(int size, double[] center,
			double[] weight) {
		KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
		nearestNeighborSearch(cluster);
		return cluster;
	}

	/**
	 * Back-end implementation of the entry adding.
	 * 
	 * @param entry
	 *            new entry to the tree
	 * @return removed point if any
	 */
	private KdEntry<V> addPoint(KdEntry<V> entry) {

		if (isLeaf) {

			// Still has spaces, add the data
			if (data.size() < maxDensity) {
				data.add(entry);
				return null;
			}

			// final leaf, unsplitable, remove element and add new one.
			if (maxDepth <= 1) {
				data.add(entry);
				return data.poll();
			}

			// if we reached here, we need to split this leaf to a branch.
			isLeaf = false;
			for (KdEntry<V> p : data) {
				passToChildren(p);
			}

			data.clear();
		}

		// we are branch, pass to children
		return passToChildren(entry);
	}

	/**
	 * Perform n-nearest neighbor search
	 * 
	 * @param cluster
	 *            current working cluster
	 */
	private void nearestNeighborSearch(KdCluster<V> cluster) {
		if (isLeaf) {
			for (KdEntry<V> p : data) {
				cluster.consider(p);
			}
		} else {
			for (PRKdBucketTree<V> child : children) {
				if (child != null && cluster.isViable(child))
					child.nearestNeighborSearch(cluster);
			}
		}
	}

	/**
	 * Pass the entry to correct children
	 * 
	 * @param p
	 *            entry
	 * @return removed point if any
	 */
	private KdEntry<V> passToChildren(KdEntry<V> p) {
		int i = getChildrenIndex(p.getLocation(dimension));

		if (children[i] == null)
			children[i] = createChildTree(i);

		return children[i].addPoint(p);
	}

	/**
	 * Return index of the children which will contains data with that value.
	 * 
	 * @param value
	 *            the value
	 * @return index of children
	 */
	private int getChildrenIndex(double value) {
		return (int) M.limit(0, Math.floor(value / splitMedian),
				numChildren[dimension] - 1);
	}

	/**
	 * Create children tree with correct bound.
	 * 
	 * @param i
	 *            the index of the children
	 * @return created tree
	 */
	private PRKdBucketTree<V> createChildTree(int i) {
		double[] upperBound = this.upperBound.clone();
		double[] lowerBound = this.lowerBound.clone();
		lowerBound[i] = splitMedian * i;
		upperBound[i] = splitMedian * (i + 1);
		return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
				numChildren, maxDepth - 1, maxDensity, distancer);
	}

	public static class KdCluster<K> implements Iterable<KdPoint<K>> {
		private final PriorityQueue<KdPoint<K>> points;
		private final double[] center;
		private final double[] weight;
		private final int size;
		private final Distancer distancer;

		public KdCluster(int size, double[] center, double[] weight, Distancer distancer) {
			points = new PriorityQueue<KdPoint<K>>();
			this.size = size;
			this.center = center;
			this.weight = weight;
			this.distancer = distancer;
		}

		public void consider(KdEntry<K> k) {
			KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);

			if (points.size() < size) {
				points.add(p);
			} else if (points.peek().isFurtherThan(p)) {
				points.poll();
				points.add(p);
			}
		}

		public boolean isViable(PRKdBucketTree<K> tree) {

			if (points.size() < size)
				return true;

			double[] testPoints = new double[center.length];

			for (int i = 0; i < center.length; i++)
				testPoints[i] = M.limit(tree.lowerBound[i], center[i],
						tree.upperBound[i]);

			return points.peek().isFurtherThan(
					distancer.getDistance(center, testPoints, weight));
		}

		@Override
		public Iterator<KdPoint<K>> iterator() {
			return points.iterator();
		}

		public Collection<KdPoint<K>> getValues() {
			Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
					.size());

			for (KdPoint<K> p : points) {
				collect.add(p);
			}

			return collect;
		}
	}

	private static class KdEntry<K> implements Serializable {
		private static final long serialVersionUID = 1L;

		private final K value;
		private final double[] location;

		public KdEntry(K value, double[] location) {
			super();
			this.value = value;
			this.location = location;
		}

		public K getValue() {
			return value;
		}

		public double[] getLocation() {
			return location;
		}

		public double getLocation(int a) {
			return location[a];
		}
	}

	public static class KdPoint<K> extends KdEntry<K> implements Serializable,
			Comparable<KdPoint<K>> {
		private static final long serialVersionUID = 1L;
		private final double distanceToCenter;

		public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) {
			super(p.getValue(), p.getLocation());
			distanceToCenter = distancer.getDistance(center, getLocation(),
					weight);
		}

		public double getDistanceToCenter() {
			return distanceToCenter;
		}

		@Override
		public int compareTo(KdPoint<K> o) {
			return (int) Math.signum(o.distanceToCenter - distanceToCenter);
		}

		public boolean isFurtherThan(KdPoint<K> p) {
			return compareTo(p) == -1;
		}

		public boolean isFurtherThan(double distance) {
			return distanceToCenter > distance;
		}
	}

	public static abstract class Distancer {
		public double getDistance(double[] p1, double[] p2, double[] weight) {
			if (p1.length != p2.length)
				throw new IllegalArgumentException();
			return getPointDistance(p1, p2, weight);
		}

		public abstract double getPointDistance(double[] p1, double[] p2,
				double[] weight);

		public static class EuclidianDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.sqr(p1[i] - p2[i]) * weight[i];
				}
				return M.sqrt(result);
			}
		}

		public static class ManhattanDistancer extends Distancer {
			@Override
			public double getPointDistance(double[] p1, double[] p2,
					double[] weight) {
				double result = 0;
				for (int i = 0; i < p1.length; i++) {
					result += M.abs(p1[i] - p2[i]) * weight[i];
				}
				return result;
			}
		}
	}

	
}

I use this code to check:

package nat.tree;

import java.util.ArrayList;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;

import nat.tree.PRKdBucketTree.KdPoint;

public class Test {
	public static void main(String[] args) {
		final int numTest = 13000;
		final int clusterSize = 10;

		long linearTime, tree2time, tree3time;
		String[] answer = new String[100];

		System.out.println("Starting Bucket PR k-d tree performance test...");
		System.out.println("Generating points...");

		ArrayList<String> input = new ArrayList<String>();
		ArrayList<double[]> location = new ArrayList<double[]>();

		for (int i = 0; i < numTest; i++) {
			input.add(generateRandomString());
			double[] p = new double[3];
			p[0] = Math.random();
			p[1] = Math.random();
			p[2] = Math.random();
			location.add(p);
		}

		PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
		PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);

		for (int i = 0; i < numTest; i++) {
			tree2.addPoint(input.get(i), location.get(i));
			tree3.addPoint(input.get(i), location.get(i));
		}

		double[] center = new double[3];
		center[0] = Math.random();
		center[1] = Math.random();
		center[2] = Math.random();

		System.out.println("Data generated.");
		System.out.println("Performing linear search...");

		linearTime = -System.nanoTime();
		PriorityQueue<Compare> pq = new PriorityQueue<Compare>();

		for (int i = 0; i < numTest; i++) {
			double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
					location.get(i), new double[] { 1, 1, 1 });

			pq.add(new Compare(input.get(i), distance));
		}

		double distance = -1;
		for (int i = 0; i < clusterSize; i++) {
			if (distance == -1)
				distance = pq.peek().val;
			answer[i] = pq.poll().data;
		}
		linearTime += System.nanoTime();
		System.out.println("Linear search complete; time = "
				+ (linearTime / 1E9));
		System.out.println("Performing binary k-d tree search...");

		tree2time = -System.nanoTime();

		PRKdBucketTree.KdCluster<String> tree2r = tree2
				.getNearestNeighbor(clusterSize, center,
						new double[] { 1, 1, 1 });
		Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();

		tree2time += System.nanoTime();

		System.out.println("Distance = " + distance + "; "
				+ tree2a.iterator().next().getDistanceToCenter());

		System.out.println("Binary tree search complete; time = "
				+ (tree2time / 1E9));
		int correct = 0;
		int j = 0;
		for (PRKdBucketTree.KdPoint<String> p : tree2a) {
			System.out.print(p.getValue());
			System.out.print(" ");
			System.out.println(answer[j]);
			if (p.getValue().equals(answer[j++]))
				correct++;
		}

		System.out.println(": accuracy = " + ((double) correct / clusterSize));
	}

	private static class Compare implements Comparable<Compare> {
		String data;
		double val;

		@Override
		public int compareTo(Compare o) {
			return (int) -Math.signum(o.val - val);
		}

		public Compare(String data, double val) {
			this.data = data;
			this.val = val;
		}

	}

	private static String generateRandomString() {
		String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
		Random r = new Random();
		char[] buf = new char[15];

		for (int i = 0; i < buf.length; i++) {
			buf[i] = chars.charAt(r.nextInt(chars.length()));
		}

		return new String(buf);
	}
}

The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)

So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree() where is should be upperBound[dimension] and lowerBound[dimension] instead of i. » Nat | Talk » 16:30, 15 August 2009 (UTC)