Difference between revisions of "User:Pedersen/kdTree"

From Robowiki
Jump to navigation Jump to search
m (filler)
 
(Initial implementation)
Line 1: Line 1:
Nothing to share yet. Once I get the accuracy to 100% for 100k data samples, I'll release the code.
+
<pre>
 +
public class PedersenKDBucketTreeKNNSearch extends KNNImplementation {
 +
 
 +
private final KDBucketTree<StringHoldingExemplar> tree;
 +
 +
public PedersenKDBucketTreeKNNSearch(int dimension) {
 +
super(dimension);
 +
final int bucketSize = 24;
 +
double[] domainMap = new double[ dimension ];
 +
Arrays.fill( domainMap, Double.NaN );
 +
this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize );
 +
}
 +
 
 +
@Override
 +
public void addPoint(double[] location, String value) {
 +
this.tree.add( new StringHoldingExemplar( location, value ) );
 +
}
 +
 
 +
@Override
 +
public String getName() {
 +
return "Pedersen's " + this.tree.getClass().getSimpleName();
 +
}
 +
 
 +
@Override
 +
public KNNPoint[] getNearestNeighbors(double[] location, int size) {
 +
Neighborhood<StringHoldingExemplar> neighborhood =
 +
new Neighborhood<StringHoldingExemplar>( location,
 +
new StringHoldingExemplar[ size ] );
 +
this.tree.findNearestNeighbors( neighborhood );
 +
 
 +
// System.out.println();
 +
// System.out.print( this.tree.description( "1" ) );
 +
// System.out.println();
 +
 
 +
// System.out.println( "\t" + new Exemplar( location ).description() );
 +
 +
return convert( neighborhood );
 +
}
 +
 
 +
private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) {
 +
StringHoldingExemplar[] neighbors = neighborhood.getNeighbors();
 +
double[] distances = neighborhood.getDistances();
 +
 
 +
KNNPoint[] points = new KNNPoint[ neighbors.length ];
 +
for( int i = 0; i < neighbors.length; i++ ) {
 +
if( neighbors[ i ] != null ) {
 +
// System.out.println( "Sample:\t" + neighbors[i].description() );
 +
points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] );
 +
} else {
 +
System.out.println( "Sample [" + i + "] was null." );
 +
points[ i ] = new KNNPoint( null, Double.NaN );
 +
}
 +
}
 +
return points;
 +
}
 +
 +
}
 +
</pre>
 +
 
 +
<pre>
 +
public class Exemplar {
 +
 
 +
public Exemplar( double domain[] ) {
 +
this.domain = domain;
 +
}
 +
 
 +
public String description() {
 +
StringBuilder buffer = new StringBuilder();
 +
buffer.append(this.domain[0]);
 +
for( int i = 1; i < this.domain.length; i++ ) {
 +
buffer.append( "\t" ).append( this.domain[ i ] );
 +
}
 +
return buffer.toString();
 +
}
 +
 +
final double domain[];
 +
 +
}
 +
</pre>
 +
 
 +
<pre>
 +
public class StringHoldingExemplar extends Exemplar {
 +
 
 +
public StringHoldingExemplar(double[] domain, String payload) {
 +
super(domain);
 +
this.payload = payload;
 +
}
 +
 
 +
public String getPayload() { return this.payload; }
 +
 
 +
private final String payload;
 +
 +
}
 +
</pre>
 +
 
 +
<pre>
 +
public class KDBucketTree<T extends Exemplar> {
 +
 +
public KDBucketTree( int bucketSize ) {
 +
this.bucketSize = bucketSize;
 +
}
 +
 
 +
public void findNearestNeighbors( Neighborhood<T> neighborhood ) {
 +
if( !this.isTree() ) {
 +
neighborhood.evaluate( this.exemplars );
 +
} else {
 +
if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) {
 +
this.lt.findNearestNeighbors( neighborhood );
 +
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
 +
this.gte.findNearestNeighbors( neighborhood );
 +
}
 +
} else {
 +
this.gte.findNearestNeighbors( neighborhood );
 +
if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
 +
this.lt.findNearestNeighbors( neighborhood );
 +
}
 +
}
 +
}
 +
}
 +
 +
