Difference between revisions of "User:Rednaxela/kD-Tree"

From Robowiki
Jump to navigation Jump to search
(Use the current weightings when deciding what axis to split on. No speed impact when unweighted, and should be better when weights are used.)
m (reformat a line)
Line 137: Line 137:
 
             if (cursor.locations != null) {
 
             if (cursor.locations != null) {
 
                 cursor.splitDimension = cursor.findWidestAxis(this.weights);
 
                 cursor.splitDimension = cursor.findWidestAxis(this.weights);
                 cursor.splitValue = (cursor.minLimit[cursor.splitDimension] +
+
                 cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;
                        cursor.maxLimit[cursor.splitDimension]) * 0.5;
 
  
 
                 // Don't split node if it has no width in any axis. Double the bucket size instead
 
                 // Don't split node if it has no width in any axis. Double the bucket size instead

Revision as of 04:47, 28 August 2009

A nice efficent small kD-Tree. It's quite fast... Feel free to use

/**
 * Copyright 2009 Rednaxela
 * 
 * This software is provided 'as-is', without any express or implied
 * warranty. In no event will the authors be held liable for any damages
 * arising from the use of this software.
 * 
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 * 
 *    1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 
 *    2. This notice may not be removed or altered from any source
 *    distribution.
 */

package ags.utils.newtree2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

/**
 * An efficent well-optimized kd-tree
 * 
 * @author Rednaxela
 */
public class KdTree<T> {
    // Static variables
    private static final int bucketSize = 32;

    // All types
    private final int dimensions;
    private final KdTree<T> parent;

    // Root only
    private final HashMap<Object, T> map;
    private double[] weights;

    // Leaf only
    private double[][] locations;
    private int locationCount;

    // Stem only
    private KdTree<T> left, right;
    private int splitDimension;
    private double splitValue;

    // Bounds
    private double[] minLimit, maxLimit;

    // Temporary
    private Status status;

    /**
     * Extends the bounds of this node do include a new location
     */
    private final void extendBounds(double[] location) {
        if (minLimit == null) {
            minLimit = new double[dimensions];
            System.arraycopy(location, 0, minLimit, 0, dimensions);
            maxLimit = new double[dimensions];
            System.arraycopy(location, 0, maxLimit, 0, dimensions);
            return;
        }

        for (int i=0; i<dimensions; i++) {
            if (minLimit[i] > location[i]) {
                minLimit[i] = location[i];
            }
            if (maxLimit[i] < location[i]) {
                maxLimit[i] = location[i];
            }
        }
    }

    /**
     * Find the widest axis of the bounds of this node
     */
    private final int findWidestAxis(double[] weights) {
        int widest = 0;
        double width = (maxLimit[0] - minLimit[0]) * weights[0];
        for (int i = 1; i < dimensions; i++) {
            double nwidth = (maxLimit[i] - minLimit[i]) * weights[i];
            if (nwidth > width) {
                widest = i;
                width = nwidth;
            }
        }
        return widest;
    }

    // Main constructor
    public KdTree(int dimensions) {
        this.dimensions = dimensions;

        // Init as leaf
        this.locations = new double[bucketSize][];
        this.locationCount = 0;

        // Init as root
        this.map = new HashMap<Object, T>();
        this.weights = new double[dimensions];
        Arrays.fill(this.weights, 1.0);
        this.parent = null;
    }

    // Child constructor
    private KdTree(KdTree<T> parent, boolean right) {
        this.dimensions = parent.dimensions;

        // Init as leaf
        this.locations = new double[bucketSize][];
        this.locationCount = 0;

        // Init as non-root
        this.map = null;
        this.parent = parent;
    }

    /**
     * Add a point and associated value to the tree
     */
    public void addPoint(double[] location, T value) {
        KdTree<T> cursor = this;

        while (cursor.locations == null || cursor.locationCount >=
            cursor.locations.length) {
            if (cursor.locations != null) {
                cursor.splitDimension = cursor.findWidestAxis(this.weights);
                cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;

                // Don't split node if it has no width in any axis. Double the bucket size instead
                if ((cursor.minLimit[cursor.splitDimension] - cursor.maxLimit[cursor.splitDimension]) == 0) {
                    double[][] newLocations = new double[cursor.locations.length * 2][];
                    System.arraycopy(cursor.locations, 0, newLocations, 0, cursor.locationCount);
                    cursor.locations = newLocations;
                    break;
                }

                // Create child leaves
                KdTree<T> left = new KdTree<T>(cursor, false);
                KdTree<T> right = new KdTree<T>(cursor, true);

                // Move locations into children
                for (double[] oldLocation : cursor.locations) {
                    if (oldLocation[cursor.splitDimension] > cursor.splitValue) {
                        // Right
                        right.locations[right.locationCount] = oldLocation;
                        right.locationCount++;
                        right.extendBounds(oldLocation);
                    }
                    else {
                        // Left
                        left.locations[left.locationCount] = oldLocation;
                        left.locationCount++;
                        left.extendBounds(oldLocation);
                    }
                }

                // Make into stem
                cursor.left = left;
                cursor.right = right;
                cursor.locations = null;
            }

            cursor.extendBounds(location);

            if (location[cursor.splitDimension] > cursor.splitValue) {
                cursor = cursor.right;
            }
            else {
                cursor = cursor.left;
            }
        }

        cursor.locations[cursor.locationCount] = location;
        cursor.locationCount++;
        cursor.extendBounds(location);

        this.map.put(location, value);
    }

    /**
     * Enumeration representing the status of a node during the running 
     */
    private static enum Status {
        NONE,
        LEFTVISITED,
        RIGHTVISITED,
        ALLVISITED
    }

    /**
     * Stores a distance and value to output
     */
    public static class Entry<T> {
        public final double distance;
        public final T value;
        private Entry(double distance, T value) {
            this.distance = distance;
            this.value = value;
        }
    }

    /**
     * Sets the weighting on dimensions used
     */
    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    /**
     * Calculates the nearest 'count' points to 'location'
     */
    public List<Entry<T>> nearestNeighbor(double[] location, int count) {
        KdTree<T> cursor = this;
        cursor.status = Status.NONE;
        double range = Double.POSITIVE_INFINITY;
        ResultHeap resultHeap = new ResultHeap(count); 

        do {
            if (cursor.status == Status.ALLVISITED) {
                // At a fully visited part. Move up the tree
                cursor = cursor.parent;
                continue;
            }

            if (cursor.status == Status.NONE && cursor.locations != null) {
                // At a leaf. Use the data.
                for (int i=0; i<cursor.locationCount; i++) {
                    double dist = sqrPointDist(cursor.locations[i], location, this.weights);
                    resultHeap.addValue(dist, cursor.locations[i]);
                }
                range = resultHeap.getMaxDist();

                if (cursor.parent == null) {
                    break;
                }
                cursor = cursor.parent;
                continue;
            }

            // Going to descend
            KdTree<T> nextCursor = null;
            if (cursor.status == Status.NONE) {
                // At a fresh node, descend the most probably useful direction
                if (location[cursor.splitDimension] > cursor.splitValue) {
                    // Descend right
                    nextCursor = cursor.right;
                    cursor.status = Status.RIGHTVISITED;
                }
                else {
                    // Descend left;
                    nextCursor = cursor.left;
                    cursor.status = Status.LEFTVISITED;
                }
            }
            else if (cursor.status == Status.LEFTVISITED) {
                // Left node visited, descend right.
                nextCursor = cursor.right;
                cursor.status = Status.ALLVISITED;
            }
            else if (cursor.status == Status.RIGHTVISITED) {
                // Right node visited, descend left.
                nextCursor = cursor.left;
                cursor.status = Status.ALLVISITED;
            }

            // Check if it's worth descending. Assume it is if it's sibling has not been visited yet. 
            if (cursor.status == Status.ALLVISITED) {
                if (nextCursor.locationCount == 0 || sqrPointRegionDist(location, nextCursor.minLimit, nextCursor.maxLimit, this.weights) > range) {
                    continue;
                }
            }

            // Descend down the tree
            cursor = nextCursor;
            cursor.status = Status.NONE;
        } while (cursor.parent != null || cursor.status != Status.ALLVISITED);

        ArrayList<Entry<T>> results = new ArrayList<Entry<T>>(count);
        Object[] data = resultHeap.getData();
        double[] dist = resultHeap.getDistances();
        for (int i=0; i<resultHeap.values; i++) {
            T value = this.map.get(data[i]);
            results.add(new Entry<T>(dist[i], value));
        }

        return results;
    }

    /**
     * Calculates the (squared euclidean) distance between two points
     */
    private static final double sqrPointDist(double[] p1, double[] p2, double[] weights) {
        double d = 0;

        for (int i=0; i<p1.length; i++) {
            double diff = (p1[i] - p2[i]) * weights[i];
            d += diff * diff;
        }

        return d;
    }

    /**
     * Calculates the closest (squared euclidean) distance between in a point and a bounding region
     */
    private static final double sqrPointRegionDist(double[] point, double[] min, double[] max, double[] weights) {
        double d = 0;

        for (int i=0; i<point.length; i++) {
            if (point[i] > max[i]) {
                double diff = (point[i] - max[i]) * weights[i];
                d += diff * diff;
            } else if (point[i] < min[i]) {
                double diff = (point[i] - min[i]) * weights[i];
                d += diff * diff;
            }
        }

        return d;
    }

    /**
     * Class for tracking up to 'size' closest values
     */
    private static class ResultHeap {
        private final Object[] data;
        private final double[] distance;
        private final int size;
        private int values;

        public ResultHeap(int size) {
            this.data = new Object[size+1];
            this.distance = new double[size+1];
            this.size = size;
            this.values = 0;
        }

        public void addValue(double dist, Object value) {
            if (values == size && dist >= distance[0]) {
                return;
            }

            // Insert value
            data[values] = value;
            distance[values] = dist;
            values++;

            // Up-Heapify
            for (int c = values-1, p = (c-1)/2; c != 0 && distance[c] > distance[p]; c = p, p = (c-1)/2) {
                Object pData = data[p];
                double pDist = distance[p];
                data[p] = data[c];
                distance[p] = distance[c];
                data[c] = pData;
                distance[c] = pDist;
            }

            // If too big, remove the highest value
            if (values > size) {
                // Move the last entry to the top
                values--;
                data[0] = data[values];
                distance[0] = distance[values];

                // Down-Heapify
                for (int p = 0, c = 1; c < values; p = c,c = p*2+1) {
                    if (c+1 < values && distance[c] < distance[c+1]) {
                        c++;
                    }
                    if (distance[p] < distance[c]) {
                        // Swap the points
                        Object pData = data[p];
                        double pDist = distance[p];
                        data[p] = data[c];
                        distance[p] = distance[c];
                        data[c] = pData;
                        distance[c] = pDist;
                    }
                    else {
                        break;
                    }
                }
            }
        }

        public double getMaxDist() {
            if (values < size) {
                return Double.POSITIVE_INFINITY;
            }
            return distance[0];
        }

        public Object[] getData() {
            return data;
        }

        public double[] getDistances() {
            return distance;
        }
    }
}