User:Pedersen/kdTree

From Robowiki
Jump to navigation Jump to search
public class PedersenKDBucketTreeKNNSearch extends KNNImplementation {

	private final KDBucketTree<StringHoldingExemplar> tree;
	
	public PedersenKDBucketTreeKNNSearch(int dimension) {
		super(dimension);
		final int bucketSize = 24;
		double[] domainMap = new double[ dimension ];
		Arrays.fill( domainMap, Double.NaN );
		this.tree = new KDBucketTree<StringHoldingExemplar>( bucketSize );
	}

	@Override
	public void addPoint(double[] location, String value) {
		this.tree.add( new StringHoldingExemplar( location, value ) );
	}

	@Override
	public String getName() {
		return "Pedersen's " + this.tree.getClass().getSimpleName();
	}

	@Override
	public KNNPoint[] getNearestNeighbors(double[] location, int size) {
		Neighborhood<StringHoldingExemplar> neighborhood = 
				new Neighborhood<StringHoldingExemplar>( location,
						new StringHoldingExemplar[ size ] );
		this.tree.findNearestNeighbors( neighborhood );

//		System.out.println();
//		System.out.print( this.tree.description( "1" ) );
//		System.out.println();

//		System.out.println( "\t" + new Exemplar( location ).description() );
		
		return convert( neighborhood );
	}

	private KNNPoint[] convert( Neighborhood<StringHoldingExemplar> neighborhood ) {
		StringHoldingExemplar[] neighbors = neighborhood.getNeighbors();
		double[] distances = neighborhood.getDistances();

		KNNPoint[] points = new KNNPoint[ neighbors.length ];
		for( int i = 0; i < neighbors.length; i++ ) {
			if( neighbors[ i ] != null ) {
//				System.out.println( "Sample:\t" + neighbors[i].description() );
				points[ i ] = new KNNPoint( neighbors[ i ].getPayload(), distances[ i ] );
			} else {
				System.out.println( "Sample [" + i + "] was null." );
				points[ i ] = new KNNPoint( null, Double.NaN );
			}
		}
		return points;
	}
	
}
public class Exemplar {

	public Exemplar( double domain[] ) {
		this.domain = domain;
	}

	public String description() {
		StringBuilder buffer = new StringBuilder();
		buffer.append(this.domain[0]);
		for( int i = 1; i < this.domain.length; i++ ) {
			buffer.append( "\t" ).append( this.domain[ i ] );
		}
		return buffer.toString();
	}
	
	final double domain[];
	
}
public class StringHoldingExemplar extends Exemplar {

	public StringHoldingExemplar(double[] domain, String payload) {
		super(domain);
		this.payload = payload;
	}

	public String getPayload() { return this.payload; }

	private final String payload;
	
}
public class KDBucketTree<T extends Exemplar> {
	
	public KDBucketTree( int bucketSize ) {
		this.bucketSize = bucketSize;
	}

