/* Copyright (C) 2003-2008 Dan Arlow
 * 
 * This file is part of motifADE.
 * 
 * motifADE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * motifADE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with motifADE.  If not, see <http://www.gnu.org/licenses/>.
 */

/* 
 *  mann_whitney_u.cpp
 *
 *  Implementation of functions that perform a fast caching Mann-Whitney U-test
 */


#include "common.hpp"
#include "promoter.hpp"
#include "expression_statistics.hpp"
#include "univariate_expression_statistic_calculator.hpp"
#include "mark_set.hpp"
#include "mann_whitney_u.hpp"


using namespace std;


// MannWhitneyU constructor -- computes its two-tailed p-values
MannWhitneyU::MannWhitneyU( double u_val, double z_val, double freq_val, double med1, double med2 )
    : Statistic( false ),
	  u( u_val ),
      z( z_val ),
      freq( freq_val ),
      m1( med1 ),
      m2( med2 )
{
    
    p = 2.0 * normalcdf( z > 0 ? -z : z );
    
    p_adj = p * static_cast< double >( Statistic::NUMBER_OF_TESTS );
    if( p_adj > 1.0 ) p_adj = 1.0;
}


// MannWhitneyU method -- print the summary of the statistic
void
MannWhitneyU::print( ostream& os ) const
{
    #ifndef MOTIFADE_CRAPPY_IOMANIP
        os << showpoint << fixed;
    #endif
    
    os << setprecision( 6 );
    
    os << freq << '\t';
    
    if( error ) {
        
        os << "Err";
        
    } else {
        
        #ifndef MOTIFADE_CRAPPY_IOMANIP
            os << showpos;
        #endif
        
		if( !USE_RANKS )
			os << ( m1 - m2 ) << '\t';
        
        #ifndef MOTIFADE_CRAPPY_IOMANIP
            os << noshowpos;
        #endif
        
		cout << z << '\t';
		
        os << p;
		
		if( USE_ADJUSTED )
			os << '\t' << p_adj;
        
    }
}


// MannWhitneyU method -- print the header of the fields
void
MannWhitneyU::printHeader( ostream& os ) const
{
    os << "Frequency" << '\t';
	if( !USE_RANKS )
		os << "Delta-Median" << '\t';
	os << "Z-score" << '\t' << "P-value";
	if( USE_ADJUSTED )
		os << '\t' << "Adjusted P-value";
}



// MannWhitneyUCalculator constructor -- builds the sorted indices and ranks
MannWhitneyUCalculator::MannWhitneyUCalculator( const PromoterVector& promoters, unsigned int dimension )
    : UnivariateExpressionStatisticCalculator( promoters, dimension ),
      sortedIndex( promoters.size() ),
      ranks( promoters.size() ),
	  nullDistributionMean( promoters.size() + 1 ),
	  nullDistributionStddev( promoters.size() + 1 )
{
    buildRanks();
	computeNullDistributions();
}


// MannWhitneyUCalculator copy constructor -- copies a given MannWhitneyUCalculator, possibly excluding the ranks and sorted index cache
MannWhitneyUCalculator::MannWhitneyUCalculator( const MannWhitneyUCalculator& calc, bool copyRanks )
    : UnivariateExpressionStatisticCalculator( calc ),
	  sortedIndex( calc.getSortedIndex().size() ),
	  ranks( calc.getRanks().size() ),
	  nullDistributionMean( calc.getNullDistributionMean() ),
	  nullDistributionStddev( calc.getNullDistributionStddev() )
{
    if( copyRanks ) {
        copy( calc.getSortedIndex().begin(), calc.getSortedIndex().end(), sortedIndex.begin() );
        copy( calc.getRanks().begin(), calc.getRanks().end(), ranks.begin() );
    }
}


// MannWhitneyUCalculator method -- builds the sorted indices and ranks
void
MannWhitneyUCalculator::buildRanks()
{
    unsigned int r, i, n = expression.size();
    double rank, tied, prev;
    
    // initialize sorted rank index to 0..expression.size()-1
    generate( sortedIndex.begin(), sortedIndex.end(), Increment() );
    
    // sort indices using their expression values
    sort( sortedIndex.begin(), sortedIndex.end(), ExpressionIndexCompare( expression ) );
    
    // assign the rank positions their correct values, averaging ranks for tied data
    r = 0;
    while( r < n ) {
        rank = r + 1;
        tied = 1;
        
        // "look ahead" to accumulate tied ranks
        prev = expression[ sortedIndex[ r ] ];
        for( i = r + 1; i < n && prev == expression[ sortedIndex[ i ] ]; ++i ) {
            ++tied;
            rank += i + 1;
        }
        
        // assign tied data the mean of their ranks
        rank = rank / tied;
        for( i = 0; i < tied; ++i ) ranks[ sortedIndex[ r + i ] ] = rank; // now uses sortedIndex to un-sort ranks
        
        
        // advance r to the end of the tied ranks
        r += static_cast< unsigned int >( tied );
    }
}


