/* Copyright (C) 2003-2008 Dan Arlow
 * 
 * This file is part of motifADE.
 * 
 * motifADE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * motifADE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with motifADE.  If not, see <http://www.gnu.org/licenses/>.
 */

/* 
 *  nucleotide_distribution.hpp
 *
 *  Class for columns of a "position weight matrix" motif.
 */

#include <string>
#include <vector>
#include <algorithm>
#include <functional>
#include <numeric>
#include <iostream>
#include <iterator>
#include <cmath>

#include "common.hpp"
#include "tokenizer.hpp"
#include "nucleotide_distribution.hpp"


#ifdef MOTIFADE_DUMB_FLOATING_POINT
	#define ln( x ) safeLog( x )
#else
	#define ln( x ) log( x )
#endif


const double NucleotideDistribution::TINY_VAL = 1e-100;
const double NucleotideDistribution::LOG_TINY_VAL = log( NucleotideDistribution::TINY_VAL );
const double NucleotideDistribution::EPSILON = 0.00000001;
const double NucleotideDistribution::UNITY = 1.0;
#ifdef MOTIFADE_DUMB_FLOATING_POINT
	const double NucleotideDistribution::ZERO = NucleotideDistribution::TINY_VAL;
#else
	const double NucleotideDistribution::ZERO = 0.0;
#endif
const double NucleotideDistribution::LOG_UNITY = log( NucleotideDistribution::UNITY );
const double NucleotideDistribution::LOG_ZERO = log( NucleotideDistribution::ZERO );
const double NucleotideDistribution::CONSENSUS_THRESHOLD = 0.5;


