/* Copyright (C) 2003-2008 Dan Arlow
 * 
 * This file is part of motifADE.
 * 
 * motifADE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * motifADE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with motifADE.  If not, see <http://www.gnu.org/licenses/>.
 */

/* 
 *  weighted_sampling_without_replacement.hpp
 *  
 *  (Potentially novel) CDF-splitting algorithm for efficient sampling without
 *  replacement from a sequence of items with probability proportional to given
 *  weights. The main advantage of this algorithm over a partial-sum tree is that
 *  resetting the data structure takes O(1) time.
 */


#ifndef WEIGHTED_SAMPLING_WITHOUT_REPLACEMENT_H
#define WEIGHTED_SAMPLING_WITHOUT_REPLACEMENT_H


#include <algorithm>
#include <numeric>

#ifdef MOTIFADE_WSWR_DEBUG
#include <iostream>
#include <iomanip>
#endif


using namespace std;


template< class __DataIterType, class __CDFIterType >
class WeightedSamplingWithoutReplacement {
public:
	// important: requires *( startCDF - 1 ) == 0.0
	WeightedSamplingWithoutReplacement( __DataIterType dataStart, __DataIterType dataEnd, __CDFIterType cdfStart, __CDFIterType cdfEnd )
	: startData( dataStart ), endData( dataEnd ), startCDF( cdfStart ), endCDF( cdfEnd )//, RANDOM_MAX( 0x7FFFFFFF )
	{
		if( *( startCDF - 1 ) != 0.0 ) { cerr << "WeightedSamplingWithoutReplacement requires that *( startCDF - 1 ) == 0.0" << endl; exit( -1 ); }
		chunks = new CDFChunk[ dataEnd - dataStart ];
		reset();
	}
	
	~WeightedSamplingWithoutReplacement()
	{
		delete chunks;
	}
	
	static const double RANDOM_MAX;
//	const double RANDOM_MAX;

	double random_double() const {
		return static_cast< double >( random() ) / RANDOM_MAX;
	}
	
	void reset()
	{
		*chunks = CDFChunk( 0, &root, true, startCDF, endCDF );
		root = chunks;
		nextChunk = chunks + 1;
		numRemaining = endCDF - startCDF;
	}
	
	unsigned int getNumRemaining() const
	{
		return numRemaining;
	}
	
	__DataIterType sample()
	{
		if( root == 0 ) return endData;
		
		CDFChunk *chunk = root;
		
		double r = random_double();
		
		while( chunk->isSplit() ) {
			r *= chunk->totalWeight;
			if( r < chunk->leftWeight ) {
				chunk = chunk->left;
			} else {
				r -= chunk->leftWeight;
				chunk = chunk->right;
			}
			r /= chunk->totalWeight;
		}
		double p = *chunk->preFirst + r * chunk->totalWeight;
		
		#ifdef MOTIFADE_WSWR_DEBUG
		cout << setprecision( 8 ) << fixed;
		cout << "r = " << r << "\n";
		cout << "p = " << p << "\n";
		cout << endl;
		#endif
		
		#ifdef MOTIFADE_WSWR_DEBUG
		cout << "CDFChunk tree before sampling:" << "\n{\n\n";
		printChunkTree( root );
		cout << "}\n\n";
		#endif
		
		__CDFIterType i, j;
		
		double weight;
		CDFChunk* parent = chunk->parent;
		
		if( chunk->isSingleton() ) {
			j = chunk->start;
			if( parent != 0 ) {
				CDFChunk* sibling = chunk->isLeft ? parent->right : parent->left;
				sibling->parentPtr = parent->parentPtr;
				*parent->parentPtr = sibling;
				sibling->parent = parent->parent; // FIXED 5/24/04 -- must update sibling->parentPtr and sibling-isLeft!!!
				if( sibling->parent != 0 )
					sibling->isLeft = sibling == sibling->parent->left;
				else
					sibling->isLeft = true;
			} else {
				root = 0;
				goto done;
			}
			weight = chunk->totalWeight;
		} else {
			i = upper_bound( chunk->start, chunk->end, p );
			j = i--;
			weight = *j - *i;
			
			if( j == chunk->start ) {
				chunk->removeLeftmost();
				chunk->removeFromTotal( weight );
			} else if( j == chunk->last ) {
				chunk->removeRightmost();
				chunk->removeFromTotal( weight );
			} else { // make a split
				chunk->left  = newChunk( chunk, &chunk->left, true, chunk->start, j );
				chunk->right = newChunk( chunk, &chunk->right, false, j + 1, chunk->end );
				
				chunk->leftWeight = chunk->left->totalWeight;
				chunk->removeFromTotal( weight );
			}
		}
		
		// re-weight the parents after removing the item
		while( parent != 0 ) {
			if( chunk->isLeft )
				parent->removeFromLeft( weight );
			else
				parent->removeFromRight( weight );
			
			chunk = parent;
			parent = chunk->parent;
		}
		
		done:
		
		#ifdef MOTIFADE_WSWR_DEBUG
		cout << "after sampling:" << "\n{\n\n";
		printChunkTree( root );
		cout << "}\n\n";
		#endif
		
		--numRemaining;
		return startData + ( j - startCDF );
	}
	
private:
	struct CDFChunk {
		CDFChunk() {}
		CDFChunk( CDFChunk* p, CDFChunk** pptr, bool isleft, __CDFIterType s, __CDFIterType e )
			: start( s ), end( e ), preFirst( s - 1 ), last( e - 1 ),
			  left( 0 ), right( 0 ), parent( p ), parentPtr( pptr ),
			  leftWeight( 0.0 ), totalWeight( *last - *preFirst ), isLeft( isleft ) {}
		