// MannWhitneyUCalculator method -- shuffles the expression data and rebuilds the rank metadata
void
MannWhitneyUCalculator::shuffleExpression()
{
    UnivariateExpressionStatisticCalculator::shuffleExpression();
    buildRanks();
}


// MannWhitneyUCalculator method -- compute and return the U statistic for a given MarkSet
double
MannWhitneyUCalculator::computeRankSum( const MarkSet& marks, unsigned int& n1, unsigned int& n2 )
{
    unsigned int r, n = marks.size();
    double rank_sum;
    
    // accumulate ranks of the sorted indices of the marked expression values and calculate sample sizes
    rank_sum = 0;
    n1 = n2 = 0;
    for( r = 0; r < n; ++r ) {
        // accumulate the current rank if marked
        if( marks[ r ] ) {
            rank_sum += ranks[ r ];
            ++n1;
        }
    }
    n2 = n - n1;
	
	return rank_sum;
}

// MannWhitneyUCalculator method -- compute and store the U statistic for a given MarkSet
void
MannWhitneyUCalculator::compute( const MarkSet& marks, Statistic& statistic ) // fix this to take a MannWhitneyU&
{
    unsigned int n = expression.size(), n1, n2;
    double rank_sum, mean, stddev, contcn, z, freq;
    
    MannWhitneyU* stat = dynamic_cast< MannWhitneyU* >( &statistic );
    if( stat == 0 ) throw( MotifADEException( "MannWhitneyUCalculator::compute: not passed a MannWhitneyU pointer!" ) );
    
	rank_sum = computeRankSum( marks, n1, n2 );
	freq = static_cast< double >( n1 ) / static_cast< double >( n );
    
    // check to make sure than the samples are big enough to satisfy the assumptions of the sampling distribution, otherwise fail
    if( ( n1 < 5 ) || ( n2 < 5 ) ) {
        *stat = MannWhitneyU( freq );
        return;
    }
    
    // mean of the sampling distribution for the U statistic
    mean   = nullDistributionMean[ n1 ];
    
    // standard deviation of the sampling distriubtion of the U statistic
    stddev = nullDistributionStddev[ n1 ];
    
    // correction for continuity of the sampling distribution
    contcn = ( rank_sum > mean ? -0.5 : 0.5 );
    
    // compute Z-score for the U statistic
    z      = ( rank_sum - mean + contcn ) / stddev;
    
//    printf( "( %g - %g + %g ) / %g = %g\n", rank_sum, mean, contcn, stddev, ( rank_sum - mean + contcn ) / stddev );
    
    *stat = MannWhitneyU( rank_sum, z, freq, computeMedian( marks, true ), computeMedian( marks, false ) );
    return;
}


// terrible!
// MannWhitneyUCalculator method -- compute and return the Z-score of the U statistic for a given MarkSet
double
MannWhitneyUCalculator::computeValue( const MarkSet& marks )
{
	MannWhitneyU stat;
	compute( marks, stat );
	return stat.z;
}


// MannWhitneyUCalculator method -- compute the median of the marked set of expression values
double
MannWhitneyUCalculator::computeMedian( const MarkSet& marks, bool marked ) const
{
    unsigned int	i, n, count = marks.countMarked( marked ), counted;
    double			m = 0.0;
    
    
    counted = 0;
    n = count / 2;


    if( isOdd( count ) ) { // if an odd number of data in the marked sample, return the middle value
        
        // find the middle datum
        for( i = 0; i < marks.size(); ++i ) {
            if( marks[ sortedIndex[ i ] ] == marked ) ++counted;
            if( counted == n + 1 ) {
                m = expression[ sortedIndex[ i ] ];
                break;
            }
        }
        
    } else { // if an even number of data in the marked sample, return the average of the middle two
    
        // find the lower middle datum
        for( i = 0; i < marks.size(); ++i ) {
            if( marks[ sortedIndex[ i ] ] == marked ) ++counted;
            if( counted == n ) {
                m = expression[ sortedIndex[ i ] ];
                break;
            }
        }
        
        // find the upper middle datum
        for( i = i + 1; i < marks.size(); ++i ) {
            if( marks[ sortedIndex[ i ] ] == marked ) {
                m = ( m + expression[ sortedIndex[ i ] ] ) / 2;
                break;
            }
        }
        
    }
    
    return m;
}


// MannWhitneyUCalculator method -- compute and store parameters of the null distributions for the U statistic
void
MannWhitneyUCalculator::computeNullDistributions()
{
	unsigned int n = ranks.size();
	
	for( unsigned int n1 = 0; n1 <= n; n1++ ) {
		nullDistributionMean[ n1 ]   = n1 * ( n + 1.0 ) / 2.0;
		nullDistributionStddev[ n1 ] = sqrt( n1 * ( n - n1 ) * ( n + 1.0 ) / 12.0 );
	}
}