public void add( T e ) {
 +
if( this.isTree() ) {
 +
if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) {
 +
this.lt.add( e );
 +
} else {
 +
this.gte.add( e );
 +
}
 +
} else {
 +
if( this.exemplars.size() > this.bucketSize ) {
 +
System.out.println( "Excessive size: " + this.exemplars.size() );
 +
}
 +
this.exemplars.add( e );
 +
if( this.exemplars.size() > this.bucketSize ) {
 +
this.branch();
 +
}
 +
}
 +
}
 +
 +
private void branch() {
 +
this.determineSplittingPoint();
 +
List<T> ltList = new ArrayList<T>();
 +
 
 +
Iterator<T> iterator = this.exemplars.iterator();
 +
while( iterator.hasNext() ) {
 +
T exemplar = iterator.next();
 +
if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) {
 +
ltList.add( exemplar );
 +
iterator.remove();
 +
}
 +
}
 +
 
 +
this.lt = new KDBucketTree<T>( this.bucketSize );
 +
this.lt.put( ltList );
 +
this.gte = new KDBucketTree<T>( this.bucketSize );
 +
this.gte.put( this.exemplars );
 +
this.exemplars = null;
 +
}
 +
 +
private void put( List<T> exemplars ) {
 +
this.exemplars = exemplars;
 +
}
 +
 +
private void determineSplittingPoint() {
 +
Iterator<T> iterator = this.exemplars.iterator();
 +
if( iterator.hasNext() ) {
 +
T exemplar = iterator.next();
 +
double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
 +
double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
 +
while( iterator.hasNext() ) {
 +
exemplar = iterator.next();
 +
for( int i = 0; i < exemplar.domain.length; i++ ) {
 +
minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] );
 +
maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] );
 +
}
 +
}
 +
this.branchingIndex = 0;
 +
double maxRange = maximums[ 0 ] - minimums[ 0 ];
 +
for( int i = 1; i < exemplar.domain.length; i++ ) {
 +
double range = maximums[ i ] - minimums[ i ];
 +
if( range > maxRange ) {
 +
this.branchingIndex = i;
 +
maxRange = range;
 +
}
 +
}
 +
 
 +
this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange );
 +
}
 +
 +
}
 +
 +
private double getSplitValue() {
 +
return this.branchingValue;
 +
}
 +
 +
private void setSplitValue( double value ) {
 +
this.branchingValue = value;
 +
}
 +
 
 +
public boolean isTree() { return branchingIndex > -1; }
 +
 +
public static String trim(double value) {
 +
String untrimmed = String.valueOf(value);
 +
String trimmed = untrimmed;
 +
int indexOfDecimal = untrimmed.indexOf('.');
 +
if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) {
 +
trimmed = untrimmed.substring(0, indexOfDecimal + 5);
 +
}
 +
if (untrimmed.indexOf('e') > 0)
 +
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e'));
 +
if (untrimmed.indexOf('E') > 0)
 +
trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E'));
 +
return trimmed;
 +
}
 +
 
 +
public String description( String prefix ) {
 +
StringBuilder buffer = new StringBuilder();
 +
if( this.isTree() ) {
 +
buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" );
 +
buffer.append( this.lt.description( prefix + ".1" ) );
 +
buffer.append( this.gte.description( prefix + ".2" ) );
 +
} else {
 +
buffer.append( "Leaf " + prefix + "\n" );
 +
// for( T e : this.exemplars ) {
 +
// buffer.append( "\t" ).append( e.description() ).append( "\n" );
 +
// }
 +
}
 +
return buffer.toString();
 +
}
 +
 +
private final int bucketSize;
 +
private List<T> exemplars = new ArrayList<T>();
 +
private int branchingIndex = -1;
 +
