User:Jdev/Code/VP Tree

From Robowiki
Jump to navigation Jump to search

My implementation of vantage-point tree. It's not well tested and may still has debug code, but generally must work. Published just for history and i recommend to use User:Rednaxela/kD-Tree for kNN search, because it faster and more reliable. But if this code interesting for you - fell free to use and adopt it.

package lxx.utils.vp_tree;

import lxx.utils.IntervalDouble;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

import static java.lang.Math.*;

public class VPTree {

    public static int nodesCount = 0;
    public static int maxLevel = 0;

    private static final int LOCATIONS_LIMIT = 32;

    private final IntervalDouble innerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);
    private final IntervalDouble outerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);

    private final int level;
    private final int dimensionsCount;
    private final IntervalDouble[] dimensionsWidth;

    private VPTreeLocation[] locations = new VPTreeLocation[LOCATIONS_LIMIT];
    private int size = 0;

    private VPTree innerSubTree;
    private VPTree outerSubTree;
    private double[] vantagePoint;
    private boolean singular = true;
    private double splitRadius;

    public VPTree(int level, int dimensionsCount, int locationsCount) {
        locations = new VPTreeLocation[LOCATIONS_LIMIT];
        size = locationsCount;
        this.level = level;
        dimensionsWidth = new IntervalDouble[dimensionsCount];
        this.dimensionsCount = dimensionsCount;
        for (int i = 0; i < dimensionsCount; i++) {
            dimensionsWidth[i] = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);
        }

        nodesCount++;
        maxLevel = max(maxLevel, level);
    }

    public VPTree(int level, int dimensionsCount) {
        this(level, dimensionsCount, 0);
    }

    public void add(VPTreeLocation location) {
        if (vantagePoint != null) {
            location.distToLastVantagePoint = getDistance(vantagePoint, location.location);
            addToChild(location);
        } else {
            addImpl(location);
        }
    }

    public VPTreeLocation[] findNearestNeighbours(double[] location, int k) {
        final NeighborsSet set = new NeighborsSet(location, new VPTreeLocation[k]);
        findNearestNeighbours(set);
        return set.neighbors;
    }

    private void findNearestNeighbours(NeighborsSet set) {
        if (vantagePoint == null) {
            set.visitedNodes++;
            if (singular) {
                locations[0].lastDist = getDistance(set.center.location, locations[0].location);
                for (int i = 1; i < size; i++) {
                    locations[i].lastDist = locations[i - 1].lastDist;
                }
                set.add(size, locations);
            } else {
                // todo: try find out index of center in locations and go in both directions from it
                for (int i = 0; i < size; i++) {
                    locations[i].lastDist = getDistance(set.center.location, locations[i].location);
                    set.add(locations[i]);
                }
            }
        } else {
            final double distToVP = getDistance(vantagePoint, set.center.location);
            if (distToVP + set.getMaxDistance() < outerBounds.a && set.isFilled()) {
                innerSubTree.findNearestNeighbours(set);
            } else if (distToVP - set.getMaxDistance() > innerBounds.b && set.isFilled()) {
                outerSubTree.findNearestNeighbours(set);
            } else {
                if (distToVP < (innerBounds.a + outerBounds.b) / 2) {
                    innerSubTree.findNearestNeighbours(set);

                    if (distToVP + set.getMaxDistance() >= outerBounds.a || !set.isFilled()) {
                        outerSubTree.findNearestNeighbours(set);
                    }
                } else {
                    outerSubTree.findNearestNeighbours(set);

                    if (distToVP - set.getMaxDistance() <= innerBounds.b || !set.isFilled()) {
                        innerSubTree.findNearestNeighbours(set);
                    }
                }
            }
        }
    }

    public List<VPTreeLocation> getAll() {
        if (vantagePoint == null) {
            return Arrays.asList(Arrays.copyOf(locations, size));
        } else {
            final List<VPTreeLocation> all = new ArrayList<VPTreeLocation>();

            all.addAll(innerSubTree.getAll());
            all.addAll(outerSubTree.getAll());

            return all;
        }
    }

    private void split() {
        innerSubTree = new VPTree(level + 1, dimensionsCount);
        outerSubTree = new VPTree(level + 1, dimensionsCount);
        splitRadius = getSplitRadius();
        for (int i = 0; i < size; i++) {
            addToChild(locations[i]);
        }
    }

    private void addImpl(VPTreeLocation location) {
        if (size == locations.length) {
            final VPTreeLocation[] newLocations = new VPTreeLocation[size * 2];
            System.arraycopy(locations, 0, newLocations, 0, locations.length);
            locations = newLocations;
        }
        locations[size++] = location;
        for (int i = 0; i < dimensionsCount; i++) {
            dimensionsWidth[i].extend(location.location[i]);
            singular &= dimensionsWidth[i].a == dimensionsWidth[i].b;
        }

        if (!singular && size == LOCATIONS_LIMIT) {
            split();
        }
    }

    private void addToChild(VPTreeLocation location) {
        if (location.distToLastVantagePoint < splitRadius) {
            innerBounds.extend(location.distToLastVantagePoint);
            innerSubTree.add(location);
        } else {
            outerBounds.extend(location.distToLastVantagePoint);
            outerSubTree.add(location);
        }
    }

    private double getSplitRadius() {
        double maxDist = Integer.MIN_VALUE;
        final VPTreeLocation[] maxDistLocs = new VPTreeLocation[2];
        for (int i = 0; i < size; i++) {
            for (int j = i + 1; j < size; j++) {
                double dist = getDistance(locations[i].location, locations[j].location);
                if (dist > maxDist) {
                    maxDist = dist;
                    maxDistLocs[0] = locations[i];
                    maxDistLocs[1] = locations[j];
                }
            }
        }

        int idx = (int) (2 * random());
        vantagePoint = Arrays.copyOf(maxDistLocs[idx].location, maxDistLocs[idx].location.length);
        for (int i = 0; i < size; i++) {
            locations[i].distToLastVantagePoint = getDistance(vantagePoint, locations[i].location);
        }

        Arrays.sort(locations, new Comparator<VPTreeLocation>() {
            @Override
            public int compare(VPTreeLocation o1, VPTreeLocation o2) {
                return (int) signum(o1.distToLastVantagePoint - o2.distToLastVantagePoint);
            }
        });

        maxDist = Integer.MIN_VALUE;
        double splitRadius = 0;
        for (int i = 1; i < size - 2; i++) {
            double dist = abs(locations[i].distToLastVantagePoint - locations[i + 1].distToLastVantagePoint);
            if (dist > maxDist) {
                maxDist = dist;
                splitRadius = (locations[i].distToLastVantagePoint + locations[i + 1].distToLastVantagePoint) / 2;
            }
        }

        return splitRadius;
    }

    private double getDistance(double[] pnt1, double[] pnt2) {
        double distance = 0;

        for (int i = 0; i < dimensionsCount; i++) {
            double d = pnt1[i] - pnt2[i];
            distance += d * d;
        }

        return sqrt(distance);
    }

    public static class VPTreeLocation {

        private double distToLastVantagePoint;

        public final double[] location;
        public double lastDist;

        public VPTreeLocation(double[] location) {
            this.location = location;
        }

    }

    public static class NeighborsSet {

        private final VPTreeLocation center;
        private final VPTreeLocation[] neighbors;
        private int visitedNodes = 0;
        private int locsCount = 0;

        private int size = 0;

        public NeighborsSet(double[] center, VPTreeLocation[] neighbors) {
            this.center = new VPTreeLocation(center);
            this.neighbors = neighbors;
        }

        public void add(VPTreeLocation... entries) {
            add(entries.length, entries);
        }

        public void add(int count, VPTreeLocation... entries) {
            locsCount++;
            if (size == neighbors.length && entries[0].lastDist > neighbors[size - 1].lastDist) {
                return;
            }
            int idx = 0;
            if (size > 0) {
                idx = findPosition(entries[0]);
                final int destPos = idx + count;
                if (destPos < neighbors.length) {
                    System.arraycopy(neighbors, idx, neighbors, destPos, neighbors.length - destPos);
                }
            }
            System.arraycopy(entries, 0, neighbors, idx, min(count, neighbors.length - idx));
            size = min(size + count, neighbors.length);
        }

        private int findPosition(VPTreeLocation entry) {
            int idx = size - 1;
            for (; idx >= 0; idx--) {
                if (entry.lastDist > neighbors[idx].lastDist) {
                    break;
                }
            }
            return idx < size ? idx + 1 : idx;
        }

        public double getMaxDistance() {
            return size > 0 ? neighbors[size - 1].lastDist : Integer.MAX_VALUE;
        }

        public boolean isFilled() {
            return size == neighbors.length;
        }
    }
}