Difference between revisions of "User:Pedersen/kdTree"
Jump to navigation
Jump to search
m (filler) |
RednaxelaBot (talk | contribs) m (Using <syntaxhighlight>.) |
||
(One intermediate revision by one other user not shown) | |||
Line 1: | Line 1: | ||
− | + | <syntaxhighlight> | |
+ | public class PedersenKDBucketTreeKNNSearch extends KNNImplementation { | ||
+ | |||
+ | private final KDBucketTree<StringHoldingExemplar> tree; | ||
+ | |||
+ | public PedersenKDBucketTreeKNNSearch(int dimension) { | ||
+ | super(dimension); | ||
+ | final int bucketSize = 24; | ||
+ | double[] domainMap = new double[ dimension ]; | ||
+ | Arrays.fill( domainMap, Double.NaN ); | ||
+ | this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize ); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public void addPoint(double[] location, String value) { | ||
+ | this.tree.add( new StringHoldingExemplar( location, value ) ); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public String getName() { | ||
+ | return "Pedersen's " + this.tree.getClass().getSimpleName(); | ||
+ | } | ||
+ | |||
+ | @Override | ||
+ | public KNNPoint[] getNearestNeighbors(double[] location, int size) { | ||
+ | Neighborhood<StringHoldingExemplar> neighborhood = | ||
+ | new Neighborhood<StringHoldingExemplar>( location, | ||
+ | new StringHoldingExemplar[ size ] ); | ||
+ | this.tree.findNearestNeighbors( neighborhood ); | ||
+ | |||
+ | // System.out.println(); | ||
+ | // System.out.print( this.tree.description( "1" ) ); | ||
+ | // System.out.println(); | ||
+ | |||
+ | // System.out.println( "\t" + new Exemplar( location ).description() ); | ||
+ | |||
+ | return convert( neighborhood ); | ||
+ | } | ||
+ | |||
+ | private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) { | ||
+ | StringHoldingExemplar[] neighbors = neighborhood.getNeighbors(); | ||
+ | double[] distances = neighborhood.getDistances(); | ||
+ | |||
+ | KNNPoint[] points = new KNNPoint[ neighbors.length ]; | ||
+ | for( int i = 0; i < neighbors.length; i++ ) { | ||
+ | if( neighbors[ i ] != null ) { | ||
+ | // System.out.println( "Sample:\t" + neighbors[i].description() ); | ||
+ | points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] ); | ||
+ | } else { | ||
+ | System.out.println( "Sample [" + i + "] was null." ); | ||
+ | points[ i ] = new KNNPoint( null, Double.NaN ); | ||
+ | } | ||
+ | } | ||
+ | return points; | ||
+ | } | ||
+ | |||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | <syntaxhighlight> | ||
+ | public class Exemplar { | ||
+ | |||
+ | public Exemplar( double domain[] ) { | ||
+ | this.domain = domain; | ||
+ | } | ||
+ | |||
+ | public String description() { | ||
+ | StringBuilder buffer = new StringBuilder(); | ||
+ | buffer.append(this.domain[0]); | ||
+ | for( int i = 1; i < this.domain.length; i++ ) { | ||
+ | buffer.append( "\t" ).append( this.domain[ i ] ); | ||
+ | } | ||
+ | return buffer.toString(); | ||
+ | } | ||
+ | |||
+ | final double domain[]; | ||
+ | |||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | <syntaxhighlight> | ||
+ | public class StringHoldingExemplar extends Exemplar { | ||
+ | |||
+ | public StringHoldingExemplar(double[] domain, String payload) { | ||
+ | super(domain); | ||
+ | this.payload = payload; | ||
+ | } | ||
+ | |||
+ | public String getPayload() { return this.payload; } | ||
+ | |||
+ | private final String payload; | ||
+ | |||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | <syntaxhighlight> | ||
+ | public class KDBucketTree<T extends Exemplar> { | ||
+ | |||
+ | public KDBucketTree( int bucketSize ) { | ||
+ | this.bucketSize = bucketSize; | ||
+ | } | ||
+ | |||
+ | public void findNearestNeighbors( Neighborhood<T> neighborhood ) { | ||
+ | if( !this.isTree() ) { | ||
+ | neighborhood.evaluate( this.exemplars ); | ||
+ | } else { | ||
+ | if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) { | ||
+ | this.lt.findNearestNeighbors( neighborhood ); | ||
+ | if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) { | ||
+ | this.gte.findNearestNeighbors( neighborhood ); | ||
+ | } | ||
+ | } else { | ||
+ | this.gte.findNearestNeighbors( neighborhood ); | ||
+ | if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) { | ||
+ | this.lt.findNearestNeighbors( neighborhood ); | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public void add( T e ) { | ||
+ | if( this.isTree() ) { | ||
+ | if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) { | ||
+ | this.lt.add( e ); | ||
+ | } else { | ||
+ | this.gte.add( e ); | ||
+ | } | ||
+ | } else { | ||
+ | if( this.exemplars.size() > this.bucketSize ) { | ||
+ | System.out.println( "Excessive size: " + this.exemplars.size() ); | ||
+ | } | ||
+ | this.exemplars.add( e ); | ||
+ | if( this.exemplars.size() > this.bucketSize ) { | ||
+ | this.branch(); | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | private void branch() { | ||
+ | this.determineSplittingPoint(); | ||
+ | List<T> ltList = new ArrayList<T>(); | ||
+ | |||
+ | Iterator<T> iterator = this.exemplars.iterator(); | ||
+ | while( iterator.hasNext() ) { | ||
+ | T exemplar = iterator.next(); | ||
+ | if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) { | ||
+ | ltList.add( exemplar ); | ||
+ | iterator.remove(); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | this.lt = new KDBucketTree<T>( this.bucketSize ); | ||
+ | this.lt.put( ltList ); | ||
+ | this.gte = new KDBucketTree<T>( this.bucketSize ); | ||
+ | this.gte.put( this.exemplars ); | ||
+ | this.exemplars = null; | ||
+ | } | ||
+ | |||
+ | private void put( List<T> exemplars ) { | ||
+ | this.exemplars = exemplars; | ||
+ | } | ||
+ | |||
+ | private void determineSplittingPoint() { | ||
+ | Iterator<T> iterator = this.exemplars.iterator(); | ||
+ | if( iterator.hasNext() ) { | ||
+ | T exemplar = iterator.next(); | ||
+ | double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length ); | ||
+ | double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length ); | ||
+ | while( iterator.hasNext() ) { | ||
+ | exemplar = iterator.next(); | ||
+ | for( int i = 0; i < exemplar.domain.length; i++ ) { | ||
+ | minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] ); | ||
+ | maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] ); | ||
+ | } | ||
+ | } | ||
+ | this.branchingIndex = 0; | ||
+ | double maxRange = maximums[ 0 ] - minimums[ 0 ]; | ||
+ | for( int i = 1; i < exemplar.domain.length; i++ ) { | ||
+ | double range = maximums[ i ] - minimums[ i ]; | ||
+ | if( range > maxRange ) { | ||
+ | this.branchingIndex = i; | ||
+ | maxRange = range; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange ); | ||
+ | } | ||
+ | |||
+ | } | ||
+ | |||
+ | private double getSplitValue() { | ||
+ | return this.branchingValue; | ||
+ | } | ||
+ | |||
+ | private void setSplitValue( double value ) { | ||
+ | this.branchingValue = value; | ||
+ | } | ||
+ | |||
+ | public boolean isTree() { return branchingIndex > -1; } | ||
+ | |||
+ | public static String trim(double value) { | ||
+ | String untrimmed = String.valueOf(value); | ||
+ | String trimmed = untrimmed; | ||
+ | int indexOfDecimal = untrimmed.indexOf('.'); | ||
+ | if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) { | ||
+ | trimmed = untrimmed.substring(0, indexOfDecimal + 5); | ||
+ | } | ||
+ | if (untrimmed.indexOf('e') > 0) | ||
+ | trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e')); | ||
+ | if (untrimmed.indexOf('E') > 0) | ||
+ | trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E')); | ||
+ | return trimmed; | ||
+ | } | ||
+ | |||
+ | public String description( String prefix ) { | ||
+ | StringBuilder buffer = new StringBuilder(); | ||
+ | if( this.isTree() ) { | ||
+ | buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" ); | ||
+ | buffer.append( this.lt.description( prefix + ".1" ) ); | ||
+ | buffer.append( this.gte.description( prefix + ".2" ) ); | ||
+ | } else { | ||
+ | buffer.append( "Leaf " + prefix + "\n" ); | ||
+ | // for( T e : this.exemplars ) { | ||
+ | // buffer.append( "\t" ).append( e.description() ).append( "\n" ); | ||
+ | // } | ||
+ | } | ||
+ | return buffer.toString(); | ||
+ | } | ||
+ | |||
+ | private final int bucketSize; | ||
+ | private List<T> exemplars = new ArrayList<T>(); | ||
+ | private int branchingIndex = -1; | ||
+ | private double branchingValue = Double.NaN; | ||
+ | private KDBucketTree<T> lt = null; | ||
+ | private KDBucketTree<T> gte = null; | ||
+ | |||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | <syntaxhighlight> | ||
+ | public class Neighborhood<T extends Exemplar> { | ||
+ | |||
+ | public Neighborhood( double[] control, T sandbox[] ) { | ||
+ | this.control = control; | ||
+ | this.neighbors = sandbox; | ||
+ | this.dSquareds = new double[this.neighbors.length]; | ||
+ | Arrays.fill( this.dSquareds, Double.MAX_VALUE ); | ||
+ | } | ||
+ | |||
+ | public void evaluate( Iterable<T> exemplars ) { | ||
+ | for( T e : exemplars ) evaluate( e ); | ||
+ | } | ||
+ | |||
+ | public void evaluate( T e ) { | ||
+ | double dSquared = this.calculateDistanceSquared( this.control, e.domain ); | ||
+ | if( dSquared < this.getGreatestSquaredDistance() ) { | ||
+ | this.neighbors[ this.indexOfGreatestDistanceSquared ] = e; | ||
+ | this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared; | ||
+ | this.setIndexOfGreatestSquaredDistance(); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public double getGreatestSquaredDistance() { | ||
+ | return this.dSquareds[ this.indexOfGreatestDistanceSquared ]; | ||
+ | } | ||
+ | |||
+ | public boolean isBranchEligible( int branchIndex, double branchSplit ) { | ||
+ | return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit ); | ||
+ | } | ||
+ | |||
+ | private void setIndexOfGreatestSquaredDistance() { | ||
+ | this.indexOfGreatestDistanceSquared = 0; | ||
+ | for( int i = 1; i < this.dSquareds.length; i++ ) { | ||
+ | if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) { | ||
+ | this.indexOfGreatestDistanceSquared = i; | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | public double getControlValueAtIndex( int index ) { | ||
+ | return this.control[ index ]; | ||
+ | } | ||
+ | |||
+ | public T[] getNeighbors() { return this.neighbors; } | ||
+ | |||
+ | public double[] getDistances() { | ||
+ | double[] distances = new double[ this.dSquareds.length ]; | ||
+ | for( int i = 0; i < this.dSquareds.length; i++ ) { | ||
+ | distances[ i ] = Math.sqrt( this.dSquareds[ i ] ); | ||
+ | } | ||
+ | return distances; | ||
+ | } | ||
+ | |||
+ | private double calculateDistanceSquared( double[] control, double[] reference ) { | ||
+ | double ds = 0.0; | ||
+ | assert( reference.length == control.length ); | ||
+ | for( int i = 0; i < reference.length; i++ ) { | ||
+ | if( !Double.isNaN( reference[ i ] ) ) { | ||
+ | ds += this.square( control[ i ] - reference[ i ] ); | ||
+ | } | ||
+ | } | ||
+ | // System.out.println( "Calculated distance squared: " + ds ); | ||
+ | return ds; | ||
+ | } | ||
+ | |||
+ | private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) { | ||
+ | return this.square( this.control[ branchIndex ] - branchSplit ); | ||
+ | } | ||
+ | |||
+ | private double square( double v ) { return v * v; } | ||
+ | |||
+ | private final double[] control; | ||
+ | private final T neighbors[]; | ||
+ | private final double dSquareds[]; | ||
+ | private int indexOfGreatestDistanceSquared = 0; | ||
+ | |||
+ | } | ||
+ | </syntaxhighlight> |
Latest revision as of 09:37, 1 July 2010
public class PedersenKDBucketTreeKNNSearch extends KNNImplementation {
private final KDBucketTree<StringHoldingExemplar> tree;
public PedersenKDBucketTreeKNNSearch(int dimension) {
super(dimension);
final int bucketSize = 24;
double[] domainMap = new double[ dimension ];
Arrays.fill( domainMap, Double.NaN );
this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize );
}
@Override
public void addPoint(double[] location, String value) {
this.tree.add( new StringHoldingExemplar( location, value ) );
}
@Override
public String getName() {
return "Pedersen's " + this.tree.getClass().getSimpleName();
}
@Override
public KNNPoint[] getNearestNeighbors(double[] location, int size) {
Neighborhood<StringHoldingExemplar> neighborhood =
new Neighborhood<StringHoldingExemplar>( location,
new StringHoldingExemplar[ size ] );
this.tree.findNearestNeighbors( neighborhood );
// System.out.println();
// System.out.print( this.tree.description( "1" ) );
// System.out.println();
// System.out.println( "\t" + new Exemplar( location ).description() );
return convert( neighborhood );
}
private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) {
StringHoldingExemplar[] neighbors = neighborhood.getNeighbors();
double[] distances = neighborhood.getDistances();
KNNPoint[] points = new KNNPoint[ neighbors.length ];
for( int i = 0; i < neighbors.length; i++ ) {
if( neighbors[ i ] != null ) {
// System.out.println( "Sample:\t" + neighbors[i].description() );
points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] );
} else {
System.out.println( "Sample [" + i + "] was null." );
points[ i ] = new KNNPoint( null, Double.NaN );
}
}
return points;
}
}
public class Exemplar {
public Exemplar( double domain[] ) {
this.domain = domain;
}
public String description() {
StringBuilder buffer = new StringBuilder();
buffer.append(this.domain[0]);
for( int i = 1; i < this.domain.length; i++ ) {
buffer.append( "\t" ).append( this.domain[ i ] );
}
return buffer.toString();
}
final double domain[];
}
public class StringHoldingExemplar extends Exemplar {
public StringHoldingExemplar(double[] domain, String payload) {
super(domain);
this.payload = payload;
}
public String getPayload() { return this.payload; }
private final String payload;
}
public class KDBucketTree<T extends Exemplar> {
public KDBucketTree( int bucketSize ) {
this.bucketSize = bucketSize;
}
public void findNearestNeighbors( Neighborhood<T> neighborhood ) {
if( !this.isTree() ) {
neighborhood.evaluate( this.exemplars );
} else {
if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) {
this.lt.findNearestNeighbors( neighborhood );
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
this.gte.findNearestNeighbors( neighborhood );
}
} else {
this.gte.findNearestNeighbors( neighborhood );
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
this.lt.findNearestNeighbors( neighborhood );
}
}
}
}
public void add( T e ) {
if( this.isTree() ) {
if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) {
this.lt.add( e );
} else {
this.gte.add( e );
}
} else {
if( this.exemplars.size() > this.bucketSize ) {
System.out.println( "Excessive size: " + this.exemplars.size() );
}
this.exemplars.add( e );
if( this.exemplars.size() > this.bucketSize ) {
this.branch();
}
}
}
private void branch() {
this.determineSplittingPoint();
List<T> ltList = new ArrayList<T>();
Iterator<T> iterator = this.exemplars.iterator();
while( iterator.hasNext() ) {
T exemplar = iterator.next();
if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) {
ltList.add( exemplar );
iterator.remove();
}
}
this.lt = new KDBucketTree<T>( this.bucketSize );
this.lt.put( ltList );
this.gte = new KDBucketTree<T>( this.bucketSize );
this.gte.put( this.exemplars );
this.exemplars = null;
}
private void put( List<T> exemplars ) {
this.exemplars = exemplars;
}
private void determineSplittingPoint() {
Iterator<T> iterator = this.exemplars.iterator();
if( iterator.hasNext() ) {
T exemplar = iterator.next();
double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
while( iterator.hasNext() ) {
exemplar = iterator.next();
for( int i = 0; i < exemplar.domain.length; i++ ) {
minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] );
maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] );
}
}
this.branchingIndex = 0;
double maxRange = maximums[ 0 ] - minimums[ 0 ];
for( int i = 1; i < exemplar.domain.length; i++ ) {
double range = maximums[ i ] - minimums[ i ];
if( range > maxRange ) {
this.branchingIndex = i;
maxRange = range;
}
}
this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange );
}
}
private double getSplitValue() {
return this.branchingValue;
}
private void setSplitValue( double value ) {
this.branchingValue = value;
}
public boolean isTree() { return branchingIndex > -1; }
public static String trim(double value) {
String untrimmed = String.valueOf(value);
String trimmed = untrimmed;
int indexOfDecimal = untrimmed.indexOf('.');
if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) {
trimmed = untrimmed.substring(0, indexOfDecimal + 5);
}
if (untrimmed.indexOf('e') > 0)
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e'));
if (untrimmed.indexOf('E') > 0)
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E'));
return trimmed;
}
public String description( String prefix ) {
StringBuilder buffer = new StringBuilder();
if( this.isTree() ) {
buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" );
buffer.append( this.lt.description( prefix + ".1" ) );
buffer.append( this.gte.description( prefix + ".2" ) );
} else {
buffer.append( "Leaf " + prefix + "\n" );
// for( T e : this.exemplars ) {
// buffer.append( "\t" ).append( e.description() ).append( "\n" );
// }
}
return buffer.toString();
}
private final int bucketSize;
private List<T> exemplars = new ArrayList<T>();
private int branchingIndex = -1;
private double branchingValue = Double.NaN;
private KDBucketTree<T> lt = null;
private KDBucketTree<T> gte = null;
}
public class Neighborhood<T extends Exemplar> {
public Neighborhood( double[] control, T sandbox[] ) {
this.control = control;
this.neighbors = sandbox;
this.dSquareds = new double[this.neighbors.length];
Arrays.fill( this.dSquareds, Double.MAX_VALUE );
}
public void evaluate( Iterable<T> exemplars ) {
for( T e : exemplars ) evaluate( e );
}
public void evaluate( T e ) {
double dSquared = this.calculateDistanceSquared( this.control, e.domain );
if( dSquared < this.getGreatestSquaredDistance() ) {
this.neighbors[ this.indexOfGreatestDistanceSquared ] = e;
this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared;
this.setIndexOfGreatestSquaredDistance();
}
}
public double getGreatestSquaredDistance() {
return this.dSquareds[ this.indexOfGreatestDistanceSquared ];
}
public boolean isBranchEligible( int branchIndex, double branchSplit ) {
return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit );
}
private void setIndexOfGreatestSquaredDistance() {
this.indexOfGreatestDistanceSquared = 0;
for( int i = 1; i < this.dSquareds.length; i++ ) {
if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) {
this.indexOfGreatestDistanceSquared = i;
}
}
}
public double getControlValueAtIndex( int index ) {
return this.control[ index ];
}
public T[] getNeighbors() { return this.neighbors; }
public double[] getDistances() {
double[] distances = new double[ this.dSquareds.length ];
for( int i = 0; i < this.dSquareds.length; i++ ) {
distances[ i ] = Math.sqrt( this.dSquareds[ i ] );
}
return distances;
}
private double calculateDistanceSquared( double[] control, double[] reference ) {
double ds = 0.0;
assert( reference.length == control.length );
for( int i = 0; i < reference.length; i++ ) {
if( !Double.isNaN( reference[ i ] ) ) {
ds += this.square( control[ i ] - reference[ i ] );
}
}
// System.out.println( "Calculated distance squared: " + ds );
return ds;
}
private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) {
return this.square( this.control[ branchIndex ] - branchSplit );
}
private double square( double v ) { return v * v; }
private final double[] control;
private final T neighbors[];
private final double dSquareds[];
private int indexOfGreatestDistanceSquared = 0;
}