User:Pedersen/kdTree
Jump to navigation
Jump to search
public class PedersenKDBucketTreeKNNSearch extends KNNImplementation { private final KDBucketTree<StringHoldingExemplar> tree; public PedersenKDBucketTreeKNNSearch(int dimension) { super(dimension); final int bucketSize = 24; double[] domainMap = new double[ dimension ]; Arrays.fill( domainMap, Double.NaN ); this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize ); } @Override public void addPoint(double[] location, String value) { this.tree.add( new StringHoldingExemplar( location, value ) ); } @Override public String getName() { return "Pedersen's " + this.tree.getClass().getSimpleName(); } @Override public KNNPoint[] getNearestNeighbors(double[] location, int size) { Neighborhood<StringHoldingExemplar> neighborhood = new Neighborhood<StringHoldingExemplar>( location, new StringHoldingExemplar[ size ] ); this.tree.findNearestNeighbors( neighborhood ); // System.out.println(); // System.out.print( this.tree.description( "1" ) ); // System.out.println(); // System.out.println( "\t" + new Exemplar( location ).description() ); return convert( neighborhood ); } private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) { StringHoldingExemplar[] neighbors = neighborhood.getNeighbors(); double[] distances = neighborhood.getDistances(); KNNPoint[] points = new KNNPoint[ neighbors.length ]; for( int i = 0; i < neighbors.length; i++ ) { if( neighbors[ i ] != null ) { // System.out.println( "Sample:\t" + neighbors[i].description() ); points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] ); } else { System.out.println( "Sample [" + i + "] was null." ); points[ i ] = new KNNPoint( null, Double.NaN ); } } return points; } }
public class Exemplar { public Exemplar( double domain[] ) { this.domain = domain; } public String description() { StringBuilder buffer = new StringBuilder(); buffer.append(this.domain[0]); for( int i = 1; i < this.domain.length; i++ ) { buffer.append( "\t" ).append( this.domain[ i ] ); } return buffer.toString(); } final double domain[]; }
public class StringHoldingExemplar extends Exemplar { public StringHoldingExemplar(double[] domain, String payload) { super(domain); this.payload = payload; } public String getPayload() { return this.payload; } private final String payload; }
public class KDBucketTree<T extends Exemplar> { public KDBucketTree( int bucketSize ) { this.bucketSize = bucketSize; } public void findNearestNeighbors( Neighborhood<T> neighborhood ) { if( !this.isTree() ) { neighborhood.evaluate( this.exemplars ); } else { if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) { this.lt.findNearestNeighbors( neighborhood ); if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) { this.gte.findNearestNeighbors( neighborhood ); } } else { this.gte.findNearestNeighbors( neighborhood ); if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) { this.lt.findNearestNeighbors( neighborhood ); } } } } public void add( T e ) { if( this.isTree() ) { if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) { this.lt.add( e ); } else { this.gte.add( e ); } } else { if( this.exemplars.size() > this.bucketSize ) { System.out.println( "Excessive size: " + this.exemplars.size() ); } this.exemplars.add( e ); if( this.exemplars.size() > this.bucketSize ) { this.branch(); } } } private void branch() { this.determineSplittingPoint(); List<T> ltList = new ArrayList<T>(); Iterator<T> iterator = this.exemplars.iterator(); while( iterator.hasNext() ) { T exemplar = iterator.next(); if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) { ltList.add( exemplar ); iterator.remove(); } } this.lt = new KDBucketTree<T>( this.bucketSize ); this.lt.put( ltList ); this.gte = new KDBucketTree<T>( this.bucketSize ); this.gte.put( this.exemplars ); this.exemplars = null; } private void put( List<T> exemplars ) { this.exemplars = exemplars; } private void determineSplittingPoint() { Iterator<T> iterator = this.exemplars.iterator(); if( iterator.hasNext() ) { T exemplar = iterator.next(); double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length ); double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length ); while( iterator.hasNext() ) { exemplar = iterator.next(); for( int i = 0; i < exemplar.domain.length; i++ ) { minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] ); maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] ); } } this.branchingIndex = 0; double maxRange = maximums[ 0 ] - minimums[ 0 ]; for( int i = 1; i < exemplar.domain.length; i++ ) { double range = maximums[ i ] - minimums[ i ]; if( range > maxRange ) { this.branchingIndex = i; maxRange = range; } } this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange ); } } private double getSplitValue() { return this.branchingValue; } private void setSplitValue( double value ) { this.branchingValue = value; } public boolean isTree() { return branchingIndex > -1; } public static String trim(double value) { String untrimmed = String.valueOf(value); String trimmed = untrimmed; int indexOfDecimal = untrimmed.indexOf('.'); if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) { trimmed = untrimmed.substring(0, indexOfDecimal + 5); } if (untrimmed.indexOf('e') > 0) trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e')); if (untrimmed.indexOf('E') > 0) trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E')); return trimmed; } public String description( String prefix ) { StringBuilder buffer = new StringBuilder(); if( this.isTree() ) { buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" ); buffer.append( this.lt.description( prefix + ".1" ) ); buffer.append( this.gte.description( prefix + ".2" ) ); } else { buffer.append( "Leaf " + prefix + "\n" ); // for( T e : this.exemplars ) { // buffer.append( "\t" ).append( e.description() ).append( "\n" ); // } } return buffer.toString(); } private final int bucketSize; private List<T> exemplars = new ArrayList<T>(); private int branchingIndex = -1; private double branchingValue = Double.NaN; private KDBucketTree<T> lt = null; private KDBucketTree<T> gte = null; }
public class Neighborhood<T extends Exemplar> { public Neighborhood( double[] control, T sandbox[] ) { this.control = control; this.neighbors = sandbox; this.dSquareds = new double[this.neighbors.length]; Arrays.fill( this.dSquareds, Double.MAX_VALUE ); } public void evaluate( Iterable<T> exemplars ) { for( T e : exemplars ) evaluate( e ); } public void evaluate( T e ) { double dSquared = this.calculateDistanceSquared( this.control, e.domain ); if( dSquared < this.getGreatestSquaredDistance() ) { this.neighbors[ this.indexOfGreatestDistanceSquared ] = e; this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared; this.setIndexOfGreatestSquaredDistance(); } } public double getGreatestSquaredDistance() { return this.dSquareds[ this.indexOfGreatestDistanceSquared ]; } public boolean isBranchEligible( int branchIndex, double branchSplit ) { return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit ); } private void setIndexOfGreatestSquaredDistance() { this.indexOfGreatestDistanceSquared = 0; for( int i = 1; i < this.dSquareds.length; i++ ) { if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) { this.indexOfGreatestDistanceSquared = i; } } } public double getControlValueAtIndex( int index ) { return this.control[ index ]; } public T[] getNeighbors() { return this.neighbors; } public double[] getDistances() { double[] distances = new double[ this.dSquareds.length ]; for( int i = 0; i < this.dSquareds.length; i++ ) { distances[ i ] = Math.sqrt( this.dSquareds[ i ] ); } return distances; } private double calculateDistanceSquared( double[] control, double[] reference ) { double ds = 0.0; assert( reference.length == control.length ); for( int i = 0; i < reference.length; i++ ) { if( !Double.isNaN( reference[ i ] ) ) { ds += this.square( control[ i ] - reference[ i ] ); } } // System.out.println( "Calculated distance squared: " + ds ); return ds; } private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) { return this.square( this.control[ branchIndex ] - branchSplit ); } private double square( double v ) { return v * v; } private final double[] control; private final T neighbors[]; private final double dSquareds[]; private int indexOfGreatestDistanceSquared = 0; }