Help:Help/Nat/Kd-Tree
Jump to navigation
Jump to search
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
package nat.tree; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.PriorityQueue; import java.util.Queue; import nat.util.M; /** * * Implementation of bucket PR k-d tree. * * @author Nat Pavasant * * @param <V> * The type of data to store */ public class PRKdBucketTree<V> implements Serializable { private static final long serialVersionUID = 1L; public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer(); public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer(); private final PRKdBucketTree<V>[] children; private final Queue<KdEntry<V>> data; private final Distancer distancer; private final double[] lowerBound, upperBound; private final int[] numChildren; private final int allDimensions; private final int dimension; private final int maxDepth, maxDensity; private final double splitMedian; private boolean isLeaf = true; /** * Create new Bucket PR k-d tree. * * @param allDimensions * number of dimensions in the tree * @param lowerBound * the minimum value of the location of each dimension * @param upperBound * the maximum value of the location of each dimension * @param numChildren * number of children in each dimension * @param maxDepth * the max depth of the tree * @param maxDensity * size of bucket in each leaf * @param distancer * distance measurer */ @SuppressWarnings("unchecked") public PRKdBucketTree(int allDimensions, double[] lowerBound, double[] upperBound, int[] numChildren, int maxDepth, int maxDensity, Distancer distancer) { if (allDimensions < 1 || maxDensity < 1) throw new IllegalArgumentException( "Either dimension or density isn't positive integer."); if (lowerBound.length != allDimensions || upperBound.length != allDimensions || numChildren.length != allDimensions) throw new IllegalArgumentException( "Either bounds or children amount is more or less than dimension count."); for (double a : lowerBound) { if (a < 0) throw new IllegalArgumentException( "Can't set lower bound to negative number."); } for (int i = 0; i < lowerBound.length; i++) { if (lowerBound[i] > upperBound[i]) throw new IllegalArgumentException( "Upper bound must have a value higer than lower bound."); } this.allDimensions = allDimensions; this.lowerBound = lowerBound; this.upperBound = upperBound; this.numChildren = numChildren; this.distancer = distancer; this.maxDepth = maxDepth; this.maxDensity = maxDensity; this.dimension = maxDepth % allDimensions; this.data = new LinkedList<KdEntry<V>>(); this.children = new PRKdBucketTree[numChildren[this.dimension]]; this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension]) / numChildren[this.dimension]; } // Another constructor public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double lowerBound, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { double[] lowerBounds = new double[dimension]; double[] upperBounds = new double[dimension]; int[] numChildrens = new int[dimension]; Arrays.fill(lowerBounds, lowerBound); Arrays.fill(upperBounds, upperBound); Arrays.fill(numChildrens, numChildren); return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds, numChildrens, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, 2, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren) { return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound) { return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) { return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN); } /** * Add new point to the tree * * @param value * the stored value * @param location * location of the point * @return removed point if any */ public KdEntry<V> addPoint(V value, double[] location) { if (location.length != allDimensions) { throw new IllegalArgumentException( "Provided location have either more or less dimensions than the tree."); } KdEntry<V> entry = new KdEntry<V>(value, location); return addPoint(entry); } /** * Get the n-nearest neighbor. * * @param size * number of neighbors * @param center * center of the cluster * @param weight * weighting for the distancer * @return */ public KdCluster<V> getNearestNeighbor(int size, double[] center, double[] weight) { KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer); nearestNeighborSearch(cluster); return cluster; } /** * Back-end implementation of the entry adding. * * @param entry * new entry to the tree * @return removed point if any */ private KdEntry<V> addPoint(KdEntry<V> entry) { if (isLeaf) { // Still has spaces, add the data if (data.size() < maxDensity) { data.add(entry); return null; } // final leaf, unsplitable, remove element and add new one. if (maxDepth <= 1) { data.add(entry); return data.poll(); } // if we reached here, we need to split this leaf to a branch. isLeaf = false; for (KdEntry<V> p : data) { passToChildren(p); } data.clear(); } // we are branch, pass to children return passToChildren(entry); } /** * Perform n-nearest neighbor search * * @param cluster * current working cluster */ private void nearestNeighborSearch(KdCluster<V> cluster) { if (isLeaf) { for (KdEntry<V> p : data) { cluster.consider(p); } } else { for (PRKdBucketTree<V> child : children) { if (child != null && cluster.isViable(child)) child.nearestNeighborSearch(cluster); } } } /** * Pass the entry to correct children * * @param p * entry * @return removed point if any */ private KdEntry<V> passToChildren(KdEntry<V> p) { int i = getChildrenIndex(p.getLocation(dimension)); if (children[i] == null) children[i] = createChildTree(i); return children[i].addPoint(p); } /** * Return index of the children which will contains data with that value. * * @param value * the value * @return index of children */ private int getChildrenIndex(double value) { return (int) M.limit(0, Math.floor(value / splitMedian), numChildren[dimension] - 1); } /** * Create children tree with correct bound. * * @param i * the index of the children * @return created tree */ private PRKdBucketTree<V> createChildTree(int i) { double[] upperBound = this.upperBound.clone(); double[] lowerBound = this.lowerBound.clone(); lowerBound[i] = splitMedian * i; upperBound[i] = splitMedian * (i + 1); return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound, numChildren, maxDepth - 1, maxDensity, distancer); } public static class KdCluster<K> implements Iterable<KdPoint<K>> { private final PriorityQueue<KdPoint<K>> points; private final double[] center; private final double[] weight; private final int size; private final Distancer distancer; public KdCluster(int size, double[] center, double[] weight, Distancer distancer) { points = new PriorityQueue<KdPoint<K>>(); this.size = size; this.center = center; this.weight = weight; this.distancer = distancer; } public void consider(KdEntry<K> k) { KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer); if (points.size() < size) { points.add(p); } else if (points.peek().isFurtherThan(p)) { points.poll(); points.add(p); } } public boolean isViable(PRKdBucketTree<K> tree) { if (points.size() < size) return true; double[] testPoints = new double[center.length]; for (int i = 0; i < center.length; i++) testPoints[i] = M.limit(tree.lowerBound[i], center[i], tree.upperBound[i]); return points.peek().isFurtherThan( distancer.getDistance(center, testPoints, weight)); } @Override public Iterator<KdPoint<K>> iterator() { return points.iterator(); } public Collection<KdPoint<K>> getValues() { Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points .size()); for (KdPoint<K> p : points) { collect.add(p); } return collect; } } private static class KdEntry<K> implements Serializable { private static final long serialVersionUID = 1L; private final K value; private final double[] location; public KdEntry(K value, double[] location) { super(); this.value = value; this.location = location; } public K getValue() { return value; } public double[] getLocation() { return location; } public double getLocation(int a) { return location[a]; } } public static class KdPoint<K> extends KdEntry<K> implements Serializable, Comparable<KdPoint<K>> { private static final long serialVersionUID = 1L; private final double distanceToCenter; public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) { super(p.getValue(), p.getLocation()); distanceToCenter = distancer.getDistance(center, getLocation(), weight); } public double getDistanceToCenter() { return distanceToCenter; } @Override public int compareTo(KdPoint<K> o) { return (int) Math.signum(o.distanceToCenter - distanceToCenter); } public boolean isFurtherThan(KdPoint<K> p) { return compareTo(p) == -1; } public boolean isFurtherThan(double distance) { return distanceToCenter > distance; } } public static abstract class Distancer { public double getDistance(double[] p1, double[] p2, double[] weight) { if (p1.length != p2.length) throw new IllegalArgumentException(); return getPointDistance(p1, p2, weight); } public abstract double getPointDistance(double[] p1, double[] p2, double[] weight); public static class EuclidianDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.sqr(p1[i] - p2[i]) * weight[i]; } return M.sqrt(result); } } public static class ManhattanDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.abs(p1[i] - p2[i]) * weight[i]; } return result; } } } }
I use this code to check:
package nat.tree; import java.util.ArrayList; import java.util.Collection; import java.util.PriorityQueue; import java.util.Random; import nat.tree.PRKdBucketTree.KdPoint; public class Test { public static void main(String[] args) { final int numTest = 13000; final int clusterSize = 10; long linearTime, tree2time, tree3time; String[] answer = new String[100]; System.out.println("Starting Bucket PR k-d tree performance test..."); System.out.println("Generating points..."); ArrayList<String> input = new ArrayList<String>(); ArrayList<double[]> location = new ArrayList<double[]>(); for (int i = 0; i < numTest; i++) { input.add(generateRandomString()); double[] p = new double[3]; p[0] = Math.random(); p[1] = Math.random(); p[2] = Math.random(); location.add(p); } PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2); PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3); for (int i = 0; i < numTest; i++) { tree2.addPoint(input.get(i), location.get(i)); tree3.addPoint(input.get(i), location.get(i)); } double[] center = new double[3]; center[0] = Math.random(); center[1] = Math.random(); center[2] = Math.random(); System.out.println("Data generated."); System.out.println("Performing linear search..."); linearTime = -System.nanoTime(); PriorityQueue<Compare> pq = new PriorityQueue<Compare>(); for (int i = 0; i < numTest; i++) { double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center, location.get(i), new double[] { 1, 1, 1 }); pq.add(new Compare(input.get(i), distance)); } double distance = -1; for (int i = 0; i < clusterSize; i++) { if (distance == -1) distance = pq.peek().val; answer[i] = pq.poll().data; } linearTime += System.nanoTime(); System.out.println("Linear search complete; time = " + (linearTime / 1E9)); System.out.println("Performing binary k-d tree search..."); tree2time = -System.nanoTime(); PRKdBucketTree.KdCluster<String> tree2r = tree2 .getNearestNeighbor(clusterSize, center, new double[] { 1, 1, 1 }); Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues(); tree2time += System.nanoTime(); System.out.println("Distance = " + distance + "; " + tree2a.iterator().next().getDistanceToCenter()); System.out.println("Binary tree search complete; time = " + (tree2time / 1E9)); int correct = 0; int j = 0; for (PRKdBucketTree.KdPoint<String> p : tree2a) { System.out.print(p.getValue()); System.out.print(" "); System.out.println(answer[j]); if (p.getValue().equals(answer[j++])) correct++; } System.out.println(": accuracy = " + ((double) correct / clusterSize)); } private static class Compare implements Comparable<Compare> { String data; double val; @Override public int compareTo(Compare o) { return (int) -Math.signum(o.val - val); } public Compare(String data, double val) { this.data = data; this.val = val; } } private static String generateRandomString() { String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; Random r = new Random(); char[] buf = new char[15]; for (int i = 0; i < buf.length; i++) { buf[i] = chars.charAt(r.nextInt(chars.length())); } return new String(buf); } }
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)
So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree()
where is should be upperBound[dimension]
and lowerBound[dimension]
instead of i
. » Nat | Talk » 16:30, 15 August 2009 (UTC)