/* Copyright (C) 2003-2008 Dan Arlow
 * 
 * This file is part of motifADE.
 * 
 * motifADE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * motifADE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with motifADE.  If not, see <http://www.gnu.org/licenses/>.
 */

/* 
 *  pwm_motif.cpp
 *
 *  Class for storing "position weight matrix" motifs.
 */


#include <string>
#include <vector>
#include <algorithm>
#include <functional>
#include <numeric>
#include <iostream>
#include <iterator>
#include <cmath>


#include "common.hpp"
#include "id_object.hpp"
#include "buffered_reader.hpp"
#include "sequence.hpp"
#include "tokenizer.hpp"
#include "pwm_motif.hpp"


PWMMotif::PWMMotif()
	: name( "" ),
	  threshold( 1 ),
	  logThreshold( 0 )
{
	// nothing to do here?
}


PWMMotif::PWMMotif( const PWMMotif& rhs )
	: name( rhs.name ),
	  threshold( rhs.threshold ),
	  logThreshold( rhs.logThreshold ),
	  weights( rhs.weights )
{

}


PWMMotif::PWMMotif( const string& init_name )
	: name( init_name ),
	  threshold( 1 ),
	  logThreshold( 0 )
{
	// nothing to do here?
}


PWMMotif::PWMMotif( const string& init_name, const string& init_IUPAC_pattern, double init_threshold )
	: name( init_name ),
	  threshold( init_threshold ),
	  logThreshold( log( threshold ) )
{
	setToIUPACMotif( init_IUPAC_pattern );
}



PWMMotif&
PWMMotif::operator=( const PWMMotif& rhs )
{
	set( rhs );
	return *this;
}


bool
PWMMotif::operator==( const PWMMotif& rhs ) const
{
	return weights == rhs.weights && threshold == rhs.threshold;
}


bool
PWMMotif::operator!=( const PWMMotif& rhs ) const
{
	return !( *this == rhs );
}


void
PWMMotif::set( const PWMMotif& rhs )
{
	weights = rhs.weights;
	name = rhs.name;
	threshold = rhs.threshold;
	logThreshold = rhs.logThreshold;
}


void
PWMMotif::setToIUPACMotif( const string& IUPAC_pattern )
{
	resize( IUPAC_pattern.size() );
	for( unsigned int i = 0; i < size(); ++i )
		setWeights( i, NucleotideDistribution::getIUPACSymbolDistribution( IUPAC_pattern[ i ] ) );
}


const string&
PWMMotif::getName() const
{
	return name;
}


unsigned int
PWMMotif::size() const
{
	return weights.size();
}


NucleotideDistribution&
PWMMotif::getWeights( unsigned int pos )
{
	if( pos >= size() )
		throw( MotifADEException( "PWMMotif::getWeights: pos > size()" ) );
	
	return weights[ pos ];
}


const NucleotideDistribution&
PWMMotif::getWeights( unsigned int pos ) const
{
	if( pos >= size() )
		throw( MotifADEException( "PWMMotif::getWeights: pos > size()" ) );
	
	return weights[ pos ];
}


double
PWMMotif::getThreshold() const
{
	return threshold;
}


double
PWMMotif::getLogThreshold() const
{
	return logThreshold;
}


double
PWMMotif::getMaxScore() const
{
	return exp( getMaxLogScore() );
}


double
PWMMotif::getMinScore() const
{
	return exp( getMinLogScore() );
}


double
PWMMotif::getMaxLogScore() const
{
	if( size() == 0 )
		throw( MotifADEException( "PWMMotif::getMaxLogScore: size() == 0" ) );
	
	double score = 0;
	for( unsigned int i = 0; i < size(); ++i )
		score += getWeights( i ).getLogMax();
	return score;
}


double
PWMMotif::getMinLogScore() const
{
	if( size() == 0 )
		throw( MotifADEException( "PWMMotif::getMinLogScore: size() == 0" ) );
	
	double score = 0;
	for( unsigned int i = 0; i < size(); ++i )
		score += getWeights( i ).getLogMin();
	return score;
}


void
PWMMotif::setName( const string& newName )
{
	name = newName;
}


void
PWMMotif::resize( unsigned int newSize )
{
	if( newSize > 1000 )
		throw( MotifADEException( "PWMMotif::resize: newSize > 1000 is not allowed." ) );
	
	weights.resize( newSize );
}