private double branchingValue = Double.NaN;
 +
private KDBucketTree<T> lt = null;
 +
private KDBucketTree<T> gte = null;
 +
 +
}
 +
</pre>
 +
 
 +
<pre>
 +
public class Neighborhood<T extends Exemplar> {
 +
 
 +
public Neighborhood( double[] control, T sandbox[] ) {
 +
this.control = control;
 +
this.neighbors = sandbox;
 +
this.dSquareds = new double[this.neighbors.length];
 +
Arrays.fill( this.dSquareds, Double.MAX_VALUE );
 +
}
 +
 
 +
public void evaluate( Iterable<T> exemplars ) {
 +
for( T e : exemplars ) evaluate( e );
 +
}
 +
 +
public void evaluate( T e ) {
 +
double dSquared = this.calculateDistanceSquared( this.control, e.domain );
 +
if( dSquared < this.getGreatestSquaredDistance() ) {
 +
this.neighbors[ this.indexOfGreatestDistanceSquared ] = e;
 +
this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared;
 +
this.setIndexOfGreatestSquaredDistance();
 +
}
 +
}
 +
 +
public double getGreatestSquaredDistance() {
 +
return this.dSquareds[ this.indexOfGreatestDistanceSquared ];
 +
}
 +
 +
public boolean isBranchEligible( int branchIndex, double branchSplit ) {
 +
return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit );
 +
}
 +
 
 +
private void setIndexOfGreatestSquaredDistance() {
 +
this.indexOfGreatestDistanceSquared = 0;
 +
for( int i = 1; i < this.dSquareds.length; i++ ) {
 +
if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) {
 +
this.indexOfGreatestDistanceSquared = i;
 +
}
 +
}
 +
}
 +
 +
public double getControlValueAtIndex( int index ) {
 +
return this.control[ index ];
 +
}
 +
 +
public T[] getNeighbors() { return this.neighbors; }
 +
 
 +
public double[] getDistances() {
 +
double[] distances = new double[ this.dSquareds.length ];
 +
for( int i = 0; i < this.dSquareds.length; i++ ) {
 +
distances[ i ] = Math.sqrt( this.dSquareds[ i ] );
 +
}
 +
return distances;
 +
}
 +
 
 +
private double calculateDistanceSquared( double[] control, double[] reference ) {
 +
double ds = 0.0;
 +
assert( reference.length == control.length );
 +
for( int i = 0; i < reference.length; i++ ) {
 +
if( !Double.isNaN( reference[ i ] ) ) {
 +
ds += this.square( control[ i ] - reference[ i ] );
 +
}
 +
}
 +
// System.out.println( "Calculated distance squared: " + ds );
 +
return ds;
 +
}
 +
 
 +
private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) {
 +
return this.square( this.control[ branchIndex ] - branchSplit );
 +
}
 +
 
 +
private double square( double v ) { return v * v; }
 +
 
 +
private final double[] control;
 +
private final T neighbors[];
 +
private final double dSquareds[];
 +
private int indexOfGreatestDistanceSquared = 0;
 +
 +
}
 +
</pre>

Revision as of 01:51, 9 March 2010

public class PedersenKDBucketTreeKNNSearch extends KNNImplementation {

	private final KDBucketTree<StringHoldingExemplar> tree;
	
