/* Copyright (C) 2003-2008 Dan Arlow
 * 
 * This file is part of motifADE.
 * 
 * motifADE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * motifADE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with motifADE.  If not, see <http://www.gnu.org/licenses/>.
 */

/* 
 *  hypergeometric_gene_set_motif_enrichment.cpp
 *
 *  Class for comparing the incidence of motifs in a given gene set to
 *  background using a hypergeometric null model.
 */


#include <cmath>

#include "common.hpp"
#include "promoter.hpp"
#include "expression_statistics.hpp"
#include "mark_set.hpp"
#include "gene_set_motif_enrichment_base.hpp"
#include "hypergeometric_gene_set_motif_enrichment.hpp"


HypergeometricGeneSetEnrichment::HypergeometricGeneSetEnrichment( double log_p_init, unsigned int num_inside, unsigned int inside_size, unsigned int num_outside, unsigned int outside_size )
	: GeneSetEnrichmentBase( num_inside, inside_size, num_outside, outside_size ),
	  log_p( log_p_init )
{
	if( numTotal == 0 ) {
		error = true;
	} else {
		p = exp( log_p );
		if( p > 1.0 ) p = 1.0;
		p_adj = exp( log_p + log( double( Statistic::NUMBER_OF_TESTS ) ) );
//		p_adj = p * static_cast< double >( Statistic::NUMBER_OF_TESTS );
		if( p_adj > 1.0 ) p_adj = 1.0;
	}
}


void
HypergeometricGeneSetEnrichment::print( ostream& os ) const
{
	os << freq << '\t' << freqInside << '\t' << freqOutside;
	if( error ) {
		os << '\t' << "Err";
		if( USE_ADJUSTED ) os << '\t' << "Err";
	} else {
		os << '\t' << p;
		if( USE_ADJUSTED ) os << '\t' << p_adj;
	}
}


void
HypergeometricGeneSetEnrichment::printHeader( ostream& os ) const
{
    os << "Frequency" << '\t' << "Frequency In Set" << '\t' << "Frequency In Rest" << '\t' << "P-value";
	if( USE_ADJUSTED )
		os << '\t' << "Adjusted P-value";
}


HypergeometricGeneSetEnrichmentCalculator::HypergeometricGeneSetEnrichmentCalculator( const PromoterVector& promoters, const MarkSet& gene_set )
	: GeneSetEnrichmentBaseCalculator( promoters, gene_set ),
	  logNFactorial( promoters.size() + 1 )
{
	initNFactorial();
}


HypergeometricGeneSetEnrichmentCalculator::HypergeometricGeneSetEnrichmentCalculator( const HypergeometricGeneSetEnrichmentCalculator& calc )
	: GeneSetEnrichmentBaseCalculator( calc ),
	  logNFactorial( calc.logNFactorial )
{
}


void
HypergeometricGeneSetEnrichmentCalculator::compute( const MarkSet& marks, Statistic& statistic ) // fix this to take a HypergeometricGeneSetEnrichment&
{
    HypergeometricGeneSetEnrichment* stat = dynamic_cast< HypergeometricGeneSetEnrichment* >( &statistic );
    if( stat == 0 ) throw( MotifADEException( "BinomialGeneSetEnrichmentCalculator::compute: not passed a HypergeometricGeneSetEnrichment pointer!" ) );
	
	if( marks.size() != getTotalSize() )
		throw( MotifADEException( "BinomialGeneSetEnrichmentCalculator::compute: marks.size() != getTotalSize()" ) );
	
	unsigned int numInside = 0, numOutside = 0;
	
	countIntersection( marks, numInside, numOutside );
	
	unsigned int numTotal = numInside + numOutside;
	
	double log_p_lower = logHypergeometricCDF( numInside, geneSetSize, numTotal, totalSize, true );
	double log_p_upper = logHypergeometricCDF( numInside, geneSetSize, numTotal, totalSize, false );
	double log_p = fast_min( log_p_lower, log_p_upper ) + LOG2; // two-tailed test; p = min( p_lower, p_upper ) * 2.0
	
	*stat = HypergeometricGeneSetEnrichment( log_p, numInside, geneSetSize, numOutside, restSize );
}


double
HypergeometricGeneSetEnrichmentCalculator::computeValue( const MarkSet& marks )
{
	HypergeometricGeneSetEnrichment stat;
	compute( marks, stat );
	return stat.p;
}


void
HypergeometricGeneSetEnrichmentCalculator::initNFactorial()
{
	logNFactorial[ 0 ] = 0.0;
	for( unsigned int i = 1; i < logNFactorial.size(); ++i )
		logNFactorial[ i ] = logNFactorial[ i - 1 ] + log( double( i ) );
}


double
HypergeometricGeneSetEnrichmentCalculator::logFactorial( unsigned int n ) const
{
	if( n >= logNFactorial.size() )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logFactorial: n > dynamic programming table size" ) );
	
	return logNFactorial[ n ];
}


double
HypergeometricGeneSetEnrichmentCalculator::logNChooseK( unsigned int n, unsigned int k ) const
{
	if( k > n )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logNChooseK: k > n" ) );
	
	return logFactorial( n ) - logFactorial( k ) - logFactorial( n - k );
}


double
HypergeometricGeneSetEnrichmentCalculator::logHypergeometricPDF( unsigned int numDrawn, unsigned int sampleSize, unsigned int numPop, unsigned int popSize ) const
{
	if( numDrawn > sampleSize )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logHypergeometricDist: numDrawn > sampleSize" ) );
	
	if( numPop > popSize )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logHypergeometricDist: numPop > popSize" ) );	

	if( sampleSize > popSize )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logHypergeometricDist: sampleSize > popSize" ) );
	
	if( numDrawn > numPop )
		throw( MotifADEException( "HypergeometricGeneSetEnrichmentCalculator::logHypergeometricDist: numDrawn > numPop" ) );
	
	return logNChooseK( numPop, numDrawn ) + logNChooseK( popSize - numPop, sampleSize - numDrawn ) - logNChooseK( popSize, sampleSize );
}

double
HypergeometricGeneSetEnrichmentCalculator::logHypergeometricCDF( unsigned int numDrawn, unsigned int sampleSize, unsigned int numPop, unsigned int popSize, bool lower_tail )
{	
	cdfTemp.clear();
	
	unsigned int lb, ub;
	
	if( lower_tail ) {
		unsigned int restSize = popSize - sampleSize;
		if( numPop > restSize ) // fixed a bug here where lb = 0 was inappropriate
			lb = numPop - restSize;
		else
			lb = 0;
		ub = fast_min( sampleSize, numPop );
	} else { // upper tail
		lb = numDrawn;
		ub = fast_min( sampleSize, numPop );
	}
	
	for( unsigned int i = lb; i <= ub; ++i )
		cdfTemp.push_back( logHypergeometricPDF( i, sampleSize, numPop, popSize ) );
	
//	cout << "lower_tail = " << lower_tail << endl;
//	cout << "logHypergeometricCDF( " << numDrawn << ", " << sampleSize << ", " << numPop << ", " << popSize << " ) = " << addLogs( cdfTemp ) << endl;
//	cout << "   HypergeometricCDF( " << numDrawn << ", " << sampleSize << ", " << numPop << ", " << popSize << " ) = " << exp( addLogs( cdfTemp ) ) << endl;
	
	return addLogs( cdfTemp );
}
