Difference between revisions of "Help:Help/Nat/Kd-Tree"
Jump to navigation
Jump to search
(anyone please help me!!!) |
(latest code that I've problem) |
||
(One intermediate revision by the same user not shown) | |||
Line 1: | Line 1: | ||
+ | Here is my code that I have problem state in [[Talk:Kd-Tree]]. My ''M'' class can be found at [[User:Nat/Free code]]. | ||
+ | |||
+ | == Tree == | ||
+ | <pre> | ||
+ | package nat.tree; | ||
+ | |||
+ | import java.io.Serializable; | ||
+ | import java.util.ArrayList; | ||
+ | import java.util.Arrays; | ||
+ | import java.util.Collection; | ||
+ | import java.util.Iterator; | ||
+ | import java.util.LinkedList; | ||
+ | import java.util.PriorityQueue; | ||
+ | import java.util.Queue; | ||
+ | |||
+ | import nat.util.M; | ||
+ | |||
+ | /** | ||
+ | * | ||
+ | * Implementation of bucket PR k-d tree. | ||
+ | * | ||
+ | * @author Nat Pavasant | ||
+ | * | ||
+ | * @param <V> | ||
+ | * The type of data to store | ||
+ | */ | ||
+ | public class PRKdBucketTree<V> implements Serializable { | ||
+ | |||
+ | private static final long serialVersionUID = 1L; | ||
+ | |||
+ | public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer(); | ||
+ | public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer(); | ||
+ | |||
+ | private final PRKdBucketTree<V>[] children; | ||
+ | private final Queue<KdEntry<V>> data; | ||
+ | private final Distancer distancer; | ||
+ | private final double[] lowerBound, upperBound; | ||
+ | private final int[] numChildren; | ||
+ | private final int allDimensions; | ||
+ | private final int dimension; | ||
+ | private final int maxDepth, maxDensity; | ||
+ | private final double splitMedian; | ||
+ | |||
+ | private boolean isLeaf = true; | ||
+ | |||
+ | /** | ||
+ | * Create new Bucket PR k-d tree. | ||
+ | * | ||
+ | * @param allDimensions | ||
+ | * number of dimensions in the tree | ||
+ | * @param lowerBound | ||
+ | * the minimum value of the location of each dimension | ||
+ | * @param upperBound | ||
+ | * the maximum value of the location of each dimension | ||
+ | * @param numChildren | ||
+ | * number of children in each dimension | ||
+ | * @param maxDepth | ||
+ | * the max depth of the tree | ||
+ | * @param maxDensity | ||
+ | * size of bucket in each leaf | ||
+ | * @param distancer | ||
+ | * distance measurer | ||
+ | */ | ||
+ | @SuppressWarnings("unchecked") | ||
+ | public PRKdBucketTree(int allDimensions, double[] lowerBound, | ||
+ | double[] upperBound, int[] numChildren, int maxDepth, | ||
+ | int maxDensity, Distancer distancer) { | ||
+ | if (allDimensions < 1 || maxDensity < 1) | ||
+ | throw new IllegalArgumentException( | ||
+ | "Either dimension or density isn't positive integer."); | ||
+ | |||
+ | if (lowerBound.length != allDimensions | ||
+ | || upperBound.length != allDimensions | ||
+ | || numChildren.length != allDimensions) | ||
+ | throw new IllegalArgumentException( | ||
+ | "Either bounds or children amount is more or less than dimension count."); | ||
+ | |||
+ | for (double a : lowerBound) { | ||
+ | if (a < 0) | ||
+ | throw new IllegalArgumentException( | ||
+ | "Can't set lower bound to negative number."); | ||
+ | } | ||
+ | |||
+ | for (int i = 0; i < lowerBound.length; i++) { | ||
+ | if (lowerBound[i] > upperBound[i]) | ||
+ | throw new IllegalArgumentException( | ||
+ | "Upper bound must have a value higer than lower bound."); | ||
+ | } | ||
+ | |||
+ | this.allDimensions = allDimensions; | ||
+ | this.lowerBound = lowerBound; | ||
+ | this.upperBound = upperBound; | ||
+ | this.numChildren = numChildren; | ||
+ | this.distancer = distancer; | ||
+ | this.maxDepth = maxDepth; | ||
+ | this.maxDensity = maxDensity; | ||
+ | this.dimension = maxDepth % allDimensions; | ||
+ | this.data = new LinkedList<KdEntry<V>>(); | ||
+ | this.children = new PRKdBucketTree[numChildren[this.dimension]]; | ||
+ | this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension]) | ||
+ | / numChildren[this.dimension]; | ||
+ | } | ||
+ | |||
+ | // Another constructor | ||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double lowerBound, double upperBound, int numChildren, | ||
+ | int maxDepth, int maxDensity, Distancer distance) { | ||
+ | double[] lowerBounds = new double[dimension]; | ||
+ | double[] upperBounds = new double[dimension]; | ||
+ | int[] numChildrens = new int[dimension]; | ||
+ | Arrays.fill(lowerBounds, lowerBound); | ||
+ | Arrays.fill(upperBounds, upperBound); | ||
+ | Arrays.fill(numChildrens, numChildren); | ||
+ | return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds, | ||
+ | numChildrens, maxDepth, maxDensity, distance); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double upperBound, int numChildren, int maxDepth, int maxDensity, | ||
+ | Distancer distance) { | ||
+ | return getTree(dataType, dimension, 0, upperBound, numChildren, | ||
+ | maxDepth, maxDensity, distance); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double upperBound, int numChildren, int maxDepth, int maxDensity) { | ||
+ | return getTree(dataType, dimension, 0, upperBound, numChildren, | ||
+ | maxDepth, maxDensity, EUCLIDIAN); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double upperBound, int maxDepth, int maxDensity) { | ||
+ | return getTree(dataType, dimension, 0, upperBound, 2, maxDepth, | ||
+ | maxDensity, EUCLIDIAN); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double upperBound, int numChildren) { | ||
+ | return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8, | ||
+ | EUCLIDIAN); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
+ | double upperBound) { | ||
+ | return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN); | ||
+ | } | ||
+ | |||
+ | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) { | ||
+ | return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN); | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Add new point to the tree | ||
+ | * | ||
+ | * @param value | ||
+ | * the stored value | ||
+ | * @param location | ||
+ | * location of the point | ||
+ | * @return removed point if any | ||
+ | */ | ||
+ | public KdEntry<V> addPoint(V value, double[] location) { | ||
+ | if (location.length != allDimensions) { | ||
+ | throw new IllegalArgumentException( | ||
+ | "Provided location have either more or less dimensions than the tree."); | ||
+ | } | ||
+ | |||
+ | KdEntry<V> entry = new KdEntry<V>(value, location); | ||
+ | |||
+ | return addPoint(entry); | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Get the n-nearest neighbor. | ||
+ | * | ||
+ | * @param size | ||
+ | * number of neighbors | ||
+ | * @param center | ||
+ | * center of the cluster | ||
+ | * @param weight | ||
+ | * weighting for the distancer | ||
+ | * @return | ||
+ | */ | ||
+ | public KdCluster<V> getNearestNeighbor(int size, double[] center, | ||
+ | double[] weight) { | ||
+ | KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer); | ||
+ | nearestNeighborSearch(cluster); | ||
+ | return cluster; | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Back-end implementation of the entry adding. | ||
+ | * | ||
+ | * @param entry | ||
+ | * new entry to the tree | ||
+ | * @return removed point if any | ||
+ | */ | ||
+ | private KdEntry<V> addPoint(KdEntry<V> entry) { | ||
+ | |||
+ | if (isLeaf) { | ||
+ | |||
+ | // Still has spaces, add the data | ||
+ | if (data.size() < maxDensity) { | ||
+ | data.add(entry); | ||
+ | return null; | ||
+ | } | ||
+ | |||
+ | // final leaf, unsplitable, remove element and add new one. | ||
+ | if (maxDepth <= 1) { | ||
+ | data.add(entry); | ||
+ | return data.poll(); | ||
+ | } | ||
+ | |||
+ | // if we reached here, we need to split this leaf to a branch. | ||
+ | isLeaf = false; | ||
+ | for (KdEntry<V> p : data) { | ||
+ | passToChildren(p); | ||
+ | } | ||
+ | |||
+ | data.clear(); | ||
+ | } | ||
+ | |||
+ | // we are branch, pass to children | ||
+ | return passToChildren(entry); | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Perform n-nearest neighbor search | ||
+ | * | ||
+ | * @param cluster | ||
+ | * current working cluster | ||
+ | */ | ||
+ | private void nearestNeighborSearch(KdCluster<V> cluster) { | ||
+ | if (isLeaf) { | ||
+ | for (KdEntry<V> p : data) { | ||
+ | cluster.consider(p); | ||
+ | } | ||
+ | } else { | ||
+ | for (PRKdBucketTree<V> child : children) { | ||
+ | if (child != null && cluster.isViable(child)) | ||
+ | child.nearestNeighborSearch(cluster); | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Pass the entry to correct children | ||
+ | * | ||
+ | * @param p | ||
+ | * entry | ||
+ | * @return removed point if any | ||
+ | */ | ||
+ | private KdEntry<V> passToChildren(KdEntry<V> p) { | ||
+ | int i = getChildrenIndex(p.getLocation(dimension)); | ||
+ | |||
+ | if (children[i] == null) | ||
+ | children[i] = createChildTree(i); | ||
+ | |||
+ | return children[i].addPoint(p); | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Return index of the children which will contains data with that value. | ||
+ | * | ||
+ | * @param value | ||
+ | * the value | ||
+ | * @return index of children | ||
+ | */ | ||
+ | private int getChildrenIndex(double value) { | ||
+ | return (int) M.limit(0, Math.floor(value / splitMedian), | ||
+ | numChildren[dimension] - 1); | ||
+ | } | ||
+ | |||
+ | /** | ||
+ | * Create children tree with correct bound. | ||
+ | * | ||
+ | * @param i | ||
+ | * the index of the children | ||
+ | * @return created tree | ||
+ | */ | ||
+ | private PRKdBucketTree<V> createChildTree(int i) { | ||
+ | double[] upperBound = this.upperBound.clone(); | ||
+ | double[] lowerBound = this.lowerBound.clone(); | ||
+ | lowerBound[dimension] = splitMedian * i; | ||
+ | upperBound[dimension] = splitMedian * (i + 1); | ||
+ | return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound, | ||
+ | numChildren, maxDepth - 1, maxDensity, distancer); | ||
+ | } | ||
+ | |||
+ | public static class KdCluster<K> implements Iterable<KdPoint<K>> { | ||
+ | private final PriorityQueue<KdPoint<K>> points; | ||
+ | private final double[] center; | ||
+ | private final double[] weight; | ||
+ | private final int size; | ||
+ | private final Distancer distancer; | ||
+ | |||
+ | public KdCluster(int size, double[] center, double[] weight, | ||
+ | Distancer distancer) { | ||
+ | points = new PriorityQueue<KdPoint<K>>(); | ||
+ | this.size = size; | ||
+ | this.center = center; | ||
+ | this.weight = weight; | ||
+ | this.distancer = distancer; | ||
+ | } | ||
+ | |||
+ | public void consider(KdEntry<K> k) { | ||
+ | KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer); | ||
+ | |||
+ | if (points.size() < size) { | ||
+ | points.add(p); | ||
+ | System.out.println(points.toString()); | ||
+ | } else if (points.peek().getDistanceToCenter() > p | ||
+ | .getDistanceToCenter()) { | ||
+ | System.out.println(points.toString()); | ||
+ | points.poll(); | ||
+ | points.add(p); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public boolean isViable(PRKdBucketTree<K> tree) { | ||
+ | |||
+ | if (points.size() < size) | ||
+ | return true; | ||
+ | |||
+ | double[] testPoints = new double[center.length]; | ||
+ | |||
+ | for (int i = 0; i < center.length; i++) | ||
+ | testPoints[i] = M.limit(tree.lowerBound[i], center[i], | ||
+ | tree.upperBound[i]); | ||
+ | |||
+ | return points.peek().getDistanceToCenter() > distancer.getDistance( | ||
+ | center, testPoints, weight); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public Iterator<KdPoint<K>> iterator() { | ||
+ | return points.iterator(); | ||
+ | } | ||
+ | |||
+ | public Collection<KdPoint<K>> getValues() { | ||
+ | Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points | ||
+ | .size()); | ||
+ | |||
+ | for (KdPoint<K> p : points) { | ||
+ | collect.add(p); | ||
+ | } | ||
+ | |||
+ | return collect; | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public int hashCode() { | ||
+ | final int prime = 31; | ||
+ | int result = 1; | ||
+ | result = prime * result + Arrays.hashCode(center); | ||
+ | result = prime * result | ||
+ | + ((distancer == null) ? 0 : distancer.hashCode()); | ||
+ | result = prime * result | ||
+ | + ((points == null) ? 0 : points.hashCode()); | ||
+ | result = prime * result + size; | ||
+ | result = prime * result + Arrays.hashCode(weight); | ||
+ | return result; | ||
+ | } | ||
+ | |||
+ | @SuppressWarnings("unchecked") | ||
+ | @Override | ||
+ | public boolean equals(Object obj) { | ||
+ | if (this == obj) | ||
+ | return true; | ||
+ | if (obj == null) | ||
+ | return false; | ||
+ | if (!(obj instanceof KdCluster)) | ||
+ | return false; | ||
+ | KdCluster other = (KdCluster) obj; | ||
+ | if (!Arrays.equals(center, other.center)) | ||
+ | return false; | ||
+ | if (distancer == null) { | ||
+ | if (other.distancer != null) | ||
+ | return false; | ||
+ | } else if (!distancer.equals(other.distancer)) | ||
+ | return false; | ||
+ | if (points == null) { | ||
+ | if (other.points != null) | ||
+ | return false; | ||
+ | } else if (!points.equals(other.points)) | ||
+ | return false; | ||
+ | if (size != other.size) | ||
+ | return false; | ||
+ | if (!Arrays.equals(weight, other.weight)) | ||
+ | return false; | ||
+ | return true; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | private static class KdEntry<K> implements Serializable { | ||
+ | private static final long serialVersionUID = 1L; | ||
+ | |||
+ | private final K value; | ||
+ | private final double[] location; | ||
+ | |||
+ | public KdEntry(K value, double[] location) { | ||
+ | super(); | ||
+ | this.value = value; | ||
+ | this.location = location; | ||
+ | } | ||
+ | |||
+ | public K getValue() { | ||
+ | return value; | ||
+ | } | ||
+ | |||
+ | public double[] getLocation() { | ||
+ | return location; | ||
+ | } | ||
+ | |||
+ | public double getLocation(int a) { | ||
+ | return location[a]; | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public int hashCode() { | ||
+ | final int prime = 31; | ||
+ | int result = 1; | ||
+ | result = prime * result + Arrays.hashCode(location); | ||
+ | result = prime * result + ((value == null) ? 0 : value.hashCode()); | ||
+ | return result; | ||
+ | } | ||
+ | |||
+ | @SuppressWarnings("unchecked") | ||
+ | @Override | ||
+ | public boolean equals(Object obj) { | ||
+ | if (this == obj) | ||
+ | return true; | ||
+ | if (obj == null) | ||
+ | return false; | ||
+ | if (!(obj instanceof KdEntry)) | ||
+ | return false; | ||
+ | KdEntry other = (KdEntry) obj; | ||
+ | if (!Arrays.equals(location, other.location)) | ||
+ | return false; | ||
+ | if (value == null) { | ||
+ | if (other.value != null) | ||
+ | return false; | ||
+ | } else if (!value.equals(other.value)) | ||
+ | return false; | ||
+ | return true; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public static class KdPoint<K> extends KdEntry<K> implements Serializable, | ||
+ | Comparable<KdPoint<K>> { | ||
+ | private static final long serialVersionUID = 1L; | ||
+ | private final double distanceToCenter; | ||
+ | |||
+ | public KdPoint(KdEntry<K> p, double[] center, double[] weight, | ||
+ | Distancer distancer) { | ||
+ | super(p.getValue(), p.getLocation()); | ||
+ | distanceToCenter = distancer.getDistance(center, getLocation(), | ||
+ | weight); | ||
+ | } | ||
+ | |||
+ | public double getDistanceToCenter() { | ||
+ | return distanceToCenter; | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public String toString() { | ||
+ | return (new Double(distanceToCenter)).toString(); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public int compareTo(KdPoint<K> o) { | ||
+ | return (int) Math.signum(o.distanceToCenter - distanceToCenter); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public int hashCode() { | ||
+ | final int prime = 31; | ||
+ | int result = super.hashCode(); | ||
+ | long temp; | ||
+ | temp = Double.doubleToLongBits(distanceToCenter); | ||
+ | result = prime * result + (int) (temp ^ (temp >>> 32)); | ||
+ | return result; | ||
+ | } | ||
+ | |||
+ | @SuppressWarnings("unchecked") | ||
+ | @Override | ||
+ | public boolean equals(Object obj) { | ||
+ | if (this == obj) | ||
+ | return true; | ||
+ | if (!super.equals(obj)) | ||
+ | return false; | ||
+ | if (!(obj instanceof KdPoint)) | ||
+ | return false; | ||
+ | KdPoint other = (KdPoint) obj; | ||
+ | if (Double.doubleToLongBits(distanceToCenter) != Double | ||
+ | .doubleToLongBits(other.distanceToCenter)) | ||
+ | return false; | ||
+ | return true; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public static abstract class Distancer { | ||
+ | public double getDistance(double[] p1, double[] p2, double[] weight) { | ||
+ | if (p1.length != p2.length) | ||
+ | throw new IllegalArgumentException(); | ||
+ | return getPointDistance(p1, p2, weight); | ||
+ | } | ||
+ | |||
+ | public abstract double getPointDistance(double[] p1, double[] p2, | ||
+ | double[] weight); | ||
+ | |||
+ | public static class EuclidianDistancer extends Distancer { | ||
+ | @Override | ||
+ | public double getPointDistance(double[] p1, double[] p2, | ||
+ | double[] weight) { | ||
+ | double result = 0; | ||
+ | for (int i = 0; i < p1.length; i++) { | ||
+ | result += M.sqr(p1[i] - p2[i]) * weight[i]; | ||
+ | } | ||
+ | return M.sqrt(result); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public static class ManhattanDistancer extends Distancer { | ||
+ | @Override | ||
+ | public double getPointDistance(double[] p1, double[] p2, | ||
+ | double[] weight) { | ||
+ | double result = 0; | ||
+ | for (int i = 0; i < p1.length; i++) { | ||
+ | result += M.abs(p1[i] - p2[i]) * weight[i]; | ||
+ | } | ||
+ | return result; | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public int hashCode() { | ||
+ | final int prime = 31; | ||
+ | int result = 1; | ||
+ | result = prime * result + allDimensions; | ||
+ | result = prime * result + Arrays.hashCode(children); | ||
+ | result = prime * result + ((data == null) ? 0 : data.hashCode()); | ||
+ | result = prime * result + dimension; | ||
+ | result = prime * result + (isLeaf ? 1231 : 1237); | ||
+ | result = prime * result + Arrays.hashCode(lowerBound); | ||
+ | result = prime * result + maxDensity; | ||
+ | result = prime * result + maxDepth; | ||
+ | result = prime * result + Arrays.hashCode(numChildren); | ||
+ | long temp; | ||
+ | temp = Double.doubleToLongBits(splitMedian); | ||
+ | result = prime * result + (int) (temp ^ (temp >>> 32)); | ||
+ | result = prime * result + Arrays.hashCode(upperBound); | ||
+ | return result; | ||
+ | } | ||
+ | |||
+ | @SuppressWarnings("unchecked") | ||
+ | @Override | ||
+ | public boolean equals(Object obj) { | ||
+ | if (this == obj) | ||
+ | return true; | ||
+ | if (obj == null) | ||
+ | return false; | ||
+ | if (!(obj instanceof PRKdBucketTree)) | ||
+ | return false; | ||
+ | PRKdBucketTree other = (PRKdBucketTree) obj; | ||
+ | if (allDimensions != other.allDimensions) | ||
+ | return false; | ||
+ | if (!Arrays.equals(children, other.children)) | ||
+ | return false; | ||
+ | if (data == null) { | ||
+ | if (other.data != null) | ||
+ | return false; | ||
+ | } else if (!data.equals(other.data)) | ||
+ | return false; | ||
+ | if (dimension != other.dimension) | ||
+ | return false; | ||
+ | if (isLeaf != other.isLeaf) | ||
+ | return false; | ||
+ | if (!Arrays.equals(lowerBound, other.lowerBound)) | ||
+ | return false; | ||
+ | if (maxDensity != other.maxDensity) | ||
+ | return false; | ||
+ | if (maxDepth != other.maxDepth) | ||
+ | return false; | ||
+ | if (!Arrays.equals(numChildren, other.numChildren)) | ||
+ | return false; | ||
+ | if (Double.doubleToLongBits(splitMedian) != Double | ||
+ | .doubleToLongBits(other.splitMedian)) | ||
+ | return false; | ||
+ | if (!Arrays.equals(upperBound, other.upperBound)) | ||
+ | return false; | ||
+ | return true; | ||
+ | } | ||
+ | |||
+ | } | ||
+ | </pre> | ||
+ | |||
+ | == Tester == | ||
+ | <pre> | ||
+ | package nat.tree; | ||
+ | |||
+ | import java.util.ArrayList; | ||
+ | import java.util.Arrays; | ||
+ | import java.util.Collection; | ||
+ | import java.util.PriorityQueue; | ||
+ | import java.util.Random; | ||
+ | |||
+ | public class Test { | ||
+ | public static void main(String[] args) { | ||
+ | final int numTest = 100; | ||
+ | final int clusterSize = 15; | ||
+ | |||
+ | long linearTime, tree2time, tree3time, tree4time; | ||
+ | String[] answer = new String[clusterSize]; | ||
+ | |||
+ | System.out.println("Starting Bucket PR k-d tree performance test..."); | ||
+ | System.out.println("Generating points..."); | ||
+ | |||
+ | ArrayList<String> input = new ArrayList<String>(); | ||
+ | ArrayList<double[]> location = new ArrayList<double[]>(); | ||
+ | |||
+ | for (int i = 0; i < numTest; i++) { | ||
+ | input.add(generateRandomString()); | ||
+ | double[] p = new double[3]; | ||
+ | p[0] = Math.random(); | ||
+ | p[1] = Math.random(); | ||
+ | p[2] = Math.random(); | ||
+ | location.add(p); | ||
+ | } | ||
+ | |||
+ | PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2); | ||
+ | PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3); | ||
+ | PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4); | ||
+ | |||
+ | for (int i = 0; i < numTest; i++) { | ||
+ | tree2.addPoint(input.get(i), location.get(i)); | ||
+ | tree3.addPoint(input.get(i), location.get(i)); | ||
+ | tree4.addPoint(input.get(i), location.get(i)); | ||
+ | } | ||
+ | |||
+ | double[] center = new double[3]; | ||
+ | center[0] = Math.random(); | ||
+ | center[1] = Math.random(); | ||
+ | center[2] = Math.random(); | ||
+ | |||
+ | System.out.println("Data generated."); | ||
+ | System.out.println("Performing linear search..."); | ||
+ | |||
+ | linearTime = -System.nanoTime(); | ||
+ | PriorityQueue<Compare> pq = new PriorityQueue<Compare>(); | ||
+ | |||
+ | for (int i = 0; i < numTest; i++) { | ||
+ | double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center, | ||
+ | location.get(i), new double[] { 1, 1, 1 }); | ||
+ | |||
+ | pq.add(new Compare(input.get(i), distance)); | ||
+ | } | ||
+ | |||
+ | double distance = -1; | ||
+ | for (int i = 0; i < clusterSize; i++) { | ||
+ | if (distance == -1) | ||
+ | distance = pq.peek().val; | ||
+ | answer[i] = pq.poll().data; | ||
+ | } | ||
+ | linearTime += System.nanoTime(); | ||
+ | Arrays.sort(answer); | ||
+ | System.out.println("Linear search complete; time = " | ||
+ | + (linearTime / 1E9)); | ||
+ | |||
+ | |||
+ | System.out.println("Performing binary k-d tree search..."); | ||
+ | |||
+ | tree2time = -System.nanoTime(); | ||
+ | |||
+ | PRKdBucketTree.KdCluster<String> tree2r = tree2 | ||
+ | .getNearestNeighbor(clusterSize, center, | ||
+ | new double[] { 1, 1, 1 }); | ||
+ | Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues(); | ||
+ | |||
+ | tree2time += System.nanoTime(); | ||
+ | |||
+ | System.out.println("Binary tree search complete; time = " | ||
+ | + (tree2time / 1E9)); | ||
+ | int correct = 0; | ||
+ | int j = 0; | ||
+ | String[] tree2o = new String[clusterSize]; | ||
+ | for (PRKdBucketTree.KdPoint<String> p : tree2a) { | ||
+ | tree2o[j++] = p.getValue(); | ||
+ | } | ||
+ | Arrays.sort(tree2o); | ||
+ | for (int i = 0; i < tree2o.length; i++) { | ||
+ | if (tree2o[i].equals(answer[i])) | ||
+ | correct++; | ||
+ | } | ||
+ | |||
+ | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
+ | |||
+ | System.out.println("Performing ternary k-d tree search..."); | ||
+ | |||
+ | tree3time = -System.nanoTime(); | ||
+ | |||
+ | PRKdBucketTree.KdCluster<String> tree3r = tree3 | ||
+ | .getNearestNeighbor(clusterSize, center, | ||
+ | new double[] { 1, 1, 1 }); | ||
+ | Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues(); | ||
+ | |||
+ | tree3time += System.nanoTime(); | ||
+ | |||
+ | System.out.println("Ternary tree search complete; time = " | ||
+ | + (tree3time / 1E9)); | ||
+ | correct = 0; | ||
+ | j = 0; | ||
+ | String[] tree3o = new String[clusterSize]; | ||
+ | for (PRKdBucketTree.KdPoint<String> p : tree3a) { | ||
+ | tree3o[j++] = p.getValue(); | ||
+ | } | ||
+ | Arrays.sort(tree3o); | ||
+ | for (int i = 0; i < tree3o.length; i++) { | ||
+ | if (tree3o[i].equals(answer[i])) | ||
+ | correct++; | ||
+ | } | ||
+ | |||
+ | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
+ | |||
+ | System.out.println("Performing quaternary k-d tree search..."); | ||
+ | |||
+ | tree4time = -System.nanoTime(); | ||
+ | |||
+ | PRKdBucketTree.KdCluster<String> tree4r = tree4 | ||
+ | .getNearestNeighbor(clusterSize, center, | ||
+ | new double[] { 1, 1, 1 }); | ||
+ | Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues(); | ||
+ | |||
+ | tree4time += System.nanoTime(); | ||
+ | |||
+ | System.out.println("Quaternary tree search complete; time = " | ||
+ | + (tree4time / 1E9)); | ||
+ | correct = 0; | ||
+ | j = 0; | ||
+ | String[] tree4o = new String[clusterSize]; | ||
+ | for (PRKdBucketTree.KdPoint<String> p : tree4a) { | ||
+ | tree4o[j++] = p.getValue(); | ||
+ | } | ||
+ | Arrays.sort(tree4o); | ||
+ | for (int i = 0; i < tree4o.length; i++) { | ||
+ | if (tree4o[i].equals(answer[i])) | ||
+ | correct++; | ||
+ | } | ||
+ | |||
+ | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
+ | } | ||
+ | |||
+ | private static class Compare implements Comparable<Compare> { | ||
+ | String data; | ||
+ | double val; | ||
+ | |||
+ | @Override | ||
+ | public int compareTo(Compare o) { | ||
+ | return (int) -Math.signum(o.val - val); | ||
+ | } | ||
+ | |||
+ | public Compare(String data, double val) { | ||
+ | this.data = data; | ||
+ | this.val = val; | ||
+ | } | ||
+ | |||
+ | } | ||
+ | |||
+ | private static String generateRandomString() { | ||
+ | String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; | ||
+ | Random r = new Random(); | ||
+ | char[] buf = new char[25]; | ||
+ | |||
+ | for (int i = 0; i < buf.length; i++) { | ||
+ | buf[i] = chars.charAt(r.nextInt(chars.length())); | ||
+ | } | ||
+ | |||
+ | return new String(buf); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | </pre> | ||
+ | |||
+ | == Old == | ||
+ | ---- | ||
+ | |||
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me! | It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me! | ||
Line 565: | Line 1,351: | ||
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-''ary'' tree style too. Anyway, please help. Thank you in advance =) » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 13:40, 15 August 2009 (UTC) | The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-''ary'' tree style too. Anyway, please help. Thank you in advance =) » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 13:40, 15 August 2009 (UTC) | ||
+ | |||
+ | So sad no one help me =( By the way, finally I spot the bug. It is in <code>createChildTree()</code> where is should be <code>upperBound[dimension]</code> and <code>lowerBound[dimension]</code> instead of <code>i</code>. » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 16:30, 15 August 2009 (UTC) |
Latest revision as of 18:12, 15 August 2009
Here is my code that I have problem state in Talk:Kd-Tree. My M class can be found at User:Nat/Free code.
Tree
package nat.tree; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.PriorityQueue; import java.util.Queue; import nat.util.M; /** * * Implementation of bucket PR k-d tree. * * @author Nat Pavasant * * @param <V> * The type of data to store */ public class PRKdBucketTree<V> implements Serializable { private static final long serialVersionUID = 1L; public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer(); public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer(); private final PRKdBucketTree<V>[] children; private final Queue<KdEntry<V>> data; private final Distancer distancer; private final double[] lowerBound, upperBound; private final int[] numChildren; private final int allDimensions; private final int dimension; private final int maxDepth, maxDensity; private final double splitMedian; private boolean isLeaf = true; /** * Create new Bucket PR k-d tree. * * @param allDimensions * number of dimensions in the tree * @param lowerBound * the minimum value of the location of each dimension * @param upperBound * the maximum value of the location of each dimension * @param numChildren * number of children in each dimension * @param maxDepth * the max depth of the tree * @param maxDensity * size of bucket in each leaf * @param distancer * distance measurer */ @SuppressWarnings("unchecked") public PRKdBucketTree(int allDimensions, double[] lowerBound, double[] upperBound, int[] numChildren, int maxDepth, int maxDensity, Distancer distancer) { if (allDimensions < 1 || maxDensity < 1) throw new IllegalArgumentException( "Either dimension or density isn't positive integer."); if (lowerBound.length != allDimensions || upperBound.length != allDimensions || numChildren.length != allDimensions) throw new IllegalArgumentException( "Either bounds or children amount is more or less than dimension count."); for (double a : lowerBound) { if (a < 0) throw new IllegalArgumentException( "Can't set lower bound to negative number."); } for (int i = 0; i < lowerBound.length; i++) { if (lowerBound[i] > upperBound[i]) throw new IllegalArgumentException( "Upper bound must have a value higer than lower bound."); } this.allDimensions = allDimensions; this.lowerBound = lowerBound; this.upperBound = upperBound; this.numChildren = numChildren; this.distancer = distancer; this.maxDepth = maxDepth; this.maxDensity = maxDensity; this.dimension = maxDepth % allDimensions; this.data = new LinkedList<KdEntry<V>>(); this.children = new PRKdBucketTree[numChildren[this.dimension]]; this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension]) / numChildren[this.dimension]; } // Another constructor public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double lowerBound, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { double[] lowerBounds = new double[dimension]; double[] upperBounds = new double[dimension]; int[] numChildrens = new int[dimension]; Arrays.fill(lowerBounds, lowerBound); Arrays.fill(upperBounds, upperBound); Arrays.fill(numChildrens, numChildren); return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds, numChildrens, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, 2, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren) { return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound) { return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) { return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN); } /** * Add new point to the tree * * @param value * the stored value * @param location * location of the point * @return removed point if any */ public KdEntry<V> addPoint(V value, double[] location) { if (location.length != allDimensions) { throw new IllegalArgumentException( "Provided location have either more or less dimensions than the tree."); } KdEntry<V> entry = new KdEntry<V>(value, location); return addPoint(entry); } /** * Get the n-nearest neighbor. * * @param size * number of neighbors * @param center * center of the cluster * @param weight * weighting for the distancer * @return */ public KdCluster<V> getNearestNeighbor(int size, double[] center, double[] weight) { KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer); nearestNeighborSearch(cluster); return cluster; } /** * Back-end implementation of the entry adding. * * @param entry * new entry to the tree * @return removed point if any */ private KdEntry<V> addPoint(KdEntry<V> entry) { if (isLeaf) { // Still has spaces, add the data if (data.size() < maxDensity) { data.add(entry); return null; } // final leaf, unsplitable, remove element and add new one. if (maxDepth <= 1) { data.add(entry); return data.poll(); } // if we reached here, we need to split this leaf to a branch. isLeaf = false; for (KdEntry<V> p : data) { passToChildren(p); } data.clear(); } // we are branch, pass to children return passToChildren(entry); } /** * Perform n-nearest neighbor search * * @param cluster * current working cluster */ private void nearestNeighborSearch(KdCluster<V> cluster) { if (isLeaf) { for (KdEntry<V> p : data) { cluster.consider(p); } } else { for (PRKdBucketTree<V> child : children) { if (child != null && cluster.isViable(child)) child.nearestNeighborSearch(cluster); } } } /** * Pass the entry to correct children * * @param p * entry * @return removed point if any */ private KdEntry<V> passToChildren(KdEntry<V> p) { int i = getChildrenIndex(p.getLocation(dimension)); if (children[i] == null) children[i] = createChildTree(i); return children[i].addPoint(p); } /** * Return index of the children which will contains data with that value. * * @param value * the value * @return index of children */ private int getChildrenIndex(double value) { return (int) M.limit(0, Math.floor(value / splitMedian), numChildren[dimension] - 1); } /** * Create children tree with correct bound. * * @param i * the index of the children * @return created tree */ private PRKdBucketTree<V> createChildTree(int i) { double[] upperBound = this.upperBound.clone(); double[] lowerBound = this.lowerBound.clone(); lowerBound[dimension] = splitMedian * i; upperBound[dimension] = splitMedian * (i + 1); return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound, numChildren, maxDepth - 1, maxDensity, distancer); } public static class KdCluster<K> implements Iterable<KdPoint<K>> { private final PriorityQueue<KdPoint<K>> points; private final double[] center; private final double[] weight; private final int size; private final Distancer distancer; public KdCluster(int size, double[] center, double[] weight, Distancer distancer) { points = new PriorityQueue<KdPoint<K>>(); this.size = size; this.center = center; this.weight = weight; this.distancer = distancer; } public void consider(KdEntry<K> k) { KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer); if (points.size() < size) { points.add(p); System.out.println(points.toString()); } else if (points.peek().getDistanceToCenter() > p .getDistanceToCenter()) { System.out.println(points.toString()); points.poll(); points.add(p); } } public boolean isViable(PRKdBucketTree<K> tree) { if (points.size() < size) return true; double[] testPoints = new double[center.length]; for (int i = 0; i < center.length; i++) testPoints[i] = M.limit(tree.lowerBound[i], center[i], tree.upperBound[i]); return points.peek().getDistanceToCenter() > distancer.getDistance( center, testPoints, weight); } @Override public Iterator<KdPoint<K>> iterator() { return points.iterator(); } public Collection<KdPoint<K>> getValues() { Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points .size()); for (KdPoint<K> p : points) { collect.add(p); } return collect; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(center); result = prime * result + ((distancer == null) ? 0 : distancer.hashCode()); result = prime * result + ((points == null) ? 0 : points.hashCode()); result = prime * result + size; result = prime * result + Arrays.hashCode(weight); return result; } @SuppressWarnings("unchecked") @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (!(obj instanceof KdCluster)) return false; KdCluster other = (KdCluster) obj; if (!Arrays.equals(center, other.center)) return false; if (distancer == null) { if (other.distancer != null) return false; } else if (!distancer.equals(other.distancer)) return false; if (points == null) { if (other.points != null) return false; } else if (!points.equals(other.points)) return false; if (size != other.size) return false; if (!Arrays.equals(weight, other.weight)) return false; return true; } } private static class KdEntry<K> implements Serializable { private static final long serialVersionUID = 1L; private final K value; private final double[] location; public KdEntry(K value, double[] location) { super(); this.value = value; this.location = location; } public K getValue() { return value; } public double[] getLocation() { return location; } public double getLocation(int a) { return location[a]; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(location); result = prime * result + ((value == null) ? 0 : value.hashCode()); return result; } @SuppressWarnings("unchecked") @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (!(obj instanceof KdEntry)) return false; KdEntry other = (KdEntry) obj; if (!Arrays.equals(location, other.location)) return false; if (value == null) { if (other.value != null) return false; } else if (!value.equals(other.value)) return false; return true; } } public static class KdPoint<K> extends KdEntry<K> implements Serializable, Comparable<KdPoint<K>> { private static final long serialVersionUID = 1L; private final double distanceToCenter; public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) { super(p.getValue(), p.getLocation()); distanceToCenter = distancer.getDistance(center, getLocation(), weight); } public double getDistanceToCenter() { return distanceToCenter; } @Override public String toString() { return (new Double(distanceToCenter)).toString(); } @Override public int compareTo(KdPoint<K> o) { return (int) Math.signum(o.distanceToCenter - distanceToCenter); } @Override public int hashCode() { final int prime = 31; int result = super.hashCode(); long temp; temp = Double.doubleToLongBits(distanceToCenter); result = prime * result + (int) (temp ^ (temp >>> 32)); return result; } @SuppressWarnings("unchecked") @Override public boolean equals(Object obj) { if (this == obj) return true; if (!super.equals(obj)) return false; if (!(obj instanceof KdPoint)) return false; KdPoint other = (KdPoint) obj; if (Double.doubleToLongBits(distanceToCenter) != Double .doubleToLongBits(other.distanceToCenter)) return false; return true; } } public static abstract class Distancer { public double getDistance(double[] p1, double[] p2, double[] weight) { if (p1.length != p2.length) throw new IllegalArgumentException(); return getPointDistance(p1, p2, weight); } public abstract double getPointDistance(double[] p1, double[] p2, double[] weight); public static class EuclidianDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.sqr(p1[i] - p2[i]) * weight[i]; } return M.sqrt(result); } } public static class ManhattanDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.abs(p1[i] - p2[i]) * weight[i]; } return result; } } } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + allDimensions; result = prime * result + Arrays.hashCode(children); result = prime * result + ((data == null) ? 0 : data.hashCode()); result = prime * result + dimension; result = prime * result + (isLeaf ? 1231 : 1237); result = prime * result + Arrays.hashCode(lowerBound); result = prime * result + maxDensity; result = prime * result + maxDepth; result = prime * result + Arrays.hashCode(numChildren); long temp; temp = Double.doubleToLongBits(splitMedian); result = prime * result + (int) (temp ^ (temp >>> 32)); result = prime * result + Arrays.hashCode(upperBound); return result; } @SuppressWarnings("unchecked") @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (!(obj instanceof PRKdBucketTree)) return false; PRKdBucketTree other = (PRKdBucketTree) obj; if (allDimensions != other.allDimensions) return false; if (!Arrays.equals(children, other.children)) return false; if (data == null) { if (other.data != null) return false; } else if (!data.equals(other.data)) return false; if (dimension != other.dimension) return false; if (isLeaf != other.isLeaf) return false; if (!Arrays.equals(lowerBound, other.lowerBound)) return false; if (maxDensity != other.maxDensity) return false; if (maxDepth != other.maxDepth) return false; if (!Arrays.equals(numChildren, other.numChildren)) return false; if (Double.doubleToLongBits(splitMedian) != Double .doubleToLongBits(other.splitMedian)) return false; if (!Arrays.equals(upperBound, other.upperBound)) return false; return true; } }
Tester
package nat.tree; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.PriorityQueue; import java.util.Random; public class Test { public static void main(String[] args) { final int numTest = 100; final int clusterSize = 15; long linearTime, tree2time, tree3time, tree4time; String[] answer = new String[clusterSize]; System.out.println("Starting Bucket PR k-d tree performance test..."); System.out.println("Generating points..."); ArrayList<String> input = new ArrayList<String>(); ArrayList<double[]> location = new ArrayList<double[]>(); for (int i = 0; i < numTest; i++) { input.add(generateRandomString()); double[] p = new double[3]; p[0] = Math.random(); p[1] = Math.random(); p[2] = Math.random(); location.add(p); } PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2); PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3); PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4); for (int i = 0; i < numTest; i++) { tree2.addPoint(input.get(i), location.get(i)); tree3.addPoint(input.get(i), location.get(i)); tree4.addPoint(input.get(i), location.get(i)); } double[] center = new double[3]; center[0] = Math.random(); center[1] = Math.random(); center[2] = Math.random(); System.out.println("Data generated."); System.out.println("Performing linear search..."); linearTime = -System.nanoTime(); PriorityQueue<Compare> pq = new PriorityQueue<Compare>(); for (int i = 0; i < numTest; i++) { double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center, location.get(i), new double[] { 1, 1, 1 }); pq.add(new Compare(input.get(i), distance)); } double distance = -1; for (int i = 0; i < clusterSize; i++) { if (distance == -1) distance = pq.peek().val; answer[i] = pq.poll().data; } linearTime += System.nanoTime(); Arrays.sort(answer); System.out.println("Linear search complete; time = " + (linearTime / 1E9)); System.out.println("Performing binary k-d tree search..."); tree2time = -System.nanoTime(); PRKdBucketTree.KdCluster<String> tree2r = tree2 .getNearestNeighbor(clusterSize, center, new double[] { 1, 1, 1 }); Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues(); tree2time += System.nanoTime(); System.out.println("Binary tree search complete; time = " + (tree2time / 1E9)); int correct = 0; int j = 0; String[] tree2o = new String[clusterSize]; for (PRKdBucketTree.KdPoint<String> p : tree2a) { tree2o[j++] = p.getValue(); } Arrays.sort(tree2o); for (int i = 0; i < tree2o.length; i++) { if (tree2o[i].equals(answer[i])) correct++; } System.out.println(": accuracy = " + ((double) correct / clusterSize)); System.out.println("Performing ternary k-d tree search..."); tree3time = -System.nanoTime(); PRKdBucketTree.KdCluster<String> tree3r = tree3 .getNearestNeighbor(clusterSize, center, new double[] { 1, 1, 1 }); Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues(); tree3time += System.nanoTime(); System.out.println("Ternary tree search complete; time = " + (tree3time / 1E9)); correct = 0; j = 0; String[] tree3o = new String[clusterSize]; for (PRKdBucketTree.KdPoint<String> p : tree3a) { tree3o[j++] = p.getValue(); } Arrays.sort(tree3o); for (int i = 0; i < tree3o.length; i++) { if (tree3o[i].equals(answer[i])) correct++; } System.out.println(": accuracy = " + ((double) correct / clusterSize)); System.out.println("Performing quaternary k-d tree search..."); tree4time = -System.nanoTime(); PRKdBucketTree.KdCluster<String> tree4r = tree4 .getNearestNeighbor(clusterSize, center, new double[] { 1, 1, 1 }); Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues(); tree4time += System.nanoTime(); System.out.println("Quaternary tree search complete; time = " + (tree4time / 1E9)); correct = 0; j = 0; String[] tree4o = new String[clusterSize]; for (PRKdBucketTree.KdPoint<String> p : tree4a) { tree4o[j++] = p.getValue(); } Arrays.sort(tree4o); for (int i = 0; i < tree4o.length; i++) { if (tree4o[i].equals(answer[i])) correct++; } System.out.println(": accuracy = " + ((double) correct / clusterSize)); } private static class Compare implements Comparable<Compare> { String data; double val; @Override public int compareTo(Compare o) { return (int) -Math.signum(o.val - val); } public Compare(String data, double val) { this.data = data; this.val = val; } } private static String generateRandomString() { String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; Random r = new Random(); char[] buf = new char[25]; for (int i = 0; i < buf.length; i++) { buf[i] = chars.charAt(r.nextInt(chars.length())); } return new String(buf); } }
Old
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
package nat.tree; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.PriorityQueue; import java.util.Queue; import nat.util.M; /** * * Implementation of bucket PR k-d tree. * * @author Nat Pavasant * * @param <V> * The type of data to store */ public class PRKdBucketTree<V> implements Serializable { private static final long serialVersionUID = 1L; public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer(); public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer(); private final PRKdBucketTree<V>[] children; private final Queue<KdEntry<V>> data; private final Distancer distancer; private final double[] lowerBound, upperBound; private final int[] numChildren; private final int allDimensions; private final int dimension; private final int maxDepth, maxDensity; private final double splitMedian; private boolean isLeaf = true; /** * Create new Bucket PR k-d tree. * * @param allDimensions * number of dimensions in the tree * @param lowerBound * the minimum value of the location of each dimension * @param upperBound * the maximum value of the location of each dimension * @param numChildren * number of children in each dimension * @param maxDepth * the max depth of the tree * @param maxDensity * size of bucket in each leaf * @param distancer * distance measurer */ @SuppressWarnings("unchecked") public PRKdBucketTree(int allDimensions, double[] lowerBound, double[] upperBound, int[] numChildren, int maxDepth, int maxDensity, Distancer distancer) { if (allDimensions < 1 || maxDensity < 1) throw new IllegalArgumentException( "Either dimension or density isn't positive integer."); if (lowerBound.length != allDimensions || upperBound.length != allDimensions || numChildren.length != allDimensions) throw new IllegalArgumentException( "Either bounds or children amount is more or less than dimension count."); for (double a : lowerBound) { if (a < 0) throw new IllegalArgumentException( "Can't set lower bound to negative number."); } for (int i = 0; i < lowerBound.length; i++) { if (lowerBound[i] > upperBound[i]) throw new IllegalArgumentException( "Upper bound must have a value higer than lower bound."); } this.allDimensions = allDimensions; this.lowerBound = lowerBound; this.upperBound = upperBound; this.numChildren = numChildren; this.distancer = distancer; this.maxDepth = maxDepth; this.maxDensity = maxDensity; this.dimension = maxDepth % allDimensions; this.data = new LinkedList<KdEntry<V>>(); this.children = new PRKdBucketTree[numChildren[this.dimension]]; this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension]) / numChildren[this.dimension]; } // Another constructor public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double lowerBound, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { double[] lowerBounds = new double[dimension]; double[] upperBounds = new double[dimension]; int[] numChildrens = new int[dimension]; Arrays.fill(lowerBounds, lowerBound); Arrays.fill(upperBounds, upperBound); Arrays.fill(numChildrens, numChildren); return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds, numChildrens, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity, Distancer distance) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, distance); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, numChildren, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int maxDepth, int maxDensity) { return getTree(dataType, dimension, 0, upperBound, 2, maxDepth, maxDensity, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound, int numChildren) { return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, double upperBound) { return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN); } public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) { return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN); } /** * Add new point to the tree * * @param value * the stored value * @param location * location of the point * @return removed point if any */ public KdEntry<V> addPoint(V value, double[] location) { if (location.length != allDimensions) { throw new IllegalArgumentException( "Provided location have either more or less dimensions than the tree."); } KdEntry<V> entry = new KdEntry<V>(value, location); return addPoint(entry); } /** * Get the n-nearest neighbor. * * @param size * number of neighbors * @param center * center of the cluster * @param weight * weighting for the distancer * @return */ public KdCluster<V> getNearestNeighbor(int size, double[] center, double[] weight) { KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer); nearestNeighborSearch(cluster); return cluster; } /** * Back-end implementation of the entry adding. * * @param entry * new entry to the tree * @return removed point if any */ private KdEntry<V> addPoint(KdEntry<V> entry) { if (isLeaf) { // Still has spaces, add the data if (data.size() < maxDensity) { data.add(entry); return null; } // final leaf, unsplitable, remove element and add new one. if (maxDepth <= 1) { data.add(entry); return data.poll(); } // if we reached here, we need to split this leaf to a branch. isLeaf = false; for (KdEntry<V> p : data) { passToChildren(p); } data.clear(); } // we are branch, pass to children return passToChildren(entry); } /** * Perform n-nearest neighbor search * * @param cluster * current working cluster */ private void nearestNeighborSearch(KdCluster<V> cluster) { if (isLeaf) { for (KdEntry<V> p : data) { cluster.consider(p); } } else { for (PRKdBucketTree<V> child : children) { if (child != null && cluster.isViable(child)) child.nearestNeighborSearch(cluster); } } } /** * Pass the entry to correct children * * @param p * entry * @return removed point if any */ private KdEntry<V> passToChildren(KdEntry<V> p) { int i = getChildrenIndex(p.getLocation(dimension)); if (children[i] == null) children[i] = createChildTree(i); return children[i].addPoint(p); } /** * Return index of the children which will contains data with that value. * * @param value * the value * @return index of children */ private int getChildrenIndex(double value) { return (int) M.limit(0, Math.floor(value / splitMedian), numChildren[dimension] - 1); } /** * Create children tree with correct bound. * * @param i * the index of the children * @return created tree */ private PRKdBucketTree<V> createChildTree(int i) { double[] upperBound = this.upperBound.clone(); double[] lowerBound = this.lowerBound.clone(); lowerBound[i] = splitMedian * i; upperBound[i] = splitMedian * (i + 1); return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound, numChildren, maxDepth - 1, maxDensity, distancer); } public static class KdCluster<K> implements Iterable<KdPoint<K>> { private final PriorityQueue<KdPoint<K>> points; private final double[] center; private final double[] weight; private final int size; private final Distancer distancer; public KdCluster(int size, double[] center, double[] weight, Distancer distancer) { points = new PriorityQueue<KdPoint<K>>(); this.size = size; this.center = center; this.weight = weight; this.distancer = distancer; } public void consider(KdEntry<K> k) { KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer); if (points.size() < size) { points.add(p); } else if (points.peek().isFurtherThan(p)) { points.poll(); points.add(p); } } public boolean isViable(PRKdBucketTree<K> tree) { if (points.size() < size) return true; double[] testPoints = new double[center.length]; for (int i = 0; i < center.length; i++) testPoints[i] = M.limit(tree.lowerBound[i], center[i], tree.upperBound[i]); return points.peek().isFurtherThan( distancer.getDistance(center, testPoints, weight)); } @Override public Iterator<KdPoint<K>> iterator() { return points.iterator(); } public Collection<KdPoint<K>> getValues() { Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points .size()); for (KdPoint<K> p : points) { collect.add(p); } return collect; } } private static class KdEntry<K> implements Serializable { private static final long serialVersionUID = 1L; private final K value; private final double[] location; public KdEntry(K value, double[] location) { super(); this.value = value; this.location = location; } public K getValue() { return value; } public double[] getLocation() { return location; } public double getLocation(int a) { return location[a]; } } public static class KdPoint<K> extends KdEntry<K> implements Serializable, Comparable<KdPoint<K>> { private static final long serialVersionUID = 1L; private final double distanceToCenter; public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) { super(p.getValue(), p.getLocation()); distanceToCenter = distancer.getDistance(center, getLocation(), weight); } public double getDistanceToCenter() { return distanceToCenter; } @Override public int compareTo(KdPoint<K> o) { return (int) Math.signum(o.distanceToCenter - distanceToCenter); } public boolean isFurtherThan(KdPoint<K> p) { return compareTo(p) == -1; } public boolean isFurtherThan(double distance) { return distanceToCenter > distance; } } public static abstract class Distancer { public double getDistance(double[] p1, double[] p2, double[] weight) { if (p1.length != p2.length) throw new IllegalArgumentException(); return getPointDistance(p1, p2, weight); } public abstract double getPointDistance(double[] p1, double[] p2, double[] weight); public static class EuclidianDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.sqr(p1[i] - p2[i]) * weight[i]; } return M.sqrt(result); } } public static class ManhattanDistancer extends Distancer { @Override public double getPointDistance(double[] p1, double[] p2, double[] weight) { double result = 0; for (int i = 0; i < p1.length; i++) { result += M.abs(p1[i] - p2[i]) * weight[i]; } return result; } } } }
I use this code to check:
package nat.tree; import java.util.ArrayList; import java.util.Collection; import java.util.PriorityQueue; import java.util.Random; import nat.tree.PRKdBucketTree.KdPoint; public class Test { public static void main(String[] args) { final int numTest = 13000; final int clusterSize = 10; long linearTime, tree2time, tree3time; String[] answer = new String[100]; System.out.println("Starting Bucket PR k-d tree performance test..."); System.out.println("Generating points..."); ArrayList<String> input = new ArrayList<String>(); ArrayList<double[]> location = new ArrayList<double[]>(); for (int i = 0; i < numTest; i++) { input.add(generateRandomString()); double[] p = new double[3]; p[0] = Math.random(); p[1] = Math.random(); p[2] = Math.random(); location.add(p); } PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2); PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3); for (int i = 0; i < numTest; i++) { tree2.addPoint(input.get(i), location.get(i)); tree3.addPoint(input.get(i), location.get(i)); } double[] center = new double[3]; center[0] = Math.random(); center[1] = Math.random(); center[2] = Math.random(); System.out.println("Data generated."); System.out.println("Performing linear search..."); linearTime = -System.nanoTime(); PriorityQueue<Compare> pq = new PriorityQueue<Compare>(); for (int i = 0; i < numTest; i++) { double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center, location.get(i), new double[] { 1, 1, 1 }); pq.add(new Compare(input.get(i), distance)); } double distance = -1; for (int i = 0; i < clusterSize; i++) { if (distance == -1) distance = pq.peek().val; answer[i] = pq.poll().data; } linearTime += System.nanoTime(); System.out.println("Linear search complete; time = " + (linearTime / 1E9)); System.out.println("Performing binary k-d tree search..."); tree2time = -System.nanoTime(); PRKdBucketTree.KdCluster<String> tree2r = tree2 .getNearestNeighbor(clusterSize, center, new double[] { 1, 1, 1 }); Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues(); tree2time += System.nanoTime(); System.out.println("Distance = " + distance + "; " + tree2a.iterator().next().getDistanceToCenter()); System.out.println("Binary tree search complete; time = " + (tree2time / 1E9)); int correct = 0; int j = 0; for (PRKdBucketTree.KdPoint<String> p : tree2a) { System.out.print(p.getValue()); System.out.print(" "); System.out.println(answer[j]); if (p.getValue().equals(answer[j++])) correct++; } System.out.println(": accuracy = " + ((double) correct / clusterSize)); } private static class Compare implements Comparable<Compare> { String data; double val; @Override public int compareTo(Compare o) { return (int) -Math.signum(o.val - val); } public Compare(String data, double val) { this.data = data; this.val = val; } } private static String generateRandomString() { String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; Random r = new Random(); char[] buf = new char[15]; for (int i = 0; i < buf.length; i++) { buf[i] = chars.charAt(r.nextInt(chars.length())); } return new String(buf); } }
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)
So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree()
where is should be upperBound[dimension]
and lowerBound[dimension]
instead of i
. » Nat | Talk » 16:30, 15 August 2009 (UTC)