		bool isSplit()			{ return left != 0; }
		bool isSingleton()		{ return start == last; }
		
		void removeLeftmost()   { ++start; ++preFirst; }
		void removeRightmost()  { --end; --last; }
		
		void removeFromTotal( double weight )
		{
			totalWeight -= weight;
		}
		
		void removeFromLeft( double weight )
		{
			leftWeight -= weight;
			removeFromTotal( weight );
		}
		
		void removeFromRight( double weight )
		{
			removeFromTotal( weight );
		}
		
		__CDFIterType	start, end, preFirst, last;
		CDFChunk		*left, *right, *parent;
		CDFChunk		**parentPtr;
		double			leftWeight, totalWeight;
		bool			isLeft;
	};
	
	CDFChunk* newChunk( CDFChunk* p, CDFChunk** pptr, bool isLeft, __CDFIterType s, __CDFIterType e )
	{
		CDFChunk* result = &*nextChunk;
		*nextChunk++ = CDFChunk( p, pptr, isLeft, s, e );
		return result;
	}
	
	#ifdef MOTIFADE_WSWR_DEBUG
	void printChunk( CDFChunk* chunk )
	{
		cout << setprecision( 2 ) << fixed;
		cout << "chunk: " << chunk - chunks << endl;
		cout << "  isSplit()   = " << ( chunk->isSplit() ? "true" : "false" ) << "\n";
		cout << "  isLeft      = " << ( chunk->isLeft ? "true" : "false" ) << "\n";
		cout << "  totalWeight = " << chunk->totalWeight << "\n";
		if( chunk->isSplit() )
			cout << "  leftWeight  = " << chunk->leftWeight << endl;
		cout << "  *parentPtr  = " << ( *chunk->parentPtr == 0 ? -1 : *chunk->parentPtr - chunks ) << "\n";
		cout << "  parent      = " << ( chunk->parent == 0 ? -1 : chunk->parent - chunks ) << "\n";
		cout << "  left        = " << ( chunk->left == 0 ? -1 : chunk->left - chunks ) << "\n";
		cout << "  right       = " << ( chunk->right == 0 ? -1 : chunk->right - chunks ) << "\n";
		if( !chunk->isSplit() ) {
			cout << "  cdf         = ";
			for( __CDFIterType i = chunk->start; i != chunk->end; ++i )
				cout << *i << " ";
			cout << endl;
			cout << "  pdf         = ";
			for( __CDFIterType i = chunk->start; i != chunk->end; ++i )
				cout<< *i - *( i - 1 ) << " ";
			cout << endl;
			cout << "  x           = ";
			for( __CDFIterType i = chunk->start; i != chunk->end; ++i )
				cout << startData[ i - startCDF ] << " "; 
			cout << endl;
		}
	}
	
	void printChunkTree( CDFChunk* chunk )
	{
		if( chunk == 0 ) return;
		printChunk( chunk );
		cout << endl;
		printChunkTree( chunk->left );
		printChunkTree( chunk->right );
	}
	#endif
	
	__DataIterType		startData, endData;
	__CDFIterType		startCDF, endCDF;
	
	unsigned int		numRemaining;
	
	CDFChunk			*chunks, *nextChunk, *root;
};


template< class __DataIterType, class __CDFIterType >
const double WeightedSamplingWithoutReplacement< __DataIterType, __CDFIterType >::RANDOM_MAX = 0x7FFFFFFF;


#endif // WEIGHTED_SAMPLING_WITHOUT_REPLACEMENT_H