const NucleotideDistribution NucleotideDistribution::IUPAC_A = NucleotideDistribution( 1, 0, 0, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_C = NucleotideDistribution( 0, 1, 0, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_G = NucleotideDistribution( 0, 0, 1, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_T = NucleotideDistribution( 0, 0, 0, 1, true );

const NucleotideDistribution NucleotideDistribution::IUPAC_M = NucleotideDistribution( 1, 1, 0, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_R = NucleotideDistribution( 1, 0, 1, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_W = NucleotideDistribution( 1, 0, 0, 1, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_S = NucleotideDistribution( 0, 1, 1, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_Y = NucleotideDistribution( 0, 1, 0, 1, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_K = NucleotideDistribution( 0, 0, 1, 1, true );

const NucleotideDistribution NucleotideDistribution::IUPAC_V = NucleotideDistribution( 1, 1, 1, 0, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_H = NucleotideDistribution( 1, 1, 0, 1, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_D = NucleotideDistribution( 1, 0, 1, 1, true );
const NucleotideDistribution NucleotideDistribution::IUPAC_B = NucleotideDistribution( 0, 1, 1, 1, true );

const NucleotideDistribution NucleotideDistribution::IUPAC_N = NucleotideDistribution( 1, 1, 1, 1, true );


const NucleotideDistribution&
NucleotideDistribution::getIUPACSymbolDistribution( char nt )
{
	switch( nt ) {
		case 'A':
			return IUPAC_A;
		case 'C':
			return IUPAC_C;
		case 'G':
			return IUPAC_G;
		case 'T':
			return IUPAC_T;
		case 'M':
			return IUPAC_M;
		case 'R':
			return IUPAC_R;
		case 'W':
			return IUPAC_W;
		case 'S':
			return IUPAC_S;
		case 'Y':
			return IUPAC_Y;
		case 'K':
			return IUPAC_K;
		case 'V':
			return IUPAC_V;
		case 'H':
			return IUPAC_H;
		case 'D':
			return IUPAC_D;
		case 'B':
			return IUPAC_B;
		case 'N':
			return IUPAC_N;
		default:
			throw( MotifADEException( "NucleotideDistribution::getIUPACSymbolDistributiont: nt != A, C, G, T, M, R, W, S, Y, K, V, H, D, B, or N" ) );
	}
}


bool
NucleotideDistribution::isNucleotide( char nt )
{
	switch( nt ) {
		case 'A':
			return true;
		case 'C':
			return true;
		case 'G':
			return true;
		case 'T':
			return true;
		default:
			return false;
	}
}


double
NucleotideDistribution::safeLog( double value )
{
	if( value < NucleotideDistribution::TINY_VAL )
		return LOG_TINY_VAL;
	else
		return log( value );
}


NucleotideDistribution::NucleotideDistribution()
	: pA( ZERO ), pC( ZERO ), pG( ZERO ), pT( ZERO ),
	  logA( LOG_ZERO ), logC( LOG_ZERO ), logG( LOG_ZERO ), logT( LOG_ZERO )
{
	checkRep();
}

NucleotideDistribution::NucleotideDistribution( double a, double c, double g, double t )
	: pA( a ), pC( c ), pG( g ), pT( t ),
	  logA( ln( a ) ), logC( ln( c ) ), logG( ln( g ) ), logT( ln( t ) )
{
	checkRep();
}


NucleotideDistribution::NucleotideDistribution( double a, double c, double g, double t, bool normalize )
	: pA( a ), pC( c ), pG( g ), pT( t ),
	  logA( ln( a ) ), logC( ln( c ) ), logG( ln( g ) ), logT( ln( t ) )
{
	if( normalize ) {
		const_cast< NucleotideDistribution* >( this )->renormalize(); // bad! is there any way to avoid this without forcing const NucleotideDistribution objects to have their arguments normalized??
		checkNormalization();
	}
	checkRep();
}


NucleotideDistribution::NucleotideDistribution( const string& s )
	: pA( ZERO ), pC( ZERO ), pG( ZERO ), pT( ZERO ),
	  logA( LOG_ZERO ), logC( LOG_ZERO ), logG( LOG_ZERO ), logT( LOG_ZERO )
{
	const_cast< NucleotideDistribution* >( this )->set( s ); // bad! is there any way to avoid this without forcing const NucleotideDistribution objects to have their arguments normalized??
}


NucleotideDistribution::NucleotideDistribution( const string& s, bool normalize )
	: pA( ZERO ), pC( ZERO ), pG( ZERO ), pT( ZERO ),
	  logA( LOG_ZERO ), logC( LOG_ZERO ), logG( LOG_ZERO ), logT( LOG_ZERO )
{
	NucleotideDistribution* dist = const_cast< NucleotideDistribution* >( this ); // bad! is there any way to avoid this without forcing const NucleotideDistribution objects to have their arguments normalized??	
	dist->set( s );
	if( normalize ) {
		dist->renormalize();
		checkNormalization();
	}
}


NucleotideDistribution::NucleotideDistribution( const NucleotideDistribution& dist )
	: pA( dist.getA() ), pC( dist.getC() ), pG( dist.getG() ), pT( dist.getT() ),
	  logA( ln( dist.getA() ) ), logC( ln( dist.getC() ) ), logG( ln( dist.getG() ) ), logT( ln( dist.getT() ) )
{
	checkRep();
}


NucleotideDistribution::NucleotideDistribution( const NucleotideDistribution& dist, bool normalize )
	: pA( dist.getA() ), pC( dist.getC() ), pG( dist.getG() ), pT( dist.getT() ),
	  logA( ln( dist.getA() ) ), logC( ln( dist.getC() ) ), logG( ln( dist.getG() ) ), logT( ln( dist.getT() ) )
{
	if( normalize ) {
		const_cast< NucleotideDistribution* >( this )->renormalize(); // bad! is there any way to avoid this without forcing const NucleotideDistribution objects to have their arguments normalized??
		checkNormalization();
	}
	checkRep();
}


NucleotideDistribution&
NucleotideDistribution::operator=( const NucleotideDistribution& rhs )
{
	set( rhs );
	return *this;
}


bool
NucleotideDistribution::operator==( const NucleotideDistribution& rhs ) const
{
	return getA() == rhs.getA() && getC() == rhs.getC() && getG() == rhs.getG() && getT() == rhs.getT();
}


bool
NucleotideDistribution::operator!=( const NucleotideDistribution& rhs ) const
{
	return !( *this == rhs );
}


double
NucleotideDistribution::getA() const
{
	return pA;
}

double
NucleotideDistribution::getC() const
{
	return pC;
}

double
NucleotideDistribution::getG() const
{
	return pG;
}

double
NucleotideDistribution::getT() const
{
	return pT;
}

double
NucleotideDistribution::getLogA() const
{
	return logA;
}

double
NucleotideDistribution::getLogC() const
{
	return logC;
}

double
NucleotideDistribution::getLogG() const
{
	return logG;
}

double
NucleotideDistribution::getLogT() const
{
	return logT;
}

double
NucleotideDistribution::getWeight( unsigned int ntNum ) const
{
	switch( ntNum ) {
		case 0:
			return getA();
		case 1:
			return getC();
		case 2:
			return getG();
		case 3:
			return getT();
		default:
			throw( MotifADEException( "NucleotideDistribution::getWeight: ntNum != 0, 1, 2, or 3" ) );
	}
}

double
NucleotideDistribution::getLogWeight( unsigned int ntNum ) const
{
	switch( ntNum ) {
		case 0:
			return getLogA();
		case 1:
			return getLogC();
		case 2:
			return getLogG();
		case 3:
			return getLogT();
		default:
			throw( MotifADEException( "NucleotideDistribution::getLogWeight: ntNum != 0, 1, 2, or 3" ) );
	}
}

double
NucleotideDistribution::getWeight( char nt ) const
{
	switch( nt ) {
		case 'A':
			return getA();
		case 'C':
			return getC();
		case 'G':
			return getG();
		case 'T':
			return getT();
		default:
			#ifdef NUCLEOTIDE_DISTRIBUTION_THROW_ON_BAD_NT
				throw( MotifADEException( "NucleotideDistribution::getWeight: nt != A, C, G, or T" ) );
			#else
				return ZERO;
			#endif
	}
}


double
NucleotideDistribution::getLogWeight( char nt ) const
{
	switch( nt ) {
		case 'A':
			return getLogA();
		case 'C':
			return getLogC();
		case 'G':
			return getLogG();
		case 'T':
			return getLogT();
		default:
			#ifdef NUCLEOTIDE_DISTRIBUTION_THROW_ON_BAD_NT
				throw( MotifADEException( "NucleotideDistribution::getLogWeight: nt != A, C, G, or T" ) );
			#else
				return LOG_ZERO;
			#endif
	}
}


double
NucleotideDistribution::getLogMax() const
{
	return fast_max( fast_max( getLogA(), getLogC() ), fast_max( getLogG(), getLogT() ) );
}


double
NucleotideDistribution::getLogMin() const
{
	return fast_min( fast_min( getLogA(), getLogC() ), fast_min( getLogG(), getLogT() ) );
}


double
NucleotideDistribution::getEntropy() const
{
	checkNormalization();
	double entropy = 0;
	if( getA() != 0 ) entropy -= getA() * getLogA();
	if( getC() != 0 ) entropy -= getC() * getLogC();
	if( getG() != 0 ) entropy -= getG() * getLogG();
	if( getT() != 0 ) entropy -= getT() * getLogT();
	return entropy;
}


double
NucleotideDistribution::getInformationContent() const
{
	return 2.0 - getEntropy();
}


/* Uses convention from Dhaeseleer Nat. Bio. Tech 4 (24) 2006 423-425
 * "Conventionally, a single base is shown if it occurs in more than half
 * the sites and at least twice as often as the second most frequent base.
 * Otherwise, a doubledegenerate symbol is used if two bases occur in more
 * than 75% of the sites, or a tripledegenerate symbol when one base does
 * not occur at all."
 * Rather than require a base be absent to use a tripledegenerate symbol,
 * here we require it to occur in less than 5% of all sites.
 */
char
NucleotideDistribution::getConsensusIUPACSymbol() const
{
	double a = getA(), c = getC(), g = getG(), t = getT();
	renormalize( a, c, g, t );
	
	double wts[] = { a, c, g, t };
	unsigned int sorting_perm[] = { 0, 1, 2, 3 };
	sort( sorting_perm, sorting_perm + 4, SortingPermutationCmp< double[ 4 ] >( wts ) );
	
	double biggest = wts[ sorting_perm[ 3 ] ], secondBiggest = wts[ sorting_perm[ 2 ] ];
	
	if( biggest >= 0.5 && biggest >= 2 * secondBiggest )
		switch( sorting_perm[ 3 ] ) {
			case 0: return 'A';
			case 1: return 'C';
			case 2: return 'G';
			case 3: return 'T';
			default:
				throw( MotifADEException( "NucleotideDistribution::getConsensusIUPACSymbol: reached default case of switch statement!" ) );
		}
	
	else if( a + c > 0.75 )
		return 'M';
	else if( a + g > 0.75 )
		return 'R';
	else if( a + t > 0.75 )
		return 'W';
	else if( c + g > 0.75 )
		return 'S';
	else if( c + t > 0.75 )
		return 'Y';
	else if( g + t > 0.75 )
		return 'K';
	else if( t < 0.05 )
		return 'V';
	else if( g < 0.05 )
		return 'H';
	else if( c < 0.05 )
		return 'D';
	else if( a < 0.05 )
		return 'B';
	else
		return 'N';
}


bool
NucleotideDistribution::isNormalized() const
{
	return fast_abs( pA + pC + pG + pT - UNITY ) <= EPSILON;
}


void
NucleotideDistribution::set( const NucleotideDistribution& dist )
{
	set( dist.getA(), dist.getC(), dist.getG(), dist.getT() );
}


void
NucleotideDistribution::set( double a, double c, double g, double t )
{
	setA( a );
	setC( c );
	setG( g );
	setT( t );
	resetLogWeights();
}


void
NucleotideDistribution::set( const string& rawDist )
{
	const unsigned int fields[] = { 0, 1, 2, 3 };
	string rslt[ 4 ];
	getCSVFields( fields, fields + 4, rslt, rawDist, '\t' );
	set( atof( rslt[ 0 ].c_str() ), atof( rslt[ 1 ].c_str() ), atof( rslt[ 2 ].c_str() ), atof( rslt[ 3 ].c_str() ) );
}


void
NucleotideDistribution::setA( double value )
{
//	checkWeight( value );
	pA = value;
	logA = ln( value );
	checkRep();
}


void
NucleotideDistribution::setC( double value )
{
//	checkWeight( value );
	pC = value;
	logC = ln( value );
	checkRep();
}


void
NucleotideDistribution::setG( double value )
{
//	checkWeight( value );
	pG = value;
	logG = ln( value );
	checkRep();
}


void
NucleotideDistribution::setT( double value )
{
//	checkWeight( value );
	pT = value;
	logT = ln( value );
	checkRep();
}


void
NucleotideDistribution::setLogA( double value )
{
	checkLogWeight( value );
	logA = value;
	pA = exp( value );
	checkRep();
}


void
NucleotideDistribution::setLogC( double value )
{
	checkLogWeight( value );
	logC = value;
	pC = exp( value );
	checkRep();

}


void
NucleotideDistribution::setLogG( double value )
{
	checkLogWeight( value );
	logG = value;
	pG = exp( value );
	checkRep();
}


void
NucleotideDistribution::setLogT( double value )
{
	checkLogWeight( value );
	logT = value;
	pT = exp( value );
	checkRep();
}

void
NucleotideDistribution::setWeight( unsigned int ntNum, double value )
{
	switch( ntNum ) {
		case 0:
			return setA( value );
		case 1:
			return setC( value );
		case 2:
			return setG( value );
		case 3:
			return setT( value );
		default:
			throw( MotifADEException( "NucleotideDistribution::setWeight: ntNum != 0, 1, 2, or 3" ) );
	}
}

void
NucleotideDistribution::setLogWeight( unsigned int ntNum, double value )
{
	switch( ntNum ) {
		case 0:
			return setLogA( value );
		case 1:
			return setLogC( value );
		case 2:
			return setLogG( value );
		case 3:
			return setLogT( value );
		default:
			throw( MotifADEException( "NucleotideDistribution::setLogWeight: ntNum != 0, 1, 2, or 3" ) );
	}
}

void
NucleotideDistribution::setWeight( char nt, double value )
{
	switch( nt ) {
		case 'A':
			return setA( value );
		case 'C':
			return setC( value );
		case 'G':
			return setG( value );
		case 'T':
			return setT( value );
		default:
			throw( MotifADEException( "NucleotideDistribution::setWeight: nt != A, C, G, or T" ) );
	}
}

void
NucleotideDistribution::setLogWeight( char nt, double value )
{
	switch( nt ) {
		case 'A':
			return setLogA( value );
		case 'C':
			return setLogC( value );
		case 'G':
			return setLogG( value );
		case 'T':
			return setLogT( value );
		default:
			throw( MotifADEException( "NucleotideDistribution::setLogWeight: nt != A, C, G, or T" ) );
	}
}


void
NucleotideDistribution::renormalize( double& a, double& c, double& g, double& t ) const
{
	double total = a + c + g + t;
	a /= total;
	c /= total;
	g /= total;
	t /= total;
}


void
NucleotideDistribution::resetLogWeights()
{
	logA = ln( getA() );
	logC = ln( getC() );
	logG = ln( getG() );
	logT = ln( getT() );
}


void
NucleotideDistribution::renormalize()
{
	renormalize( pA, pC, pG, pT );
	resetLogWeights();
}


void
NucleotideDistribution::complement()
{
	fast_swap( pA, pT );
	fast_swap( pC, pG );
	fast_swap( logA, logT );
	fast_swap( logC, logG );
}


void
NucleotideDistribution::checkWeight( double value ) const
{
//	if( value < ZERO )
//		throw( MotifADEException( "NucleotideDistribution::checkWeight: value < 0" ) );
//	if( value > UNITY )
//		throw( MotifADEException( "NucleotideDistribution::checkWeight: value > 1" ) );
}


void
NucleotideDistribution::checkLogWeight( double value ) const
{
//	if( value > LOG_UNITY )
//		throw( MotifADEException( "NucleotideDistribution::checkLogWeight: value > ln( 1 )" ) );
}


void
NucleotideDistribution::checkRep() const
{
	if( fast_abs( logA - ln( pA ) ) > EPSILON )
		throw( MotifADEException( "NucleotideDistribution::checkRep: | logA - ln( pA ) | > EPSILON" ) );
	if( fast_abs( logC - ln( pC ) ) > EPSILON )
		throw( MotifADEException( "NucleotideDistribution::checkRep: | logC - ln( pC ) | > EPSILON" ) );
	if( fast_abs( logG - ln( pG ) ) > EPSILON )
		throw( MotifADEException( "NucleotideDistribution::checkRep: | logG - ln( pG ) | > EPSILON" ) );
	if( fast_abs( logT - ln( pT ) ) > EPSILON )
		throw( MotifADEException( "NucleotideDistribution::checkRep: | logT - ln( pT ) | > EPSILON" ) );
}


void
NucleotideDistribution::checkNormalization() const
{
	if( !isNormalized() )
		throw( MotifADEException( "NucleotideDistribution::checkNormalization: | pA + pC + pG + pT - UNITY | > EPSILON" ) );
	
	if( fast_abs( exp( logA ) + exp( logC ) + exp( logG ) + exp( logT ) - UNITY ) > EPSILON )
		throw( MotifADEException( "NucleotideDistribution::checkNormalization: | exp( logA ) + exp( logC ) + exp( logG ) + exp( logT ) - UNITY | > EPSILON" ) );	
}


ostream& operator<<( ostream& os, const NucleotideDistribution& dist )
{
	os << "consensus: " << dist.getConsensusIUPACSymbol() << " ";
	os << "{ A: " << dist.getA() << ", " << "C: " << dist.getC() << ", " << "G: " << dist.getG() << ", " << "T: " << dist.getT() << ", ";
	os << "logA: " << dist.getLogA() << ", " << "logC: " << dist.getLogC() << ", " << "logG: " << dist.getLogG() << ", " << "logT: " << dist.getLogT() << " }";
	return os;
}
