User:Jdev/Code/VP Tree
Jump to navigation
Jump to search
My implementation of vantage-point tree[1]. It's not well tested and may still has debug code, but generally must work. Published just for history and i recommend to use User:Rednaxela/kD-Tree for kNN search, because it faster and more reliable. But if this code interesting for you - fell free to use and adopt it.
/*
* Copyright (c) 2011 Alexey Zhidkov (Jdev). All Rights Reserved.
*/
package lxx.utils.vp_tree;
import lxx.utils.IntervalDouble;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import static java.lang.Math.*;
public class VPTree {
public static int nodesCount = 0;
public static int maxLevel = 0;
private static final int LOCATIONS_LIMIT = 32;
private final IntervalDouble innerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);
private final IntervalDouble outerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);
private final int level;
private final int dimensionsCount;
private final IntervalDouble[] dimensionsWidth;
private VPTreeLocation[] locations = new VPTreeLocation[LOCATIONS_LIMIT];
private int size = 0;
private VPTree innerSubTree;
private VPTree outerSubTree;
private double[] vantagePoint;
private boolean singular = true;
private double splitRadius;
public VPTree(int level, int dimensionsCount, int locationsCount) {
locations = new VPTreeLocation[LOCATIONS_LIMIT];
size = locationsCount;
this.level = level;
dimensionsWidth = new IntervalDouble[dimensionsCount];
this.dimensionsCount = dimensionsCount;
for (int i = 0; i < dimensionsCount; i++) {
dimensionsWidth[i] = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE);
}
nodesCount++;
maxLevel = max(maxLevel, level);
}
public VPTree(int level, int dimensionsCount) {
this(level, dimensionsCount, 0);
}
public void add(VPTreeLocation location) {
if (vantagePoint != null) {
location.distToLastVantagePoint = getDistance(vantagePoint, location.location);
addToChild(location);
} else {
addImpl(location);
}
}
public VPTreeLocation[] findNearestNeighbours(double[] location, int k) {
final NeighborsSet set = new NeighborsSet(location, new VPTreeLocation[k]);
findNearestNeighbours(set);
return set.neighbors;
}
private void findNearestNeighbours(NeighborsSet set) {
if (vantagePoint == null) {
set.visitedNodes++;
if (singular) {
locations[0].lastDist = getDistance(set.center.location, locations[0].location);
for (int i = 1; i < size; i++) {
locations[i].lastDist = locations[i - 1].lastDist;
}
set.add(size, locations);
} else {
// todo: try find out index of center in locations and go in both directions from it
for (int i = 0; i < size; i++) {
locations[i].lastDist = getDistance(set.center.location, locations[i].location);
set.add(locations[i]);
}
}
} else {
final double distToVP = getDistance(vantagePoint, set.center.location);
if (distToVP + set.getMaxDistance() < outerBounds.a && set.isFilled()) {
innerSubTree.findNearestNeighbours(set);
} else if (distToVP - set.getMaxDistance() > innerBounds.b && set.isFilled()) {
outerSubTree.findNearestNeighbours(set);
} else {
if (distToVP < (innerBounds.a + outerBounds.b) / 2) {
innerSubTree.findNearestNeighbours(set);
if (distToVP + set.getMaxDistance() >= outerBounds.a || !set.isFilled()) {
outerSubTree.findNearestNeighbours(set);
}
} else {
outerSubTree.findNearestNeighbours(set);
if (distToVP - set.getMaxDistance() <= innerBounds.b || !set.isFilled()) {
innerSubTree.findNearestNeighbours(set);
}
}
}
}
}
public List<VPTreeLocation> getAll() {
if (vantagePoint == null) {
return Arrays.asList(Arrays.copyOf(locations, size));
} else {
final List<VPTreeLocation> all = new ArrayList<VPTreeLocation>();
all.addAll(innerSubTree.getAll());
all.addAll(outerSubTree.getAll());
return all;
}
}
private void split() {
innerSubTree = new VPTree(level + 1, dimensionsCount);
outerSubTree = new VPTree(level + 1, dimensionsCount);
splitRadius = getSplitRadius();
for (int i = 0; i < size; i++) {
addToChild(locations[i]);
}
}
private void addImpl(VPTreeLocation location) {
if (size == locations.length) {
final VPTreeLocation[] newLocations = new VPTreeLocation[size * 2];
System.arraycopy(locations, 0, newLocations, 0, locations.length);
locations = newLocations;
}
locations[size++] = location;
for (int i = 0; i < dimensionsCount; i++) {
dimensionsWidth[i].extend(location.location[i]);
singular &= dimensionsWidth[i].a == dimensionsWidth[i].b;
}
if (!singular && size == LOCATIONS_LIMIT) {
split();
}
}
private void addToChild(VPTreeLocation location) {
if (location.distToLastVantagePoint < splitRadius) {
innerBounds.extend(location.distToLastVantagePoint);
innerSubTree.add(location);
} else {
outerBounds.extend(location.distToLastVantagePoint);
outerSubTree.add(location);
}
}
private double getSplitRadius() {
double maxDist = Integer.MIN_VALUE;
final VPTreeLocation[] maxDistLocs = new VPTreeLocation[2];
for (int i = 0; i < size; i++) {
for (int j = i + 1; j < size; j++) {
double dist = getDistance(locations[i].location, locations[j].location);
if (dist > maxDist) {
maxDist = dist;
maxDistLocs[0] = locations[i];
maxDistLocs[1] = locations[j];
}
}
}
int idx = (int) (2 * random());
vantagePoint = Arrays.copyOf(maxDistLocs[idx].location, maxDistLocs[idx].location.length);
for (int i = 0; i < size; i++) {
locations[i].distToLastVantagePoint = getDistance(vantagePoint, locations[i].location);
}
Arrays.sort(locations, new Comparator<VPTreeLocation>() {
@Override
public int compare(VPTreeLocation o1, VPTreeLocation o2) {
return (int) signum(o1.distToLastVantagePoint - o2.distToLastVantagePoint);
}
});
maxDist = Integer.MIN_VALUE;
double splitRadius = 0;
for (int i = 1; i < size - 2; i++) {
double dist = abs(locations[i].distToLastVantagePoint - locations[i + 1].distToLastVantagePoint);
if (dist > maxDist) {
maxDist = dist;
splitRadius = (locations[i].distToLastVantagePoint + locations[i + 1].distToLastVantagePoint) / 2;
}
}
return splitRadius;
}
private double getDistance(double[] pnt1, double[] pnt2) {
double distance = 0;
for (int i = 0; i < dimensionsCount; i++) {
double d = pnt1[i] - pnt2[i];
distance += d * d;
}
return sqrt(distance);
}
public static class VPTreeLocation {
private double distToLastVantagePoint;
public final double[] location;
public double lastDist;
public VPTreeLocation(double[] location) {
this.location = location;
}
}
public static class NeighborsSet {
private final VPTreeLocation center;
private final VPTreeLocation[] neighbors;
private int visitedNodes = 0;
private int locsCount = 0;
private int size = 0;
public NeighborsSet(double[] center, VPTreeLocation[] neighbors) {
this.center = new VPTreeLocation(center);
this.neighbors = neighbors;
}
public void add(VPTreeLocation... entries) {
add(entries.length, entries);
}
public void add(int count, VPTreeLocation... entries) {
locsCount++;
if (size == neighbors.length && entries[0].lastDist > neighbors[size - 1].lastDist) {
return;
}
int idx = 0;
if (size > 0) {
idx = findPosition(entries[0]);
final int destPos = idx + count;
if (destPos < neighbors.length) {
System.arraycopy(neighbors, idx, neighbors, destPos, neighbors.length - destPos);
}
}
System.arraycopy(entries, 0, neighbors, idx, min(count, neighbors.length - idx));
size = min(size + count, neighbors.length);
}
private int findPosition(VPTreeLocation entry) {
int idx = size - 1;
for (; idx >= 0; idx--) {
if (entry.lastDist > neighbors[idx].lastDist) {
break;
}
}
return idx < size ? idx + 1 : idx;
}
public double getMaxDistance() {
return size > 0 ? neighbors[size - 1].lastDist : Integer.MAX_VALUE;
}
public boolean isFilled() {
return size == neighbors.length;
}
}
}