User:Pedersen/kdTree
Jump to navigation
Jump to search
public class PedersenKDBucketTreeKNNSearch extends KNNImplementation {
private final KDBucketTree<StringHoldingExemplar> tree;
public PedersenKDBucketTreeKNNSearch(int dimension) {
super(dimension);
final int bucketSize = 24;
double[] domainMap = new double[ dimension ];
Arrays.fill( domainMap, Double.NaN );
this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize );
}
@Override
public void addPoint(double[] location, String value) {
this.tree.add( new StringHoldingExemplar( location, value ) );
}
@Override
public String getName() {
return "Pedersen's " + this.tree.getClass().getSimpleName();
}
@Override
public KNNPoint[] getNearestNeighbors(double[] location, int size) {
Neighborhood<StringHoldingExemplar> neighborhood =
new Neighborhood<StringHoldingExemplar>( location,
new StringHoldingExemplar[ size ] );
this.tree.findNearestNeighbors( neighborhood );
// System.out.println();
// System.out.print( this.tree.description( "1" ) );
// System.out.println();
// System.out.println( "\t" + new Exemplar( location ).description() );
return convert( neighborhood );
}
private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) {
StringHoldingExemplar[] neighbors = neighborhood.getNeighbors();
double[] distances = neighborhood.getDistances();
KNNPoint[] points = new KNNPoint[ neighbors.length ];
for( int i = 0; i < neighbors.length; i++ ) {
if( neighbors[ i ] != null ) {
// System.out.println( "Sample:\t" + neighbors[i].description() );
points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] );
} else {
System.out.println( "Sample [" + i + "] was null." );
points[ i ] = new KNNPoint( null, Double.NaN );
}
}
return points;
}
}
public class Exemplar {
public Exemplar( double domain[] ) {
this.domain = domain;
}
public String description() {
StringBuilder buffer = new StringBuilder();
buffer.append(this.domain[0]);
for( int i = 1; i < this.domain.length; i++ ) {
buffer.append( "\t" ).append( this.domain[ i ] );
}
return buffer.toString();
}
final double domain[];
}
public class StringHoldingExemplar extends Exemplar {
public StringHoldingExemplar(double[] domain, String payload) {
super(domain);
this.payload = payload;
}
public String getPayload() { return this.payload; }
private final String payload;
}
public class KDBucketTree<T extends Exemplar> {
public KDBucketTree( int bucketSize ) {
this.bucketSize = bucketSize;
}
public void findNearestNeighbors( Neighborhood<T> neighborhood ) {
if( !this.isTree() ) {
neighborhood.evaluate( this.exemplars );
} else {
if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) {
this.lt.findNearestNeighbors( neighborhood );
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
this.gte.findNearestNeighbors( neighborhood );
}
} else {
this.gte.findNearestNeighbors( neighborhood );
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
this.lt.findNearestNeighbors( neighborhood );
}
}
}
}
public void add( T e ) {
if( this.isTree() ) {
if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) {
this.lt.add( e );
} else {
this.gte.add( e );
}
} else {
if( this.exemplars.size() > this.bucketSize ) {
System.out.println( "Excessive size: " + this.exemplars.size() );
}
this.exemplars.add( e );
if( this.exemplars.size() > this.bucketSize ) {
this.branch();
}
}
}
private void branch() {
this.determineSplittingPoint();
List<T> ltList = new ArrayList<T>();
Iterator<T> iterator = this.exemplars.iterator();
while( iterator.hasNext() ) {
T exemplar = iterator.next();
if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) {
ltList.add( exemplar );
iterator.remove();
}
}
this.lt = new KDBucketTree<T>( this.bucketSize );
this.lt.put( ltList );
this.gte = new KDBucketTree<T>( this.bucketSize );
this.gte.put( this.exemplars );
this.exemplars = null;
}
private void put( List<T> exemplars ) {
this.exemplars = exemplars;
}
private void determineSplittingPoint() {
Iterator<T> iterator = this.exemplars.iterator();
if( iterator.hasNext() ) {
T exemplar = iterator.next();
double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
while( iterator.hasNext() ) {
exemplar = iterator.next();
for( int i = 0; i < exemplar.domain.length; i++ ) {
minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] );
maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] );
}
}
this.branchingIndex = 0;
double maxRange = maximums[ 0 ] - minimums[ 0 ];
for( int i = 1; i < exemplar.domain.length; i++ ) {
double range = maximums[ i ] - minimums[ i ];
if( range > maxRange ) {
this.branchingIndex = i;
maxRange = range;
}
}
this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange );
}
}
private double getSplitValue() {
return this.branchingValue;
}
private void setSplitValue( double value ) {
this.branchingValue = value;
}
public boolean isTree() { return branchingIndex > -1; }
public static String trim(double value) {
String untrimmed = String.valueOf(value);
String trimmed = untrimmed;
int indexOfDecimal = untrimmed.indexOf('.');
if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) {
trimmed = untrimmed.substring(0, indexOfDecimal + 5);
}
if (untrimmed.indexOf('e') > 0)
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e'));
if (untrimmed.indexOf('E') > 0)
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E'));
return trimmed;
}
public String description( String prefix ) {
StringBuilder buffer = new StringBuilder();
if( this.isTree() ) {
buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" );
buffer.append( this.lt.description( prefix + ".1" ) );
buffer.append( this.gte.description( prefix + ".2" ) );
} else {
buffer.append( "Leaf " + prefix + "\n" );
// for( T e : this.exemplars ) {
// buffer.append( "\t" ).append( e.description() ).append( "\n" );
// }
}
return buffer.toString();
}
private final int bucketSize;
private List<T> exemplars = new ArrayList<T>();
private int branchingIndex = -1;
private double branchingValue = Double.NaN;
private KDBucketTree<T> lt = null;
private KDBucketTree<T> gte = null;
}
public class Neighborhood<T extends Exemplar> {
public Neighborhood( double[] control, T sandbox[] ) {
this.control = control;
this.neighbors = sandbox;
this.dSquareds = new double[this.neighbors.length];
Arrays.fill( this.dSquareds, Double.MAX_VALUE );
}
public void evaluate( Iterable<T> exemplars ) {
for( T e : exemplars ) evaluate( e );
}
public void evaluate( T e ) {
double dSquared = this.calculateDistanceSquared( this.control, e.domain );
if( dSquared < this.getGreatestSquaredDistance() ) {
this.neighbors[ this.indexOfGreatestDistanceSquared ] = e;
this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared;
this.setIndexOfGreatestSquaredDistance();
}
}
public double getGreatestSquaredDistance() {
return this.dSquareds[ this.indexOfGreatestDistanceSquared ];
}
public boolean isBranchEligible( int branchIndex, double branchSplit ) {
return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit );
}
private void setIndexOfGreatestSquaredDistance() {
this.indexOfGreatestDistanceSquared = 0;
for( int i = 1; i < this.dSquareds.length; i++ ) {
if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) {
this.indexOfGreatestDistanceSquared = i;
}
}
}
public double getControlValueAtIndex( int index ) {
return this.control[ index ];
}
public T[] getNeighbors() { return this.neighbors; }
public double[] getDistances() {
double[] distances = new double[ this.dSquareds.length ];
for( int i = 0; i < this.dSquareds.length; i++ ) {
distances[ i ] = Math.sqrt( this.dSquareds[ i ] );
}
return distances;
}
private double calculateDistanceSquared( double[] control, double[] reference ) {
double ds = 0.0;
assert( reference.length == control.length );
for( int i = 0; i < reference.length; i++ ) {
if( !Double.isNaN( reference[ i ] ) ) {
ds += this.square( control[ i ] - reference[ i ] );
}
}
// System.out.println( "Calculated distance squared: " + ds );
return ds;
}
private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) {
return this.square( this.control[ branchIndex ] - branchSplit );
}
private double square( double v ) { return v * v; }
private final double[] control;
private final T neighbors[];
private final double dSquareds[];
private int indexOfGreatestDistanceSquared = 0;
}