Difference between revisions of "Help:Help/Nat/Kd-Tree"
Jump to navigation
Jump to search
(anyone please help me!!!) |
(latest code that I've problem) |
||
| (One intermediate revision by the same user not shown) | |||
| Line 1: | Line 1: | ||
| + | Here is my code that I have problem state in [[Talk:Kd-Tree]]. My ''M'' class can be found at [[User:Nat/Free code]]. | ||
| + | |||
| + | == Tree == | ||
| + | <pre> | ||
| + | package nat.tree; | ||
| + | |||
| + | import java.io.Serializable; | ||
| + | import java.util.ArrayList; | ||
| + | import java.util.Arrays; | ||
| + | import java.util.Collection; | ||
| + | import java.util.Iterator; | ||
| + | import java.util.LinkedList; | ||
| + | import java.util.PriorityQueue; | ||
| + | import java.util.Queue; | ||
| + | |||
| + | import nat.util.M; | ||
| + | |||
| + | /** | ||
| + | * | ||
| + | * Implementation of bucket PR k-d tree. | ||
| + | * | ||
| + | * @author Nat Pavasant | ||
| + | * | ||
| + | * @param <V> | ||
| + | * The type of data to store | ||
| + | */ | ||
| + | public class PRKdBucketTree<V> implements Serializable { | ||
| + | |||
| + | private static final long serialVersionUID = 1L; | ||
| + | |||
| + | public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer(); | ||
| + | public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer(); | ||
| + | |||
| + | private final PRKdBucketTree<V>[] children; | ||
| + | private final Queue<KdEntry<V>> data; | ||
| + | private final Distancer distancer; | ||
| + | private final double[] lowerBound, upperBound; | ||
| + | private final int[] numChildren; | ||
| + | private final int allDimensions; | ||
| + | private final int dimension; | ||
| + | private final int maxDepth, maxDensity; | ||
| + | private final double splitMedian; | ||
| + | |||
| + | private boolean isLeaf = true; | ||
| + | |||
| + | /** | ||
| + | * Create new Bucket PR k-d tree. | ||
| + | * | ||
| + | * @param allDimensions | ||
| + | * number of dimensions in the tree | ||
| + | * @param lowerBound | ||
| + | * the minimum value of the location of each dimension | ||
| + | * @param upperBound | ||
| + | * the maximum value of the location of each dimension | ||
| + | * @param numChildren | ||
| + | * number of children in each dimension | ||
| + | * @param maxDepth | ||
| + | * the max depth of the tree | ||
| + | * @param maxDensity | ||
| + | * size of bucket in each leaf | ||
| + | * @param distancer | ||
| + | * distance measurer | ||
| + | */ | ||
| + | @SuppressWarnings("unchecked") | ||
| + | public PRKdBucketTree(int allDimensions, double[] lowerBound, | ||
| + | double[] upperBound, int[] numChildren, int maxDepth, | ||
| + | int maxDensity, Distancer distancer) { | ||
| + | if (allDimensions < 1 || maxDensity < 1) | ||
| + | throw new IllegalArgumentException( | ||
| + | "Either dimension or density isn't positive integer."); | ||
| + | |||
| + | if (lowerBound.length != allDimensions | ||
| + | || upperBound.length != allDimensions | ||
| + | || numChildren.length != allDimensions) | ||
| + | throw new IllegalArgumentException( | ||
| + | "Either bounds or children amount is more or less than dimension count."); | ||
| + | |||
| + | for (double a : lowerBound) { | ||
| + | if (a < 0) | ||
| + | throw new IllegalArgumentException( | ||
| + | "Can't set lower bound to negative number."); | ||
| + | } | ||
| + | |||
| + | for (int i = 0; i < lowerBound.length; i++) { | ||
| + | if (lowerBound[i] > upperBound[i]) | ||
| + | throw new IllegalArgumentException( | ||
| + | "Upper bound must have a value higer than lower bound."); | ||
| + | } | ||
| + | |||
| + | this.allDimensions = allDimensions; | ||
| + | this.lowerBound = lowerBound; | ||
| + | this.upperBound = upperBound; | ||
| + | this.numChildren = numChildren; | ||
| + | this.distancer = distancer; | ||
| + | this.maxDepth = maxDepth; | ||
| + | this.maxDensity = maxDensity; | ||
| + | this.dimension = maxDepth % allDimensions; | ||
| + | this.data = new LinkedList<KdEntry<V>>(); | ||
| + | this.children = new PRKdBucketTree[numChildren[this.dimension]]; | ||
| + | this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension]) | ||
| + | / numChildren[this.dimension]; | ||
| + | } | ||
| + | |||
| + | // Another constructor | ||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double lowerBound, double upperBound, int numChildren, | ||
| + | int maxDepth, int maxDensity, Distancer distance) { | ||
| + | double[] lowerBounds = new double[dimension]; | ||
| + | double[] upperBounds = new double[dimension]; | ||
| + | int[] numChildrens = new int[dimension]; | ||
| + | Arrays.fill(lowerBounds, lowerBound); | ||
| + | Arrays.fill(upperBounds, upperBound); | ||
| + | Arrays.fill(numChildrens, numChildren); | ||
| + | return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds, | ||
| + | numChildrens, maxDepth, maxDensity, distance); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double upperBound, int numChildren, int maxDepth, int maxDensity, | ||
| + | Distancer distance) { | ||
| + | return getTree(dataType, dimension, 0, upperBound, numChildren, | ||
| + | maxDepth, maxDensity, distance); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double upperBound, int numChildren, int maxDepth, int maxDensity) { | ||
| + | return getTree(dataType, dimension, 0, upperBound, numChildren, | ||
| + | maxDepth, maxDensity, EUCLIDIAN); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double upperBound, int maxDepth, int maxDensity) { | ||
| + | return getTree(dataType, dimension, 0, upperBound, 2, maxDepth, | ||
| + | maxDensity, EUCLIDIAN); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double upperBound, int numChildren) { | ||
| + | return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8, | ||
| + | EUCLIDIAN); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension, | ||
| + | double upperBound) { | ||
| + | return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN); | ||
| + | } | ||
| + | |||
| + | public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) { | ||
| + | return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN); | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Add new point to the tree | ||
| + | * | ||
| + | * @param value | ||
| + | * the stored value | ||
| + | * @param location | ||
| + | * location of the point | ||
| + | * @return removed point if any | ||
| + | */ | ||
| + | public KdEntry<V> addPoint(V value, double[] location) { | ||
| + | if (location.length != allDimensions) { | ||
| + | throw new IllegalArgumentException( | ||
| + | "Provided location have either more or less dimensions than the tree."); | ||
| + | } | ||
| + | |||
| + | KdEntry<V> entry = new KdEntry<V>(value, location); | ||
| + | |||
| + | return addPoint(entry); | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Get the n-nearest neighbor. | ||
| + | * | ||
| + | * @param size | ||
| + | * number of neighbors | ||
| + | * @param center | ||
| + | * center of the cluster | ||
| + | * @param weight | ||
| + | * weighting for the distancer | ||
| + | * @return | ||
| + | */ | ||
| + | public KdCluster<V> getNearestNeighbor(int size, double[] center, | ||
| + | double[] weight) { | ||
| + | KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer); | ||
| + | nearestNeighborSearch(cluster); | ||
| + | return cluster; | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Back-end implementation of the entry adding. | ||
| + | * | ||
| + | * @param entry | ||
| + | * new entry to the tree | ||
| + | * @return removed point if any | ||
| + | */ | ||
| + | private KdEntry<V> addPoint(KdEntry<V> entry) { | ||
| + | |||
| + | if (isLeaf) { | ||
| + | |||
| + | // Still has spaces, add the data | ||
| + | if (data.size() < maxDensity) { | ||
| + | data.add(entry); | ||
| + | return null; | ||
| + | } | ||
| + | |||
| + | // final leaf, unsplitable, remove element and add new one. | ||
| + | if (maxDepth <= 1) { | ||
| + | data.add(entry); | ||
| + | return data.poll(); | ||
| + | } | ||
| + | |||
| + | // if we reached here, we need to split this leaf to a branch. | ||
| + | isLeaf = false; | ||
| + | for (KdEntry<V> p : data) { | ||
| + | passToChildren(p); | ||
| + | } | ||
| + | |||
| + | data.clear(); | ||
| + | } | ||
| + | |||
| + | // we are branch, pass to children | ||
| + | return passToChildren(entry); | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Perform n-nearest neighbor search | ||
| + | * | ||
| + | * @param cluster | ||
| + | * current working cluster | ||
| + | */ | ||
| + | private void nearestNeighborSearch(KdCluster<V> cluster) { | ||
| + | if (isLeaf) { | ||
| + | for (KdEntry<V> p : data) { | ||
| + | cluster.consider(p); | ||
| + | } | ||
| + | } else { | ||
| + | for (PRKdBucketTree<V> child : children) { | ||
| + | if (child != null && cluster.isViable(child)) | ||
| + | child.nearestNeighborSearch(cluster); | ||
| + | } | ||
| + | } | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Pass the entry to correct children | ||
| + | * | ||
| + | * @param p | ||
| + | * entry | ||
| + | * @return removed point if any | ||
| + | */ | ||
| + | private KdEntry<V> passToChildren(KdEntry<V> p) { | ||
| + | int i = getChildrenIndex(p.getLocation(dimension)); | ||
| + | |||
| + | if (children[i] == null) | ||
| + | children[i] = createChildTree(i); | ||
| + | |||
| + | return children[i].addPoint(p); | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Return index of the children which will contains data with that value. | ||
| + | * | ||
| + | * @param value | ||
| + | * the value | ||
| + | * @return index of children | ||
| + | */ | ||
| + | private int getChildrenIndex(double value) { | ||
| + | return (int) M.limit(0, Math.floor(value / splitMedian), | ||
| + | numChildren[dimension] - 1); | ||
| + | } | ||
| + | |||
| + | /** | ||
| + | * Create children tree with correct bound. | ||
| + | * | ||
| + | * @param i | ||
| + | * the index of the children | ||
| + | * @return created tree | ||
| + | */ | ||
| + | private PRKdBucketTree<V> createChildTree(int i) { | ||
| + | double[] upperBound = this.upperBound.clone(); | ||
| + | double[] lowerBound = this.lowerBound.clone(); | ||
| + | lowerBound[dimension] = splitMedian * i; | ||
| + | upperBound[dimension] = splitMedian * (i + 1); | ||
| + | return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound, | ||
| + | numChildren, maxDepth - 1, maxDensity, distancer); | ||
| + | } | ||
| + | |||
| + | public static class KdCluster<K> implements Iterable<KdPoint<K>> { | ||
| + | private final PriorityQueue<KdPoint<K>> points; | ||
| + | private final double[] center; | ||
| + | private final double[] weight; | ||
| + | private final int size; | ||
| + | private final Distancer distancer; | ||
| + | |||
| + | public KdCluster(int size, double[] center, double[] weight, | ||
| + | Distancer distancer) { | ||
| + | points = new PriorityQueue<KdPoint<K>>(); | ||
| + | this.size = size; | ||
| + | this.center = center; | ||
| + | this.weight = weight; | ||
| + | this.distancer = distancer; | ||
| + | } | ||
| + | |||
| + | public void consider(KdEntry<K> k) { | ||
| + | KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer); | ||
| + | |||
| + | if (points.size() < size) { | ||
| + | points.add(p); | ||
| + | System.out.println(points.toString()); | ||
| + | } else if (points.peek().getDistanceToCenter() > p | ||
| + | .getDistanceToCenter()) { | ||
| + | System.out.println(points.toString()); | ||
| + | points.poll(); | ||
| + | points.add(p); | ||
| + | } | ||
| + | } | ||
| + | |||
| + | public boolean isViable(PRKdBucketTree<K> tree) { | ||
| + | |||
| + | if (points.size() < size) | ||
| + | return true; | ||
| + | |||
| + | double[] testPoints = new double[center.length]; | ||
| + | |||
| + | for (int i = 0; i < center.length; i++) | ||
| + | testPoints[i] = M.limit(tree.lowerBound[i], center[i], | ||
| + | tree.upperBound[i]); | ||
| + | |||
| + | return points.peek().getDistanceToCenter() > distancer.getDistance( | ||
| + | center, testPoints, weight); | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public Iterator<KdPoint<K>> iterator() { | ||
| + | return points.iterator(); | ||
| + | } | ||
| + | |||
| + | public Collection<KdPoint<K>> getValues() { | ||
| + | Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points | ||
| + | .size()); | ||
| + | |||
| + | for (KdPoint<K> p : points) { | ||
| + | collect.add(p); | ||
| + | } | ||
| + | |||
| + | return collect; | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public int hashCode() { | ||
| + | final int prime = 31; | ||
| + | int result = 1; | ||
| + | result = prime * result + Arrays.hashCode(center); | ||
| + | result = prime * result | ||
| + | + ((distancer == null) ? 0 : distancer.hashCode()); | ||
| + | result = prime * result | ||
| + | + ((points == null) ? 0 : points.hashCode()); | ||
| + | result = prime * result + size; | ||
| + | result = prime * result + Arrays.hashCode(weight); | ||
| + | return result; | ||
| + | } | ||
| + | |||
| + | @SuppressWarnings("unchecked") | ||
| + | @Override | ||
| + | public boolean equals(Object obj) { | ||
| + | if (this == obj) | ||
| + | return true; | ||
| + | if (obj == null) | ||
| + | return false; | ||
| + | if (!(obj instanceof KdCluster)) | ||
| + | return false; | ||
| + | KdCluster other = (KdCluster) obj; | ||
| + | if (!Arrays.equals(center, other.center)) | ||
| + | return false; | ||
| + | if (distancer == null) { | ||
| + | if (other.distancer != null) | ||
| + | return false; | ||
| + | } else if (!distancer.equals(other.distancer)) | ||
| + | return false; | ||
| + | if (points == null) { | ||
| + | if (other.points != null) | ||
| + | return false; | ||
| + | } else if (!points.equals(other.points)) | ||
| + | return false; | ||
| + | if (size != other.size) | ||
| + | return false; | ||
| + | if (!Arrays.equals(weight, other.weight)) | ||
| + | return false; | ||
| + | return true; | ||
| + | } | ||
| + | } | ||
| + | |||
| + | private static class KdEntry<K> implements Serializable { | ||
| + | private static final long serialVersionUID = 1L; | ||
| + | |||
| + | private final K value; | ||
| + | private final double[] location; | ||
| + | |||
| + | public KdEntry(K value, double[] location) { | ||
| + | super(); | ||
| + | this.value = value; | ||
| + | this.location = location; | ||
| + | } | ||
| + | |||
| + | public K getValue() { | ||
| + | return value; | ||
| + | } | ||
| + | |||
| + | public double[] getLocation() { | ||
| + | return location; | ||
| + | } | ||
| + | |||
| + | public double getLocation(int a) { | ||
| + | return location[a]; | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public int hashCode() { | ||
| + | final int prime = 31; | ||
| + | int result = 1; | ||
| + | result = prime * result + Arrays.hashCode(location); | ||
| + | result = prime * result + ((value == null) ? 0 : value.hashCode()); | ||
| + | return result; | ||
| + | } | ||
| + | |||
| + | @SuppressWarnings("unchecked") | ||
| + | @Override | ||
| + | public boolean equals(Object obj) { | ||
| + | if (this == obj) | ||
| + | return true; | ||
| + | if (obj == null) | ||
| + | return false; | ||
| + | if (!(obj instanceof KdEntry)) | ||
| + | return false; | ||
| + | KdEntry other = (KdEntry) obj; | ||
| + | if (!Arrays.equals(location, other.location)) | ||
| + | return false; | ||
| + | if (value == null) { | ||
| + | if (other.value != null) | ||
| + | return false; | ||
| + | } else if (!value.equals(other.value)) | ||
| + | return false; | ||
| + | return true; | ||
| + | } | ||
| + | } | ||
| + | |||
| + | public static class KdPoint<K> extends KdEntry<K> implements Serializable, | ||
| + | Comparable<KdPoint<K>> { | ||
| + | private static final long serialVersionUID = 1L; | ||
| + | private final double distanceToCenter; | ||
| + | |||
| + | public KdPoint(KdEntry<K> p, double[] center, double[] weight, | ||
| + | Distancer distancer) { | ||
| + | super(p.getValue(), p.getLocation()); | ||
| + | distanceToCenter = distancer.getDistance(center, getLocation(), | ||
| + | weight); | ||
| + | } | ||
| + | |||
| + | public double getDistanceToCenter() { | ||
| + | return distanceToCenter; | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public String toString() { | ||
| + | return (new Double(distanceToCenter)).toString(); | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public int compareTo(KdPoint<K> o) { | ||
| + | return (int) Math.signum(o.distanceToCenter - distanceToCenter); | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public int hashCode() { | ||
| + | final int prime = 31; | ||
| + | int result = super.hashCode(); | ||
| + | long temp; | ||
| + | temp = Double.doubleToLongBits(distanceToCenter); | ||
| + | result = prime * result + (int) (temp ^ (temp >>> 32)); | ||
| + | return result; | ||
| + | } | ||
| + | |||
| + | @SuppressWarnings("unchecked") | ||
| + | @Override | ||
| + | public boolean equals(Object obj) { | ||
| + | if (this == obj) | ||
| + | return true; | ||
| + | if (!super.equals(obj)) | ||
| + | return false; | ||
| + | if (!(obj instanceof KdPoint)) | ||
| + | return false; | ||
| + | KdPoint other = (KdPoint) obj; | ||
| + | if (Double.doubleToLongBits(distanceToCenter) != Double | ||
| + | .doubleToLongBits(other.distanceToCenter)) | ||
| + | return false; | ||
| + | return true; | ||
| + | } | ||
| + | } | ||
| + | |||
| + | public static abstract class Distancer { | ||
| + | public double getDistance(double[] p1, double[] p2, double[] weight) { | ||
| + | if (p1.length != p2.length) | ||
| + | throw new IllegalArgumentException(); | ||
| + | return getPointDistance(p1, p2, weight); | ||
| + | } | ||
| + | |||
| + | public abstract double getPointDistance(double[] p1, double[] p2, | ||
| + | double[] weight); | ||
| + | |||
| + | public static class EuclidianDistancer extends Distancer { | ||
| + | @Override | ||
| + | public double getPointDistance(double[] p1, double[] p2, | ||
| + | double[] weight) { | ||
| + | double result = 0; | ||
| + | for (int i = 0; i < p1.length; i++) { | ||
| + | result += M.sqr(p1[i] - p2[i]) * weight[i]; | ||
| + | } | ||
| + | return M.sqrt(result); | ||
| + | } | ||
| + | } | ||
| + | |||
| + | public static class ManhattanDistancer extends Distancer { | ||
| + | @Override | ||
| + | public double getPointDistance(double[] p1, double[] p2, | ||
| + | double[] weight) { | ||
| + | double result = 0; | ||
| + | for (int i = 0; i < p1.length; i++) { | ||
| + | result += M.abs(p1[i] - p2[i]) * weight[i]; | ||
| + | } | ||
| + | return result; | ||
| + | } | ||
| + | } | ||
| + | } | ||
| + | |||
| + | @Override | ||
| + | public int hashCode() { | ||
| + | final int prime = 31; | ||
| + | int result = 1; | ||
| + | result = prime * result + allDimensions; | ||
| + | result = prime * result + Arrays.hashCode(children); | ||
| + | result = prime * result + ((data == null) ? 0 : data.hashCode()); | ||
| + | result = prime * result + dimension; | ||
| + | result = prime * result + (isLeaf ? 1231 : 1237); | ||
| + | result = prime * result + Arrays.hashCode(lowerBound); | ||
| + | result = prime * result + maxDensity; | ||
| + | result = prime * result + maxDepth; | ||
| + | result = prime * result + Arrays.hashCode(numChildren); | ||
| + | long temp; | ||
| + | temp = Double.doubleToLongBits(splitMedian); | ||
| + | result = prime * result + (int) (temp ^ (temp >>> 32)); | ||
| + | result = prime * result + Arrays.hashCode(upperBound); | ||
| + | return result; | ||
| + | } | ||
| + | |||
| + | @SuppressWarnings("unchecked") | ||
| + | @Override | ||
| + | public boolean equals(Object obj) { | ||
| + | if (this == obj) | ||
| + | return true; | ||
| + | if (obj == null) | ||
| + | return false; | ||
| + | if (!(obj instanceof PRKdBucketTree)) | ||
| + | return false; | ||
| + | PRKdBucketTree other = (PRKdBucketTree) obj; | ||
| + | if (allDimensions != other.allDimensions) | ||
| + | return false; | ||
| + | if (!Arrays.equals(children, other.children)) | ||
| + | return false; | ||
| + | if (data == null) { | ||
| + | if (other.data != null) | ||
| + | return false; | ||
| + | } else if (!data.equals(other.data)) | ||
| + | return false; | ||
| + | if (dimension != other.dimension) | ||
| + | return false; | ||
| + | if (isLeaf != other.isLeaf) | ||
| + | return false; | ||
| + | if (!Arrays.equals(lowerBound, other.lowerBound)) | ||
| + | return false; | ||
| + | if (maxDensity != other.maxDensity) | ||
| + | return false; | ||
| + | if (maxDepth != other.maxDepth) | ||
| + | return false; | ||
| + | if (!Arrays.equals(numChildren, other.numChildren)) | ||
| + | return false; | ||
| + | if (Double.doubleToLongBits(splitMedian) != Double | ||
| + | .doubleToLongBits(other.splitMedian)) | ||
| + | return false; | ||
| + | if (!Arrays.equals(upperBound, other.upperBound)) | ||
| + | return false; | ||
| + | return true; | ||
| + | } | ||
| + | |||
| + | } | ||
| + | </pre> | ||
| + | |||
| + | == Tester == | ||
| + | <pre> | ||
| + | package nat.tree; | ||
| + | |||
| + | import java.util.ArrayList; | ||
| + | import java.util.Arrays; | ||
| + | import java.util.Collection; | ||
| + | import java.util.PriorityQueue; | ||
| + | import java.util.Random; | ||
| + | |||
| + | public class Test { | ||
| + | public static void main(String[] args) { | ||
| + | final int numTest = 100; | ||
| + | final int clusterSize = 15; | ||
| + | |||
| + | long linearTime, tree2time, tree3time, tree4time; | ||
| + | String[] answer = new String[clusterSize]; | ||
| + | |||
| + | System.out.println("Starting Bucket PR k-d tree performance test..."); | ||
| + | System.out.println("Generating points..."); | ||
| + | |||
| + | ArrayList<String> input = new ArrayList<String>(); | ||
| + | ArrayList<double[]> location = new ArrayList<double[]>(); | ||
| + | |||
| + | for (int i = 0; i < numTest; i++) { | ||
| + | input.add(generateRandomString()); | ||
| + | double[] p = new double[3]; | ||
| + | p[0] = Math.random(); | ||
| + | p[1] = Math.random(); | ||
| + | p[2] = Math.random(); | ||
| + | location.add(p); | ||
| + | } | ||
| + | |||
| + | PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2); | ||
| + | PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3); | ||
| + | PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4); | ||
| + | |||
| + | for (int i = 0; i < numTest; i++) { | ||
| + | tree2.addPoint(input.get(i), location.get(i)); | ||
| + | tree3.addPoint(input.get(i), location.get(i)); | ||
| + | tree4.addPoint(input.get(i), location.get(i)); | ||
| + | } | ||
| + | |||
| + | double[] center = new double[3]; | ||
| + | center[0] = Math.random(); | ||
| + | center[1] = Math.random(); | ||
| + | center[2] = Math.random(); | ||
| + | |||
| + | System.out.println("Data generated."); | ||
| + | System.out.println("Performing linear search..."); | ||
| + | |||
| + | linearTime = -System.nanoTime(); | ||
| + | PriorityQueue<Compare> pq = new PriorityQueue<Compare>(); | ||
| + | |||
| + | for (int i = 0; i < numTest; i++) { | ||
| + | double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center, | ||
| + | location.get(i), new double[] { 1, 1, 1 }); | ||
| + | |||
| + | pq.add(new Compare(input.get(i), distance)); | ||
| + | } | ||
| + | |||
| + | double distance = -1; | ||
| + | for (int i = 0; i < clusterSize; i++) { | ||
| + | if (distance == -1) | ||
| + | distance = pq.peek().val; | ||
| + | answer[i] = pq.poll().data; | ||
| + | } | ||
| + | linearTime += System.nanoTime(); | ||
| + | Arrays.sort(answer); | ||
| + | System.out.println("Linear search complete; time = " | ||
| + | + (linearTime / 1E9)); | ||
| + | |||
| + | |||
| + | System.out.println("Performing binary k-d tree search..."); | ||
| + | |||
| + | tree2time = -System.nanoTime(); | ||
| + | |||
| + | PRKdBucketTree.KdCluster<String> tree2r = tree2 | ||
| + | .getNearestNeighbor(clusterSize, center, | ||
| + | new double[] { 1, 1, 1 }); | ||
| + | Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues(); | ||
| + | |||
| + | tree2time += System.nanoTime(); | ||
| + | |||
| + | System.out.println("Binary tree search complete; time = " | ||
| + | + (tree2time / 1E9)); | ||
| + | int correct = 0; | ||
| + | int j = 0; | ||
| + | String[] tree2o = new String[clusterSize]; | ||
| + | for (PRKdBucketTree.KdPoint<String> p : tree2a) { | ||
| + | tree2o[j++] = p.getValue(); | ||
| + | } | ||
| + | Arrays.sort(tree2o); | ||
| + | for (int i = 0; i < tree2o.length; i++) { | ||
| + | if (tree2o[i].equals(answer[i])) | ||
| + | correct++; | ||
| + | } | ||
| + | |||
| + | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
| + | |||
| + | System.out.println("Performing ternary k-d tree search..."); | ||
| + | |||
| + | tree3time = -System.nanoTime(); | ||
| + | |||
| + | PRKdBucketTree.KdCluster<String> tree3r = tree3 | ||
| + | .getNearestNeighbor(clusterSize, center, | ||
| + | new double[] { 1, 1, 1 }); | ||
| + | Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues(); | ||
| + | |||
| + | tree3time += System.nanoTime(); | ||
| + | |||
| + | System.out.println("Ternary tree search complete; time = " | ||
| + | + (tree3time / 1E9)); | ||
| + | correct = 0; | ||
| + | j = 0; | ||
| + | String[] tree3o = new String[clusterSize]; | ||
| + | for (PRKdBucketTree.KdPoint<String> p : tree3a) { | ||
| + | tree3o[j++] = p.getValue(); | ||
| + | } | ||
| + | Arrays.sort(tree3o); | ||
| + | for (int i = 0; i < tree3o.length; i++) { | ||
| + | if (tree3o[i].equals(answer[i])) | ||
| + | correct++; | ||
| + | } | ||
| + | |||
| + | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
| + | |||
| + | System.out.println("Performing quaternary k-d tree search..."); | ||
| + | |||
| + | tree4time = -System.nanoTime(); | ||
| + | |||
| + | PRKdBucketTree.KdCluster<String> tree4r = tree4 | ||
| + | .getNearestNeighbor(clusterSize, center, | ||
| + | new double[] { 1, 1, 1 }); | ||
| + | Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues(); | ||
| + | |||
| + | tree4time += System.nanoTime(); | ||
| + | |||
| + | System.out.println("Quaternary tree search complete; time = " | ||
| + | + (tree4time / 1E9)); | ||
| + | correct = 0; | ||
| + | j = 0; | ||
| + | String[] tree4o = new String[clusterSize]; | ||
| + | for (PRKdBucketTree.KdPoint<String> p : tree4a) { | ||
| + | tree4o[j++] = p.getValue(); | ||
| + | } | ||
| + | Arrays.sort(tree4o); | ||
| + | for (int i = 0; i < tree4o.length; i++) { | ||
| + | if (tree4o[i].equals(answer[i])) | ||
| + | correct++; | ||
| + | } | ||
| + | |||
| + | System.out.println(": accuracy = " + ((double) correct / clusterSize)); | ||
| + | } | ||
| + | |||
| + | private static class Compare implements Comparable<Compare> { | ||
| + | String data; | ||
| + | double val; | ||
| + | |||
| + | @Override | ||
| + | public int compareTo(Compare o) { | ||
| + | return (int) -Math.signum(o.val - val); | ||
| + | } | ||
| + | |||
| + | public Compare(String data, double val) { | ||
| + | this.data = data; | ||
| + | this.val = val; | ||
| + | } | ||
| + | |||
| + | } | ||
| + | |||
| + | private static String generateRandomString() { | ||
| + | String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; | ||
| + | Random r = new Random(); | ||
| + | char[] buf = new char[25]; | ||
| + | |||
| + | for (int i = 0; i < buf.length; i++) { | ||
| + | buf[i] = chars.charAt(r.nextInt(chars.length())); | ||
| + | } | ||
| + | |||
| + | return new String(buf); | ||
| + | } | ||
| + | } | ||
| + | |||
| + | </pre> | ||
| + | |||
| + | == Old == | ||
| + | ---- | ||
| + | |||
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me! | It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me! | ||
| Line 565: | Line 1,351: | ||
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-''ary'' tree style too. Anyway, please help. Thank you in advance =) » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 13:40, 15 August 2009 (UTC) | The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-''ary'' tree style too. Anyway, please help. Thank you in advance =) » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 13:40, 15 August 2009 (UTC) | ||
| + | |||
| + | So sad no one help me =( By the way, finally I spot the bug. It is in <code>createChildTree()</code> where is should be <code>upperBound[dimension]</code> and <code>lowerBound[dimension]</code> instead of <code>i</code>. » <span style="font-size:0.9em;color:darkgreen;">[[User:Nat|Nat]] | [[User_talk:Nat|Talk]]</span> » 16:30, 15 August 2009 (UTC) | ||
Latest revision as of 18:12, 15 August 2009
Here is my code that I have problem state in Talk:Kd-Tree. My M class can be found at User:Nat/Free code.
Tree
package nat.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import nat.util.M;
/**
*
* Implementation of bucket PR k-d tree.
*
* @author Nat Pavasant
*
* @param <V>
* The type of data to store
*/
public class PRKdBucketTree<V> implements Serializable {
private static final long serialVersionUID = 1L;
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
private final PRKdBucketTree<V>[] children;
private final Queue<KdEntry<V>> data;
private final Distancer distancer;
private final double[] lowerBound, upperBound;
private final int[] numChildren;
private final int allDimensions;
private final int dimension;
private final int maxDepth, maxDensity;
private final double splitMedian;
private boolean isLeaf = true;
/**
* Create new Bucket PR k-d tree.
*
* @param allDimensions
* number of dimensions in the tree
* @param lowerBound
* the minimum value of the location of each dimension
* @param upperBound
* the maximum value of the location of each dimension
* @param numChildren
* number of children in each dimension
* @param maxDepth
* the max depth of the tree
* @param maxDensity
* size of bucket in each leaf
* @param distancer
* distance measurer
*/
@SuppressWarnings("unchecked")
public PRKdBucketTree(int allDimensions, double[] lowerBound,
double[] upperBound, int[] numChildren, int maxDepth,
int maxDensity, Distancer distancer) {
if (allDimensions < 1 || maxDensity < 1)
throw new IllegalArgumentException(
"Either dimension or density isn't positive integer.");
if (lowerBound.length != allDimensions
|| upperBound.length != allDimensions
|| numChildren.length != allDimensions)
throw new IllegalArgumentException(
"Either bounds or children amount is more or less than dimension count.");
for (double a : lowerBound) {
if (a < 0)
throw new IllegalArgumentException(
"Can't set lower bound to negative number.");
}
for (int i = 0; i < lowerBound.length; i++) {
if (lowerBound[i] > upperBound[i])
throw new IllegalArgumentException(
"Upper bound must have a value higer than lower bound.");
}
this.allDimensions = allDimensions;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.numChildren = numChildren;
this.distancer = distancer;
this.maxDepth = maxDepth;
this.maxDensity = maxDensity;
this.dimension = maxDepth % allDimensions;
this.data = new LinkedList<KdEntry<V>>();
this.children = new PRKdBucketTree[numChildren[this.dimension]];
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
/ numChildren[this.dimension];
}
// Another constructor
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double lowerBound, double upperBound, int numChildren,
int maxDepth, int maxDensity, Distancer distance) {
double[] lowerBounds = new double[dimension];
double[] upperBounds = new double[dimension];
int[] numChildrens = new int[dimension];
Arrays.fill(lowerBounds, lowerBound);
Arrays.fill(upperBounds, upperBound);
Arrays.fill(numChildrens, numChildren);
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
numChildrens, maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity,
Distancer distance) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren) {
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound) {
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
}
/**
* Add new point to the tree
*
* @param value
* the stored value
* @param location
* location of the point
* @return removed point if any
*/
public KdEntry<V> addPoint(V value, double[] location) {
if (location.length != allDimensions) {
throw new IllegalArgumentException(
"Provided location have either more or less dimensions than the tree.");
}
KdEntry<V> entry = new KdEntry<V>(value, location);
return addPoint(entry);
}
/**
* Get the n-nearest neighbor.
*
* @param size
* number of neighbors
* @param center
* center of the cluster
* @param weight
* weighting for the distancer
* @return
*/
public KdCluster<V> getNearestNeighbor(int size, double[] center,
double[] weight) {
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
nearestNeighborSearch(cluster);
return cluster;
}
/**
* Back-end implementation of the entry adding.
*
* @param entry
* new entry to the tree
* @return removed point if any
*/
private KdEntry<V> addPoint(KdEntry<V> entry) {
if (isLeaf) {
// Still has spaces, add the data
if (data.size() < maxDensity) {
data.add(entry);
return null;
}
// final leaf, unsplitable, remove element and add new one.
if (maxDepth <= 1) {
data.add(entry);
return data.poll();
}
// if we reached here, we need to split this leaf to a branch.
isLeaf = false;
for (KdEntry<V> p : data) {
passToChildren(p);
}
data.clear();
}
// we are branch, pass to children
return passToChildren(entry);
}
/**
* Perform n-nearest neighbor search
*
* @param cluster
* current working cluster
*/
private void nearestNeighborSearch(KdCluster<V> cluster) {
if (isLeaf) {
for (KdEntry<V> p : data) {
cluster.consider(p);
}
} else {
for (PRKdBucketTree<V> child : children) {
if (child != null && cluster.isViable(child))
child.nearestNeighborSearch(cluster);
}
}
}
/**
* Pass the entry to correct children
*
* @param p
* entry
* @return removed point if any
*/
private KdEntry<V> passToChildren(KdEntry<V> p) {
int i = getChildrenIndex(p.getLocation(dimension));
if (children[i] == null)
children[i] = createChildTree(i);
return children[i].addPoint(p);
}
/**
* Return index of the children which will contains data with that value.
*
* @param value
* the value
* @return index of children
*/
private int getChildrenIndex(double value) {
return (int) M.limit(0, Math.floor(value / splitMedian),
numChildren[dimension] - 1);
}
/**
* Create children tree with correct bound.
*
* @param i
* the index of the children
* @return created tree
*/
private PRKdBucketTree<V> createChildTree(int i) {
double[] upperBound = this.upperBound.clone();
double[] lowerBound = this.lowerBound.clone();
lowerBound[dimension] = splitMedian * i;
upperBound[dimension] = splitMedian * (i + 1);
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
numChildren, maxDepth - 1, maxDensity, distancer);
}
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
private final PriorityQueue<KdPoint<K>> points;
private final double[] center;
private final double[] weight;
private final int size;
private final Distancer distancer;
public KdCluster(int size, double[] center, double[] weight,
Distancer distancer) {
points = new PriorityQueue<KdPoint<K>>();
this.size = size;
this.center = center;
this.weight = weight;
this.distancer = distancer;
}
public void consider(KdEntry<K> k) {
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
if (points.size() < size) {
points.add(p);
System.out.println(points.toString());
} else if (points.peek().getDistanceToCenter() > p
.getDistanceToCenter()) {
System.out.println(points.toString());
points.poll();
points.add(p);
}
}
public boolean isViable(PRKdBucketTree<K> tree) {
if (points.size() < size)
return true;
double[] testPoints = new double[center.length];
for (int i = 0; i < center.length; i++)
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
tree.upperBound[i]);
return points.peek().getDistanceToCenter() > distancer.getDistance(
center, testPoints, weight);
}
@Override
public Iterator<KdPoint<K>> iterator() {
return points.iterator();
}
public Collection<KdPoint<K>> getValues() {
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
.size());
for (KdPoint<K> p : points) {
collect.add(p);
}
return collect;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(center);
result = prime * result
+ ((distancer == null) ? 0 : distancer.hashCode());
result = prime * result
+ ((points == null) ? 0 : points.hashCode());
result = prime * result + size;
result = prime * result + Arrays.hashCode(weight);
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof KdCluster))
return false;
KdCluster other = (KdCluster) obj;
if (!Arrays.equals(center, other.center))
return false;
if (distancer == null) {
if (other.distancer != null)
return false;
} else if (!distancer.equals(other.distancer))
return false;
if (points == null) {
if (other.points != null)
return false;
} else if (!points.equals(other.points))
return false;
if (size != other.size)
return false;
if (!Arrays.equals(weight, other.weight))
return false;
return true;
}
}
private static class KdEntry<K> implements Serializable {
private static final long serialVersionUID = 1L;
private final K value;
private final double[] location;
public KdEntry(K value, double[] location) {
super();
this.value = value;
this.location = location;
}
public K getValue() {
return value;
}
public double[] getLocation() {
return location;
}
public double getLocation(int a) {
return location[a];
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(location);
result = prime * result + ((value == null) ? 0 : value.hashCode());
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof KdEntry))
return false;
KdEntry other = (KdEntry) obj;
if (!Arrays.equals(location, other.location))
return false;
if (value == null) {
if (other.value != null)
return false;
} else if (!value.equals(other.value))
return false;
return true;
}
}
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
Comparable<KdPoint<K>> {
private static final long serialVersionUID = 1L;
private final double distanceToCenter;
public KdPoint(KdEntry<K> p, double[] center, double[] weight,
Distancer distancer) {
super(p.getValue(), p.getLocation());
distanceToCenter = distancer.getDistance(center, getLocation(),
weight);
}
public double getDistanceToCenter() {
return distanceToCenter;
}
@Override
public String toString() {
return (new Double(distanceToCenter)).toString();
}
@Override
public int compareTo(KdPoint<K> o) {
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
}
@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
long temp;
temp = Double.doubleToLongBits(distanceToCenter);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (!super.equals(obj))
return false;
if (!(obj instanceof KdPoint))
return false;
KdPoint other = (KdPoint) obj;
if (Double.doubleToLongBits(distanceToCenter) != Double
.doubleToLongBits(other.distanceToCenter))
return false;
return true;
}
}
public static abstract class Distancer {
public double getDistance(double[] p1, double[] p2, double[] weight) {
if (p1.length != p2.length)
throw new IllegalArgumentException();
return getPointDistance(p1, p2, weight);
}
public abstract double getPointDistance(double[] p1, double[] p2,
double[] weight);
public static class EuclidianDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.sqr(p1[i] - p2[i]) * weight[i];
}
return M.sqrt(result);
}
}
public static class ManhattanDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.abs(p1[i] - p2[i]) * weight[i];
}
return result;
}
}
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + allDimensions;
result = prime * result + Arrays.hashCode(children);
result = prime * result + ((data == null) ? 0 : data.hashCode());
result = prime * result + dimension;
result = prime * result + (isLeaf ? 1231 : 1237);
result = prime * result + Arrays.hashCode(lowerBound);
result = prime * result + maxDensity;
result = prime * result + maxDepth;
result = prime * result + Arrays.hashCode(numChildren);
long temp;
temp = Double.doubleToLongBits(splitMedian);
result = prime * result + (int) (temp ^ (temp >>> 32));
result = prime * result + Arrays.hashCode(upperBound);
return result;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (!(obj instanceof PRKdBucketTree))
return false;
PRKdBucketTree other = (PRKdBucketTree) obj;
if (allDimensions != other.allDimensions)
return false;
if (!Arrays.equals(children, other.children))
return false;
if (data == null) {
if (other.data != null)
return false;
} else if (!data.equals(other.data))
return false;
if (dimension != other.dimension)
return false;
if (isLeaf != other.isLeaf)
return false;
if (!Arrays.equals(lowerBound, other.lowerBound))
return false;
if (maxDensity != other.maxDensity)
return false;
if (maxDepth != other.maxDepth)
return false;
if (!Arrays.equals(numChildren, other.numChildren))
return false;
if (Double.doubleToLongBits(splitMedian) != Double
.doubleToLongBits(other.splitMedian))
return false;
if (!Arrays.equals(upperBound, other.upperBound))
return false;
return true;
}
}
Tester
package nat.tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;
public class Test {
public static void main(String[] args) {
final int numTest = 100;
final int clusterSize = 15;
long linearTime, tree2time, tree3time, tree4time;
String[] answer = new String[clusterSize];
System.out.println("Starting Bucket PR k-d tree performance test...");
System.out.println("Generating points...");
ArrayList<String> input = new ArrayList<String>();
ArrayList<double[]> location = new ArrayList<double[]>();
for (int i = 0; i < numTest; i++) {
input.add(generateRandomString());
double[] p = new double[3];
p[0] = Math.random();
p[1] = Math.random();
p[2] = Math.random();
location.add(p);
}
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
PRKdBucketTree<String> tree4 = PRKdBucketTree.getTree("", 3, 1, 4);
for (int i = 0; i < numTest; i++) {
tree2.addPoint(input.get(i), location.get(i));
tree3.addPoint(input.get(i), location.get(i));
tree4.addPoint(input.get(i), location.get(i));
}
double[] center = new double[3];
center[0] = Math.random();
center[1] = Math.random();
center[2] = Math.random();
System.out.println("Data generated.");
System.out.println("Performing linear search...");
linearTime = -System.nanoTime();
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
for (int i = 0; i < numTest; i++) {
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
location.get(i), new double[] { 1, 1, 1 });
pq.add(new Compare(input.get(i), distance));
}
double distance = -1;
for (int i = 0; i < clusterSize; i++) {
if (distance == -1)
distance = pq.peek().val;
answer[i] = pq.poll().data;
}
linearTime += System.nanoTime();
Arrays.sort(answer);
System.out.println("Linear search complete; time = "
+ (linearTime / 1E9));
System.out.println("Performing binary k-d tree search...");
tree2time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree2r = tree2
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
tree2time += System.nanoTime();
System.out.println("Binary tree search complete; time = "
+ (tree2time / 1E9));
int correct = 0;
int j = 0;
String[] tree2o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
tree2o[j++] = p.getValue();
}
Arrays.sort(tree2o);
for (int i = 0; i < tree2o.length; i++) {
if (tree2o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
System.out.println("Performing ternary k-d tree search...");
tree3time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree3r = tree3
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree3a = tree3r.getValues();
tree3time += System.nanoTime();
System.out.println("Ternary tree search complete; time = "
+ (tree3time / 1E9));
correct = 0;
j = 0;
String[] tree3o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree3a) {
tree3o[j++] = p.getValue();
}
Arrays.sort(tree3o);
for (int i = 0; i < tree3o.length; i++) {
if (tree3o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
System.out.println("Performing quaternary k-d tree search...");
tree4time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree4r = tree4
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree4a = tree4r.getValues();
tree4time += System.nanoTime();
System.out.println("Quaternary tree search complete; time = "
+ (tree4time / 1E9));
correct = 0;
j = 0;
String[] tree4o = new String[clusterSize];
for (PRKdBucketTree.KdPoint<String> p : tree4a) {
tree4o[j++] = p.getValue();
}
Arrays.sort(tree4o);
for (int i = 0; i < tree4o.length; i++) {
if (tree4o[i].equals(answer[i]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
}
private static class Compare implements Comparable<Compare> {
String data;
double val;
@Override
public int compareTo(Compare o) {
return (int) -Math.signum(o.val - val);
}
public Compare(String data, double val) {
this.data = data;
this.val = val;
}
}
private static String generateRandomString() {
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
Random r = new Random();
char[] buf = new char[25];
for (int i = 0; i < buf.length; i++) {
buf[i] = chars.charAt(r.nextInt(chars.length()));
}
return new String(buf);
}
}
Old
It took me last night to create and all day today to debug it but I still can't get it right. Anyone please help me!
package nat.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import nat.util.M;
/**
*
* Implementation of bucket PR k-d tree.
*
* @author Nat Pavasant
*
* @param <V>
* The type of data to store
*/
public class PRKdBucketTree<V> implements Serializable {
private static final long serialVersionUID = 1L;
public static final Distancer EUCLIDIAN = new Distancer.EuclidianDistancer();
public static final Distancer MANHATTAN = new Distancer.ManhattanDistancer();
private final PRKdBucketTree<V>[] children;
private final Queue<KdEntry<V>> data;
private final Distancer distancer;
private final double[] lowerBound, upperBound;
private final int[] numChildren;
private final int allDimensions;
private final int dimension;
private final int maxDepth, maxDensity;
private final double splitMedian;
private boolean isLeaf = true;
/**
* Create new Bucket PR k-d tree.
*
* @param allDimensions
* number of dimensions in the tree
* @param lowerBound
* the minimum value of the location of each dimension
* @param upperBound
* the maximum value of the location of each dimension
* @param numChildren
* number of children in each dimension
* @param maxDepth
* the max depth of the tree
* @param maxDensity
* size of bucket in each leaf
* @param distancer
* distance measurer
*/
@SuppressWarnings("unchecked")
public PRKdBucketTree(int allDimensions, double[] lowerBound,
double[] upperBound, int[] numChildren, int maxDepth,
int maxDensity, Distancer distancer) {
if (allDimensions < 1 || maxDensity < 1)
throw new IllegalArgumentException(
"Either dimension or density isn't positive integer.");
if (lowerBound.length != allDimensions
|| upperBound.length != allDimensions
|| numChildren.length != allDimensions)
throw new IllegalArgumentException(
"Either bounds or children amount is more or less than dimension count.");
for (double a : lowerBound) {
if (a < 0)
throw new IllegalArgumentException(
"Can't set lower bound to negative number.");
}
for (int i = 0; i < lowerBound.length; i++) {
if (lowerBound[i] > upperBound[i])
throw new IllegalArgumentException(
"Upper bound must have a value higer than lower bound.");
}
this.allDimensions = allDimensions;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.numChildren = numChildren;
this.distancer = distancer;
this.maxDepth = maxDepth;
this.maxDensity = maxDensity;
this.dimension = maxDepth % allDimensions;
this.data = new LinkedList<KdEntry<V>>();
this.children = new PRKdBucketTree[numChildren[this.dimension]];
this.splitMedian = (upperBound[this.dimension] - lowerBound[this.dimension])
/ numChildren[this.dimension];
}
// Another constructor
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double lowerBound, double upperBound, int numChildren,
int maxDepth, int maxDensity, Distancer distance) {
double[] lowerBounds = new double[dimension];
double[] upperBounds = new double[dimension];
int[] numChildrens = new int[dimension];
Arrays.fill(lowerBounds, lowerBound);
Arrays.fill(upperBounds, upperBound);
Arrays.fill(numChildrens, numChildren);
return new PRKdBucketTree<T>(dimension, lowerBounds, upperBounds,
numChildrens, maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity,
Distancer distance) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, distance);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, numChildren,
maxDepth, maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int maxDepth, int maxDensity) {
return getTree(dataType, dimension, 0, upperBound, 2, maxDepth,
maxDensity, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound, int numChildren) {
return getTree(dataType, dimension, 0, upperBound, numChildren, 500, 8,
EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension,
double upperBound) {
return getTree(dataType, dimension, 0, upperBound, 2, 500, 8, EUCLIDIAN);
}
public static <T> PRKdBucketTree<T> getTree(T dataType, int dimension) {
return getTree(dataType, dimension, 0, 1, 2, 500, 8, EUCLIDIAN);
}
/**
* Add new point to the tree
*
* @param value
* the stored value
* @param location
* location of the point
* @return removed point if any
*/
public KdEntry<V> addPoint(V value, double[] location) {
if (location.length != allDimensions) {
throw new IllegalArgumentException(
"Provided location have either more or less dimensions than the tree.");
}
KdEntry<V> entry = new KdEntry<V>(value, location);
return addPoint(entry);
}
/**
* Get the n-nearest neighbor.
*
* @param size
* number of neighbors
* @param center
* center of the cluster
* @param weight
* weighting for the distancer
* @return
*/
public KdCluster<V> getNearestNeighbor(int size, double[] center,
double[] weight) {
KdCluster<V> cluster = new KdCluster<V>(size, center, weight, distancer);
nearestNeighborSearch(cluster);
return cluster;
}
/**
* Back-end implementation of the entry adding.
*
* @param entry
* new entry to the tree
* @return removed point if any
*/
private KdEntry<V> addPoint(KdEntry<V> entry) {
if (isLeaf) {
// Still has spaces, add the data
if (data.size() < maxDensity) {
data.add(entry);
return null;
}
// final leaf, unsplitable, remove element and add new one.
if (maxDepth <= 1) {
data.add(entry);
return data.poll();
}
// if we reached here, we need to split this leaf to a branch.
isLeaf = false;
for (KdEntry<V> p : data) {
passToChildren(p);
}
data.clear();
}
// we are branch, pass to children
return passToChildren(entry);
}
/**
* Perform n-nearest neighbor search
*
* @param cluster
* current working cluster
*/
private void nearestNeighborSearch(KdCluster<V> cluster) {
if (isLeaf) {
for (KdEntry<V> p : data) {
cluster.consider(p);
}
} else {
for (PRKdBucketTree<V> child : children) {
if (child != null && cluster.isViable(child))
child.nearestNeighborSearch(cluster);
}
}
}
/**
* Pass the entry to correct children
*
* @param p
* entry
* @return removed point if any
*/
private KdEntry<V> passToChildren(KdEntry<V> p) {
int i = getChildrenIndex(p.getLocation(dimension));
if (children[i] == null)
children[i] = createChildTree(i);
return children[i].addPoint(p);
}
/**
* Return index of the children which will contains data with that value.
*
* @param value
* the value
* @return index of children
*/
private int getChildrenIndex(double value) {
return (int) M.limit(0, Math.floor(value / splitMedian),
numChildren[dimension] - 1);
}
/**
* Create children tree with correct bound.
*
* @param i
* the index of the children
* @return created tree
*/
private PRKdBucketTree<V> createChildTree(int i) {
double[] upperBound = this.upperBound.clone();
double[] lowerBound = this.lowerBound.clone();
lowerBound[i] = splitMedian * i;
upperBound[i] = splitMedian * (i + 1);
return new PRKdBucketTree<V>(allDimensions, lowerBound, upperBound,
numChildren, maxDepth - 1, maxDensity, distancer);
}
public static class KdCluster<K> implements Iterable<KdPoint<K>> {
private final PriorityQueue<KdPoint<K>> points;
private final double[] center;
private final double[] weight;
private final int size;
private final Distancer distancer;
public KdCluster(int size, double[] center, double[] weight, Distancer distancer) {
points = new PriorityQueue<KdPoint<K>>();
this.size = size;
this.center = center;
this.weight = weight;
this.distancer = distancer;
}
public void consider(KdEntry<K> k) {
KdPoint<K> p = new KdPoint<K>(k, center, weight, distancer);
if (points.size() < size) {
points.add(p);
} else if (points.peek().isFurtherThan(p)) {
points.poll();
points.add(p);
}
}
public boolean isViable(PRKdBucketTree<K> tree) {
if (points.size() < size)
return true;
double[] testPoints = new double[center.length];
for (int i = 0; i < center.length; i++)
testPoints[i] = M.limit(tree.lowerBound[i], center[i],
tree.upperBound[i]);
return points.peek().isFurtherThan(
distancer.getDistance(center, testPoints, weight));
}
@Override
public Iterator<KdPoint<K>> iterator() {
return points.iterator();
}
public Collection<KdPoint<K>> getValues() {
Collection<KdPoint<K>> collect = new ArrayList<KdPoint<K>>(points
.size());
for (KdPoint<K> p : points) {
collect.add(p);
}
return collect;
}
}
private static class KdEntry<K> implements Serializable {
private static final long serialVersionUID = 1L;
private final K value;
private final double[] location;
public KdEntry(K value, double[] location) {
super();
this.value = value;
this.location = location;
}
public K getValue() {
return value;
}
public double[] getLocation() {
return location;
}
public double getLocation(int a) {
return location[a];
}
}
public static class KdPoint<K> extends KdEntry<K> implements Serializable,
Comparable<KdPoint<K>> {
private static final long serialVersionUID = 1L;
private final double distanceToCenter;
public KdPoint(KdEntry<K> p, double[] center, double[] weight, Distancer distancer) {
super(p.getValue(), p.getLocation());
distanceToCenter = distancer.getDistance(center, getLocation(),
weight);
}
public double getDistanceToCenter() {
return distanceToCenter;
}
@Override
public int compareTo(KdPoint<K> o) {
return (int) Math.signum(o.distanceToCenter - distanceToCenter);
}
public boolean isFurtherThan(KdPoint<K> p) {
return compareTo(p) == -1;
}
public boolean isFurtherThan(double distance) {
return distanceToCenter > distance;
}
}
public static abstract class Distancer {
public double getDistance(double[] p1, double[] p2, double[] weight) {
if (p1.length != p2.length)
throw new IllegalArgumentException();
return getPointDistance(p1, p2, weight);
}
public abstract double getPointDistance(double[] p1, double[] p2,
double[] weight);
public static class EuclidianDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.sqr(p1[i] - p2[i]) * weight[i];
}
return M.sqrt(result);
}
}
public static class ManhattanDistancer extends Distancer {
@Override
public double getPointDistance(double[] p1, double[] p2,
double[] weight) {
double result = 0;
for (int i = 0; i < p1.length; i++) {
result += M.abs(p1[i] - p2[i]) * weight[i];
}
return result;
}
}
}
}
I use this code to check:
package nat.tree;
import java.util.ArrayList;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Random;
import nat.tree.PRKdBucketTree.KdPoint;
public class Test {
public static void main(String[] args) {
final int numTest = 13000;
final int clusterSize = 10;
long linearTime, tree2time, tree3time;
String[] answer = new String[100];
System.out.println("Starting Bucket PR k-d tree performance test...");
System.out.println("Generating points...");
ArrayList<String> input = new ArrayList<String>();
ArrayList<double[]> location = new ArrayList<double[]>();
for (int i = 0; i < numTest; i++) {
input.add(generateRandomString());
double[] p = new double[3];
p[0] = Math.random();
p[1] = Math.random();
p[2] = Math.random();
location.add(p);
}
PRKdBucketTree<String> tree2 = PRKdBucketTree.getTree("", 3, 1, 2);
PRKdBucketTree<String> tree3 = PRKdBucketTree.getTree("", 3, 1, 3);
for (int i = 0; i < numTest; i++) {
tree2.addPoint(input.get(i), location.get(i));
tree3.addPoint(input.get(i), location.get(i));
}
double[] center = new double[3];
center[0] = Math.random();
center[1] = Math.random();
center[2] = Math.random();
System.out.println("Data generated.");
System.out.println("Performing linear search...");
linearTime = -System.nanoTime();
PriorityQueue<Compare> pq = new PriorityQueue<Compare>();
for (int i = 0; i < numTest; i++) {
double distance = PRKdBucketTree.EUCLIDIAN.getDistance(center,
location.get(i), new double[] { 1, 1, 1 });
pq.add(new Compare(input.get(i), distance));
}
double distance = -1;
for (int i = 0; i < clusterSize; i++) {
if (distance == -1)
distance = pq.peek().val;
answer[i] = pq.poll().data;
}
linearTime += System.nanoTime();
System.out.println("Linear search complete; time = "
+ (linearTime / 1E9));
System.out.println("Performing binary k-d tree search...");
tree2time = -System.nanoTime();
PRKdBucketTree.KdCluster<String> tree2r = tree2
.getNearestNeighbor(clusterSize, center,
new double[] { 1, 1, 1 });
Collection<PRKdBucketTree.KdPoint<String>> tree2a = tree2r.getValues();
tree2time += System.nanoTime();
System.out.println("Distance = " + distance + "; "
+ tree2a.iterator().next().getDistanceToCenter());
System.out.println("Binary tree search complete; time = "
+ (tree2time / 1E9));
int correct = 0;
int j = 0;
for (PRKdBucketTree.KdPoint<String> p : tree2a) {
System.out.print(p.getValue());
System.out.print(" ");
System.out.println(answer[j]);
if (p.getValue().equals(answer[j++]))
correct++;
}
System.out.println(": accuracy = " + ((double) correct / clusterSize));
}
private static class Compare implements Comparable<Compare> {
String data;
double val;
@Override
public int compareTo(Compare o) {
return (int) -Math.signum(o.val - val);
}
public Compare(String data, double val) {
this.data = data;
this.val = val;
}
}
private static String generateRandomString() {
String chars = "abcdefghijklmonpqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
Random r = new Random();
char[] buf = new char[15];
for (int i = 0; i < buf.length; i++) {
buf[i] = chars.charAt(r.nextInt(chars.length()));
}
return new String(buf);
}
}
The linear search and tree search doesn't the same thing. I know this tree is a bit messy since I want it to supports other m-ary tree style too. Anyway, please help. Thank you in advance =) » Nat | Talk » 13:40, 15 August 2009 (UTC)
So sad no one help me =( By the way, finally I spot the bug. It is in createChildTree() where is should be upperBound[dimension] and lowerBound[dimension] instead of i. » Nat | Talk » 16:30, 15 August 2009 (UTC)