void
PWMMotif::append( const NucleotideDistribution& dist )
{
	weights.push_back( dist );
}


void
PWMMotif::setWeights( unsigned int pos, const NucleotideDistribution& dist )
{
	if( pos >= size() )
		throw( MotifADEException( "PWMMotif::setWeights: pos > size()" ) );
	
	weights[ pos ].set( dist );
}


void
PWMMotif::setThreshold( double value )
{
	if( value < 0 )
		throw( MotifADEException( "PWMMotif::setThreshold: value < 0" ) );
	threshold = value;
	logThreshold = log( value );
}


void
PWMMotif::setLogThreshold( double value )
{
	logThreshold = value;
	threshold = exp( value );
}


void
PWMMotif::complement()
{
	for( vector< NucleotideDistribution >::iterator iter = weights.begin(); iter != weights.end(); ++iter )
		iter->complement();
}


void
PWMMotif::reverse()
{
	std::reverse( weights.begin(), weights.end() );
}


void
PWMMotif::reverseComplement()
{
	reverse();
	complement();
}


// matching
bool
PWMMotif::isValidSequence( Sequence::DataType::const_iterator iter ) const
{
	for( unsigned int i = 0; i < size(); ++i )
		if( !NucleotideDistribution::isNucleotide( *iter++ ) )
			return false;
	return true;
}

bool
PWMMotif::isValidSequence( const Sequence& seq, Sequence::DataType::size_type pos ) const
{
	return isValidSequence( seq.getData().begin() + pos );
}


double
PWMMotif::computeScore( Sequence::DataType::const_iterator iter ) const
{
	return exp( computeLogScore( iter ) );
}


double
PWMMotif::computeLogScore( Sequence::DataType::const_iterator iter ) const
{
	double logScore = 0;
	for( unsigned int i = 0; i < size(); ++i )
		logScore += weights[ i ].getLogWeight( *iter++ );
	return logScore;
}


double
PWMMotif::computeScore( const Sequence& seq, Sequence::DataType::size_type pos ) const
{
	return computeScore( seq.getData().begin() + pos );
}


double
PWMMotif::computeLogScore( const Sequence& seq, Sequence::DataType::size_type pos ) const
{
	return computeLogScore( seq.getData().begin() + pos );
}


ostream& operator<<( ostream& os, const PWMMotif& m )
{
	os << "name: " << m.getName() << " {" << endl;
	os << "\tA\tC\tG\tT\tConsensus" << endl;
	for( unsigned int i = 0; i < m.size(); ++i ) {
		const NucleotideDistribution& d = m.getWeights( i );
		os << '\t' << d.getA() << '\t' << d.getC() << '\t' << d.getG() << '\t' << d.getT() << '\t' << d.getConsensusIUPACSymbol() << endl;
	}
	os << "} " << " threshold = " << m.getThreshold() << ", logThreshold = " << m.getLogThreshold() << endl;
	return os;
}


BufferedReader& operator>>( BufferedReader& is, PWMMotif& m )
{
	if( !is ) throw( MotifADEException( "operator>>( BufferedReader&, PWMMotif& ): couldn't read from stream!" ) );
	
	m.resize( 0 );
	
	string buf, motifName, motifLogThreshold;
	is.getline( buf, '\n' );
	
	split( buf, motifName, motifLogThreshold );
	
	if( motifName[ 0 ] != '>' )
		throw( MotifADEException( "operator>>( BufferedReader&, PWMMotif& ): motif header does not begin with '>'!" ) );
	
	motifName.erase( motifName.begin() );
	
	m.setName( motifName );
	
	m.setLogThreshold( atof( motifLogThreshold.c_str() ) );
	
	is.getline( buf, '\n' );
	
	if( buf.substr( 0, 7 ) != "A\tC\tG\tT" )
		throw( MotifADEException( "operator>>( BufferedReader&, PWMMotif& ): motif header does not have A\tC\tG\tT line!" ) );
	
	while( !is.eof() && is.peek() != '>' ) {
		is.getline( buf, '\n' );
		if( buf == "" ) continue;
		m.append( NucleotideDistribution( buf, true ) );
	}
	
	return is;
}