	public PedersenKDBucketTreeKNNSearch(int dimension) {
		super(dimension);
		final int bucketSize = 24;
		double[] domainMap = new double[ dimension ];
		Arrays.fill( domainMap, Double.NaN );
		this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize );
	}

	@Override
	public void addPoint(double[] location, String value) {
		this.tree.add( new StringHoldingExemplar( location, value ) );
	}

	@Override
	public String getName() {
		return "Pedersen's " + this.tree.getClass().getSimpleName();
	}

	@Override
	public KNNPoint[] getNearestNeighbors(double[] location, int size) {
		Neighborhood<StringHoldingExemplar> neighborhood = 
				new Neighborhood<StringHoldingExemplar>( location,
						new StringHoldingExemplar[ size ] );
		this.tree.findNearestNeighbors( neighborhood );

//		System.out.println();
//		System.out.print( this.tree.description( "1" ) );
//		System.out.println();

//		System.out.println( "\t" + new Exemplar( location ).description() );
		
		return convert( neighborhood );
	}

	private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) {
		StringHoldingExemplar[] neighbors = neighborhood.getNeighbors();
		double[] distances = neighborhood.getDistances();

		KNNPoint[] points = new KNNPoint[ neighbors.length ];
		for( int i = 0; i < neighbors.length; i++ ) {
			if( neighbors[ i ] != null ) {
//				System.out.println( "Sample:\t" + neighbors[i].description() );
				points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] );
			} else {
				System.out.println( "Sample [" + i + "] was null." );
				points[ i ] = new KNNPoint( null, Double.NaN );
			}
		}
		return points;
	}
	
}
public class Exemplar {

	public Exemplar( double domain[] ) {
		this.domain = domain;
	}

	public String description() {
		StringBuilder buffer = new StringBuilder();
		buffer.append(this.domain[0]);
		for( int i = 1; i < this.domain.length; i++ ) {
			buffer.append( "\t" ).append( this.domain[ i ] );
		}
		return buffer.toString();
	}
	
	final double domain[];
	
}
public class StringHoldingExemplar extends Exemplar {

	public StringHoldingExemplar(double[] domain, String payload) {
		super(domain);
		this.payload = payload;
	}

	public String getPayload() { return this.payload; }

	private final String payload;
	
}
public class KDBucketTree<T extends Exemplar> {
	
	public KDBucketTree( int bucketSize ) {
		this.bucketSize = bucketSize;
	}

	public void findNearestNeighbors( Neighborhood<T> neighborhood ) {
		if( !this.isTree() ) {
			neighborhood.evaluate( this.exemplars );
		} else {
			if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) {
				this.lt.findNearestNeighbors( neighborhood );
				if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
					this.gte.findNearestNeighbors( neighborhood );
				}
			} else {
				this.gte.findNearestNeighbors( neighborhood );
				if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
					this.lt.findNearestNeighbors( neighborhood );
				}
			}
		}
	}
	
	public void add( T e ) {
		if( this.isTree() ) {
			if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) {
				this.lt.add( e );
			} else {
				this.gte.add( e );
			}
		} else {
			if( this.exemplars.size() > this.bucketSize ) {
				System.out.println( "Excessive size: " + this.exemplars.size() );
			}
			this.exemplars.add( e );
			if( this.exemplars.size() > this.bucketSize ) {
				this.branch();
			}
		}
	}
		
	private void branch() {
		this.determineSplittingPoint();
		List<T> ltList = new ArrayList<T>();

		Iterator<T> iterator = this.exemplars.iterator();
		while( iterator.hasNext() ) {
			T exemplar = iterator.next();
			if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) {
				ltList.add( exemplar );
				iterator.remove();
			}
		}

		this.lt = new KDBucketTree<T>( this.bucketSize );
		this.lt.put( ltList );
		this.gte = new KDBucketTree<T>( this.bucketSize );
		this.gte.put( this.exemplars );
		this.exemplars = null;
	}
	
	private void put( List<T> exemplars ) {
		this.exemplars = exemplars;
	}
	
	private void determineSplittingPoint() {
		Iterator<T> iterator = this.exemplars.iterator();
		if( iterator.hasNext() ) {
			T exemplar = iterator.next();
			double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
			double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
			while( iterator.hasNext() ) {
				exemplar = iterator.next();
				for( int i = 0; i < exemplar.domain.length; i++ ) {
					minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] );
					maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] );
				}
			}
			this.branchingIndex = 0;
			double maxRange = maximums[ 0 ] - minimums[ 0 ];
			for( int i = 1; i < exemplar.domain.length; i++ ) {
				double range = maximums[ i ] - minimums[ i ];
				if( range > maxRange ) {
					this.branchingIndex = i;
					maxRange = range;
				}
			}

			this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange );
		}
		
	}
	
	private double getSplitValue() {
		return this.branchingValue;
	}
	
	private void setSplitValue( double value ) {
		this.branchingValue = value;
	}

	public boolean isTree() { return branchingIndex > -1; }
	
	public static String trim(double value) {
		String untrimmed = String.valueOf(value);
		String trimmed = untrimmed;
		int indexOfDecimal = untrimmed.indexOf('.');
		if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) {
			trimmed = untrimmed.substring(0, indexOfDecimal + 5);
		}
		if (untrimmed.indexOf('e') > 0)
			trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e'));
		if (untrimmed.indexOf('E') > 0)
			trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E'));
		return trimmed;
	}

	public String description( String prefix ) {
		StringBuilder buffer = new StringBuilder();
		if( this.isTree() ) {
			buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" );
			buffer.append( this.lt.description( prefix + ".1" ) );
			buffer.append( this.gte.description( prefix + ".2" ) );
		} else {
			buffer.append( "Leaf " + prefix + "\n" );
//			for( T e : this.exemplars ) {
//				buffer.append( "\t" ).append( e.description() ).append( "\n" );
//			}
		}
		return buffer.toString();
	}
		
	private final int bucketSize;
	private List<T> exemplars = new ArrayList<T>();
	private int branchingIndex = -1;
	private double branchingValue = Double.NaN;
	private KDBucketTree<T> lt = null;
	private KDBucketTree<T> gte = null;
	
}
public class Neighborhood<T extends Exemplar> {