	public void findNearestNeighbors( Neighborhood<T> neighborhood ) {
		if( !this.isTree() ) {
			neighborhood.evaluate( this.exemplars );
		} else {
			if( neighborhood.getControlValueAtIndex( this.branchingIndex ) < this.branchingValue ) {
				this.lt.findNearestNeighbors( neighborhood );
				if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
					this.gte.findNearestNeighbors( neighborhood );
				}
			} else {
				this.gte.findNearestNeighbors( neighborhood );
				if( neighborhood.isBranchEligible( this.branchingIndex, this.branchingValue ) ) {
					this.lt.findNearestNeighbors( neighborhood );
				}
			}
		}
	}
	
	public void add( T e ) {
		if( this.isTree() ) {
			if( e.domain[ this.branchingIndex ] < this.getSplitValue() ) {
				this.lt.add( e );
			} else {
				this.gte.add( e );
			}
		} else {
			if( this.exemplars.size() > this.bucketSize ) {
				System.out.println( "Excessive size: " + this.exemplars.size() );
			}
			this.exemplars.add( e );
			if( this.exemplars.size() > this.bucketSize ) {
				this.branch();
			}
		}
	}
		
	private void branch() {
		this.determineSplittingPoint();
		List<T> ltList = new ArrayList<T>();

		Iterator<T> iterator = this.exemplars.iterator();
		while( iterator.hasNext() ) {
			T exemplar = iterator.next();
			if( exemplar.domain[ this.branchingIndex ] < this.getSplitValue() ) {
				ltList.add( exemplar );
				iterator.remove();
			}
		}

		this.lt = new KDBucketTree<T>( this.bucketSize );
		this.lt.put( ltList );
		this.gte = new KDBucketTree<T>( this.bucketSize );
		this.gte.put( this.exemplars );
		this.exemplars = null;
	}
	
	private void put( List<T> exemplars ) {
		this.exemplars = exemplars;
	}
	
	private void determineSplittingPoint() {
		Iterator<T> iterator = this.exemplars.iterator();
		if( iterator.hasNext() ) {
			T exemplar = iterator.next();
			double[] minimums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
			double[] maximums = Arrays.copyOf( exemplar.domain, exemplar.domain.length );
			while( iterator.hasNext() ) {
				exemplar = iterator.next();
				for( int i = 0; i < exemplar.domain.length; i++ ) {
					minimums[ i ] = Math.min( minimums[ i ], exemplar.domain[ i ] );
					maximums[ i ] = Math.max( maximums[ i ], exemplar.domain[ i ] );
				}
			}
			this.branchingIndex = 0;
			double maxRange = maximums[ 0 ] - minimums[ 0 ];
			for( int i = 1; i < exemplar.domain.length; i++ ) {
				double range = maximums[ i ] - minimums[ i ];
				if( range > maxRange ) {
					this.branchingIndex = i;
					maxRange = range;
				}
			}

			this.setSplitValue( minimums[ this.branchingIndex ] + 0.5 * maxRange );
		}
		
	}
	
	private double getSplitValue() {
		return this.branchingValue;
	}
	
	private void setSplitValue( double value ) {
		this.branchingValue = value;
	}

	public boolean isTree() { return branchingIndex > -1; }
	
	public static String trim(double value) {
		String untrimmed = String.valueOf(value);
		String trimmed = untrimmed;
		int indexOfDecimal = untrimmed.indexOf('.');
		if (indexOfDecimal > 0 && untrimmed.length() > indexOfDecimal + 4) {
			trimmed = untrimmed.substring(0, indexOfDecimal + 5);
		}
		if (untrimmed.indexOf('e') > 0)
			trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('e'));
		if (untrimmed.indexOf('E') > 0)
			trimmed = trimmed + untrimmed.substring(untrimmed.indexOf('E'));
		return trimmed;
	}

	public String description( String prefix ) {
		StringBuilder buffer = new StringBuilder();
		if( this.isTree() ) {
			buffer.append( "Tree " + prefix + "\t" + this.branchingIndex + "\t" + this.getSplitValue() + "\n" );
			buffer.append( this.lt.description( prefix + ".1" ) );
			buffer.append( this.gte.description( prefix + ".2" ) );
		} else {
			buffer.append( "Leaf " + prefix + "\n" );
//			for( T e : this.exemplars ) {
//				buffer.append( "\t" ).append( e.description() ).append( "\n" );
//			}
		}
		return buffer.toString();
	}
		
	private final int bucketSize;
	private List<T> exemplars = new ArrayList<T>();
	private int branchingIndex = -1;
	private double branchingValue = Double.NaN;
	private KDBucketTree<T> lt = null;
	private KDBucketTree<T> gte = null;
	
}
public class Neighborhood<T extends Exemplar> {

	public Neighborhood( double[] control, T sandbox[] ) {
		this.control = control;
		this.neighbors = sandbox;
		this.dSquareds = new double[this.neighbors.length];
		Arrays.fill( this.dSquareds, Double.MAX_VALUE );
	}

	public void evaluate( Iterable<T> exemplars ) {
		for( T e : exemplars ) evaluate( e );
	}
	
	public void evaluate( T e ) {
		double dSquared = this.calculateDistanceSquared( this.control, e.domain );
		if( dSquared < this.getGreatestSquaredDistance() ) {
			this.neighbors[ this.indexOfGreatestDistanceSquared ] = e;
			this.dSquareds[ this.indexOfGreatestDistanceSquared ] = dSquared;
			this.setIndexOfGreatestSquaredDistance();
		}
	}
	
	public double getGreatestSquaredDistance() {
		return this.dSquareds[ this.indexOfGreatestDistanceSquared ];
	}
	
	public boolean isBranchEligible( int branchIndex, double branchSplit ) {
		return this.getGreatestSquaredDistance() > this.calculateDistanceSquaredToSplittingPlane( branchIndex, branchSplit );
	}

	private void setIndexOfGreatestSquaredDistance() {
		this.indexOfGreatestDistanceSquared = 0;
		for( int i = 1; i < this.dSquareds.length; i++ ) {
			if( this.dSquareds[ i ] > this.getGreatestSquaredDistance() ) {
				this.indexOfGreatestDistanceSquared = i;
			}
		}
	}
	
	public double getControlValueAtIndex( int index ) {
		return this.control[ index ];
	}
	
	public T[] getNeighbors() { return this.neighbors; }

	public double[] getDistances() {
		double[] distances = new double[ this.dSquareds.length ];
		for( int i = 0; i < this.dSquareds.length; i++ ) {
			distances[ i ] = Math.sqrt( this.dSquareds[ i ] );
		}
		return distances;
	}

	private double calculateDistanceSquared( double[] control, double[] reference ) {
		double ds = 0.0;
		assert( reference.length == control.length );
		for( int i = 0; i < reference.length; i++ ) {
			if( !Double.isNaN( reference[ i ] ) ) {
				ds += this.square( control[ i ] - reference[ i ] );
			}
		}
//		System.out.println( "Calculated distance squared: " + ds );
		return ds;
	}

	private double calculateDistanceSquaredToSplittingPlane( int branchIndex, double branchSplit ) {
		return this.square( this.control[ branchIndex ] - branchSplit );
	}

	private double square( double v ) { return v * v; }

	private final double[] control;
	private final T neighbors[];
	private final double dSquareds[];
	private int indexOfGreatestDistanceSquared = 0;
	
}