	public Neighborhood( double[] control, T sandbox[] ) {
		this.control = control;
		this.neighbors = sandbox;
		this.dSquareds = new double[this.neighbors.length];
		Arrays.fill( this.dSquareds, Double.MAX_VALUE );
	}

	public void evaluate( Iterable<T> exemplars ) {
		for( T e : exemplars ) evaluate( e );
	}
	
	public void evaluate( T e ) {
		double dSquared = this.calculateDistanceSquared( this.control, e.domain );
		if( dSquared < this.getGreatestSquaredDistance() ) {
			this.neighbors[ this.indexOfGreatestDistanceSquared ] = e;
			this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared;
			this.setIndexOfGreatestSquaredDistance();
		}
	}
	
	public double getGreatestSquaredDistance() {
		return this.dSquareds[ this.indexOfGreatestDistanceSquared ];
	}
	
	public boolean isBranchEligible( int branchIndex, double branchSplit ) {
		return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit );
	}

	private void setIndexOfGreatestSquaredDistance() {
		this.indexOfGreatestDistanceSquared = 0;
		for( int i = 1; i < this.dSquareds.length; i++ ) {
			if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) {
				this.indexOfGreatestDistanceSquared = i;
			}
		}
	}
	
	public double getControlValueAtIndex( int index ) {
		return this.control[ index ];
	}
	
	public T[] getNeighbors() { return this.neighbors; }

	public double[] getDistances() {
		double[] distances = new double[ this.dSquareds.length ];
		for( int i = 0; i < this.dSquareds.length; i++ ) {
			distances[ i ] = Math.sqrt( this.dSquareds[ i ] );
		}
		return distances;
	}

	private double calculateDistanceSquared( double[] control, double[] reference ) {
		double ds = 0.0;
		assert( reference.length == control.length );
		for( int i = 0; i < reference.length; i++ ) {
			if( !Double.isNaN( reference[ i ] ) ) {
				ds += this.square( control[ i ] - reference[ i ] );
			}
		}
//		System.out.println( "Calculated distance squared: " + ds );
		return ds;
	}

	private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) {
		return this.square( this.control[ branchIndex ] - branchSplit );
	}

	private double square( double v ) { return v * v; }

	private final double[] control;
	private final T neighbors[];
	private final double dSquareds[];
	private int indexOfGreatestDistanceSquared = 0;
	
}