package org.broadinstitute.cga.tools.seq;

import net.sf.samtools.*;
import net.sf.samtools.util.CloseableIterator;
import java.io.*;
import java.lang.*;
import java.util.*;
import java.lang.reflect.*;

import org.broadinstitute.cga.tools.seq.Read;
import org.broadinstitute.cga.tools.seq.Reference;

public class BamGrasp {
    
  private static final boolean DEBUG = false;

  private static int MAX_READLENGTH = 2000;   // to accommodate 1KG long reads

  private static final String default_refdir = "/xchip/tcga/gbm/analysis/lawrence/genome/hg18";
  public Reference ref = null;

  public SAMFileReader reader = null;
  public Read read = new Read();

  private Boolean fileOpen = false;
  private Boolean quietMode = false;
  private int maxReads = -1;

  private Map<Integer, Integer> dictSeqIndex2Chr = new HashMap<Integer, Integer>();
  // key = SequenceIndex; value = chr (0=M, 1-22, 23=X, 24=Y)
  private Map<Integer, String> dictSeqChr2Name = new HashMap<Integer, String>();
  // key = chr (0=M, 1-22, 23=X, 24=Y); value = SequenceName

  private Map<String,Integer> readGroupNum = new HashMap<String,Integer>();
  private Map<String,String> readGroupName = new HashMap<String,String>();
  private Map<String,Boolean> readGroupIsBlacklisted = new HashMap<String,Boolean>();
  private Integer nReadGroups = null;

  public void set_maxReads(int m) {
      maxReads = m;
  }

  public int get_maxReads() {
      return(maxReads);
  }
    
  public String getChrName(int chr) {
      return (dictSeqChr2Name.get(chr));
  }

  public int getChrFromIndex(Integer index) {
      return (dictSeqIndex2Chr.get(index));
  }

  public Boolean isReadGroupBlacklisted(String index) {
      return (readGroupIsBlacklisted.get(index));
  }

  public Integer getReadGroupNum(String index) {
      return (readGroupNum.get(index));
  }

  public String getReadGroupName(String index) {
      return (readGroupName.get(index));
  }

  public Integer getNumReadGroups() {
      return(nReadGroups);
  }

  public BamGrasp(String bamname, String blacklistname, String refdir) throws Exception {
      openFile(bamname,blacklistname,refdir);
  }

  public BamGrasp(String bamname, String blacklistname) throws Exception {
      openFile(bamname,blacklistname);
  }

  public BamGrasp(String bamname) throws Exception {
      openFile(bamname,"none");
  }

  public BamGrasp() throws Exception {
  }

  public void openFile(String bamname) throws Exception {
      openFile(bamname,"none");
  }

  public void openFile(String bamname, String blacklistname) throws Exception {
      System.out.println("Using default refdir");
      openFile(bamname,blacklistname,default_refdir);
  }

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

  public void openFile(String bamname, String blacklistname, String refdir) throws Exception {	

      if (DEBUG) System.out.println("BamGrasp: openFile");

      SAMFileReader.setDefaultValidationStringency(SAMFileReader.ValidationStringency.SILENT);
      reader = new SAMFileReader(new File(bamname));
      ref = new Reference(refdir);

      if (DEBUG) System.out.println("BamGrasp: read blacklist");

      // read blacklist

      List<String> blacklist = new ArrayList<String>(1000);
      boolean is_gsa_format = false;
      // gsa-formatted blacklists have 'PU:<full_name>' where <full_name> is the full (long) pu name
      // exactly as listed in the bam file header

      if (!blacklistname.equalsIgnoreCase("none")) {
	    try {
	        BufferedReader bl = new BufferedReader(new FileReader(blacklistname));

            String s = bl.readLine();

            if ( s != null ) {
                if ( s.startsWith("PU:") ) {
                    is_gsa_format = true;
                    System.out.println("GATK-formatted black list is detected");
//                    System.out.println("sample: "+s.substring(3));
                }

                while ( s != null ) {
                    if ( s.matches("^\\s*$")) { s = bl.readLine(); continue; } // skip blank lines
                    if ( is_gsa_format ) blacklist.add(s.substring(3));
                    else blacklist.add(s);
                    s = bl.readLine();
                }
	            bl.close();
            }
	    } catch (Exception e) {
	        throw new Exception("Could not read blacklist "+blacklistname+": "+e.getMessage());
	    }
      }

      if (DEBUG) System.out.println("BamGrasp: readgroups");
      int sz_blacklist = blacklist.size();

      // get list of readgroups (FC+lanes)

      if (DEBUG) {
	  Class cls = Class.forName("net.sf.samtools.SAMReadGroupRecord");
	  Method methlist[] = cls.getDeclaredMethods();
	  System.out.println(methlist.length + " methods:");
	  for (Method mm : methlist) {
	      System.out.println("METHOD: "+mm);
	  }
      }

      nReadGroups = new Integer(0);               // ZERO-BASED !
      Boolean any_blacklisted = false;
      for (final SAMReadGroupRecord rg : reader.getFileHeader().getReadGroups()) {
	  String pu = null;
	  try {
	      if (DEBUG) System.out.print("trying getAttribute: ");
	      pu = rg.getAttribute("PU").toString();
              if (DEBUG) System.out.print("succeeded\n");
	  } catch (Throwable e) {
              if (DEBUG) System.out.print("failed: ");
	      try {
		  if (DEBUG) System.out.print("trying getReadGroupId: ");
		  pu = rg.getReadGroupId();
		  if (DEBUG) System.out.print("succeeded\n");
	      } catch (Exception e2) {
		  if (DEBUG) System.out.print("failed\n");
		  System.out.println("Problem reading list of readgroups in BAM file header");
		  continue;
	      }
	  }
	  if (DEBUG) System.out.println("BamGrasp: readgroups: "+pu);
          readGroupNum.put(rg.getReadGroupId(),nReadGroups++);
          readGroupName.put(rg.getReadGroupId(),pu);
          Boolean isb = false;
          for ( int i=0; i<sz_blacklist; i++ ) {

              // if we have old style blacklist then match = first 5 and last 2 characters;
              // otherwise match 1-to-1

              final String blpu = blacklist.get(i);

	          if (! is_gsa_format ) {
                  // in gsa format we list rg exactly as it appears in the bam header so we do not care
                  // about the length, otherwise, too bad...
                  if( pu.length()<7) { continue; }

	              if (blpu.length()<7) { continue; }
                  if ( pu.substring(0,5).equalsIgnoreCase(blpu.substring(0,5)) &&
                       pu.substring(pu.length()-2).equalsIgnoreCase(blpu.substring(blpu.length()-2))
                      ) {
                      isb = true;
                      if (!quietMode) { System.out.println("Suppressing blacklisted lane "+pu); }
                      any_blacklisted = true;
                      break;
                  }
              } else {
                  // gsa format: look for exact match
                  if ( pu.equalsIgnoreCase(blpu) ) {
                      isb = true;
                      if (!quietMode) { System.out.println("Suppressing blacklisted lane "+pu); }
                      any_blacklisted = true;
                      break;                      
                  }
              }
          }
          readGroupIsBlacklisted.put(rg.getReadGroupId(),isb);
      }
      if (!any_blacklisted) {
          if (sz_blacklist>0) {
              if (!quietMode) { System.out.println("No blacklisted lanes in this BAM."); }
          } else {
	      if (!quietMode) { System.out.println("No blacklist."); }
          }
      }

      if (DEBUG) System.out.println("BamGrasp: chr fmt");

      // determine chromosome naming format

      for (final SAMSequenceRecord s : reader.getFileHeader().getSequenceDictionary().getSequences()) {
          String name = s.getSequenceName();
          int index = s.getSequenceIndex();
          int chr = -40;
          String name2 = name.toUpperCase();
          if (name2.length()>3 && name2.substring(0,3).equals("CHR")) { name2 = name2.substring(3); }
          if (name2.equals("M") || name2.equals("MT")) { chr = 0; }
          else if (name2.equals("X")) { chr = 23; }
          else if (name2.equals("Y")) { chr = 24; }
          else {
              try { chr = Integer.parseInt(name2); }
              catch (Exception e) { chr = -40; }
          }
          if (dictSeqChr2Name.containsKey(chr)) {
              throw new Exception("Duplicate dictionary entries identified for chr="+chr);
          }
          dictSeqIndex2Chr.put(index,chr);
          if (chr!=-40) { dictSeqChr2Name.put(chr,name); }
      }
      // make sure all 24 chromosomes are there
      //      for (int chr=1;chr<=24;chr++) {
      //    if (!dictSeqChr2Name.containsKey(chr)) {
      //        throw new Exception("Could not identify dictionary entry corresponding to chr="+chr);
      //    }
      //}

      fileOpen = true;

  }   // end openFile()

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

  public void closeFile() {
      close();
  }

  public void close() {
      if (fileOpen) {
        reader.close();
	fileOpen = false;
      }
  }

  public void finalize() throws Throwable {
      closeFile();
      super.finalize();
  }

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

  public void setQuietModeOn() {
	quietMode = true;
  }
  public void setQuietModeOff() {
        quietMode = false;
  }

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

  // storage of reads
  private int[] readgroup = null;
  private int[] namenumber = null;
  private int[] whichpairmate = null;
  private int[] readstart = null;   
  private int[] readend = null;
  private int[] readstrand = null;
  private int[] readmapqual = null;
  private int[] baseindex = null;
  private int[] nmismatches = null;
  private int[] pairmatechr = null;
  private int[] pairmatestart = null;  
  private int[] pairmatestrand = null;
  private String[] readname = null;
 
  private int numReads = 0;
  private int numReadsAllocated = 0;
  private int readsGrowStep = 100000;

  private void growReads(int newsize) {
      Runtime.getRuntime().gc();
      //System.out.print("  Max Memory "+Runtime.getRuntime().maxMemory());
      //System.out.print("  Total Memory "+Runtime.getRuntime().totalMemory());
      //System.out.print("  Free Memory "+Runtime.getRuntime().freeMemory());
      //System.out.print("\n");
      //System.out.println("Allocating reads to " + newsize);
      int[] tmp = null;
      String[] stmp = null;
      tmp = new int[newsize]; if (readgroup != null) { System.arraycopy(readgroup,0,tmp,0,numReads); } readgroup = tmp;
      tmp = new int[newsize]; if (namenumber != null) { System.arraycopy(namenumber,0,tmp,0,numReads); } namenumber = tmp;
      tmp = new int[newsize]; if (whichpairmate != null) { System.arraycopy(whichpairmate,0,tmp,0,numReads); } whichpairmate = tmp;
      tmp = new int[newsize]; if (readstart!= null) { System.arraycopy(readstart,0,tmp,0,numReads); } readstart = tmp;
      tmp = new int[newsize]; if (readend != null) { System.arraycopy(readend,0,tmp,0,numReads); } readend = tmp;
      tmp = new int[newsize]; if (readstrand != null) { System.arraycopy(readstrand,0,tmp,0,numReads); } readstrand = tmp;
      tmp = new int[newsize]; if (readmapqual != null) { System.arraycopy(readmapqual,0,tmp,0,numReads); } readmapqual = tmp;
      tmp = new int[newsize]; if (baseindex != null) { System.arraycopy(baseindex,0,tmp,0,numReads); } baseindex = tmp;
      tmp = new int[newsize]; if (nmismatches != null) { System.arraycopy(nmismatches,0,tmp,0,numReads); } nmismatches = tmp;
      tmp = new int[newsize]; if (pairmatechr != null) { System.arraycopy(pairmatechr,0,tmp,0,numReads); } pairmatechr = tmp;
      tmp = new int[newsize]; if (pairmatestart != null) { System.arraycopy(pairmatestart,0,tmp,0,numReads); } pairmatestart = tmp;
      tmp = new int[newsize]; if (pairmatestrand != null) { System.arraycopy(pairmatestrand,0,tmp,0,numReads); } pairmatestrand = tmp;
      stmp = new String[newsize]; if (readname != null) { System.arraycopy(readname,0,stmp,0,numReads); } readname = stmp;
      numReadsAllocated = newsize;
  }

  // storage of bases
  private int[] base = null;  // 1,2,3,4
  private int[] basequal = null; 

  private int numBases = 0;
  private int numBasesAllocated = 0;
  private int basesGrowStep = 10000000;

  private void growBases(int newsize) {
      Runtime.getRuntime().gc();
      //System.out.print("  Max Memory "+Runtime.getRuntime().maxMemory());
      //System.out.print("  Total Memory "+Runtime.getRuntime().totalMemory());
      //System.out.print("  Free Memory "+Runtime.getRuntime().freeMemory());
      //System.out.print("\n");
      //System.out.println("Allocating bases to " + newsize);
    int[] tmp = null;
    tmp = new int[newsize]; if (base != null) { System.arraycopy(base,0,tmp,0,numBases); } base = tmp;
    tmp = new int[newsize]; if (basequal != null) { System.arraycopy(basequal,0,tmp,0,numBases); } basequal = tmp;
    numBasesAllocated = newsize;
  }

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

  // storage of filtering criteria
  private boolean include_unmapped_reads = true;
  public void set_unmapped_reads_included() {
      include_unmapped_reads = true;
  }
  public void set_unmapped_reads_excluded() {
      include_unmapped_reads = false;
  }

  public void loadRegion(int chr, int start) throws Exception {
      loadRegion(chr,start,start);
  }

  // storage of query parameters
  private int query_chr = -1;
  private int query_start = -1;
  private int query_end = -1;

    public void loadRegion(int chr, int start, int end) throws Exception {    

      if (!fileOpen) throw new Exception("call openFile() first");

      // save query parameters
      query_chr = chr;
      query_start = start;
      query_end = end;

      // initialize BAM iterator
      String seqname = (String)dictSeqChr2Name.get(chr);
      SAMRecordIterator c = reader.queryOverlapping(seqname,start,end);

      numReads = 0;
      numReadsAllocated = 0;
      numBases = 0;
      numBasesAllocated = 0;

      // process reads

      long starttime = System.currentTimeMillis();
      long idx = 0;

      while(c.hasNext()) {
	SAMRecord x = c.next();

	// progress reporting
        idx++;
        if ((idx%100000)==0) {
            float elapsed = (System.currentTimeMillis() - starttime);
            elapsed /= 1000;
            float rate = idx / elapsed;
	    if (!quietMode) {
	      System.out.println(idx + " records  " + elapsed + " seconds  " + rate + " records/second  " + chr + ":" + read.start);
	    }
        }

	// filtering
	if (!include_unmapped_reads && x.getReadUnmappedFlag()) continue;          // filter out unmapped reads
	if (x.getDuplicateReadFlag()) continue;           // filter out duplicate reads
        if (x.getNotPrimaryAlignmentFlag()) continue;     // filter out non-primary alignments

        parse(x);
	if (read.isBlacklisted) continue;

        // store read

        if (numReads+1 > numReadsAllocated) growReads(numReadsAllocated + readsGrowStep);
        int i = numReads;
        readgroup[i] = read.readgroup;
        readname[i] = read.name;
        namenumber[i] = read.namenumber;
	readstart[i] = read.start;
	readend[i] = read.end;
	readstrand[i] = read.strand;
	readmapqual[i] = read.mapqual;
	whichpairmate[i] = read.whichpairmate;
	pairmatechr[i] = read.pairmatechr;
	pairmatestart[i] = read.pairmatestart;
	pairmatestrand[i] = read.pairmatestrand;
        numReads++;

	// store bases

	if (numBases+read.numBases > numBasesAllocated) growBases(numBasesAllocated + basesGrowStep);
        baseindex[i] = numBases;
	for (int j=0;j<read.numBases;j++) {
	    base[numBases] = read.base[j];
	    basequal[numBases] = read.basequal[j];
	    numBases++;
 	}
	nmismatches[i] = read.nmismatches;

	// next record
        if (maxReads>-1 && numReads==maxReads) break;
      }

      growReads(numReads);   // trim arrays to just the right size
      growBases(numBases);

      c.close();

  } // end of loadRegion()

  //////////////////////////////////////////////////////////////////////
  //////////////////////////////////////////////////////////////////////

    public void parse(SAMRecord x) throws Exception {
	parse(x,true);
    }

    public void parse(SAMRecord x, boolean parseBases) throws Exception {
	parse(x,parseBases,true);
    }
    
    public void parse(SAMRecord x, boolean parseBases, boolean parseQuals) throws Exception {

	   read.empty = false;

	   Object tmp = x.getAttribute("RG");
	   if (tmp == null) {
	       read.isBlacklisted = false;
	       read.readgroup = 0;
	   } else {
	       read.rgstring = tmp.toString();
           if ( ! readGroupIsBlacklisted.containsKey(read.rgstring) ) {
             // malformed bam: a read is annotated with RG that was not present in the header
             read.isBlacklisted = false;
           } else {
	         read.isBlacklisted = readGroupIsBlacklisted.get(read.rgstring);
           }
	       if (read.isBlacklisted) return;
           if ( readGroupNum.containsKey(read.rgstring) ) {
	           read.readgroup = (short)(int)(Integer)readGroupNum.get(read.rgstring);
           } else {
               read.readgroup = -1;
           }
	   }

       if (parseBases) {
          read.cigarstring = x.getCigarString().toUpperCase();
          read.seqstring = x.getReadString().toUpperCase().toCharArray();
	      if (parseQuals) {
		     read.qualstring = x.getBaseQualityString().toCharArray();
	      } else {
		     read.qualstring = null;
	      }
       }
	   read.seqlength = x.getReadLength();
	   if (!parseBases) {
	      read.numBases = read.seqlength;   // estimate (exact number comes from parsing cigar string)
	   }
	
	   read.name = x.getReadName();
       read.namenumber = read.name.hashCode();
	   read.mapped = (!x.getReadUnmappedFlag());
       // next line takes care of nonsensical alignments like 25S75I (no actual bases over the ref!! - this is
       // seemingly a bwa bug)
       if ( x.getAlignmentEnd() < x.getAlignmentStart() ) read.mapped = false;

        // the second condition below protects against insane alignments like 57S44I. Even if we trusted them,
        // they do not have any 'reference' bases to compare against so for all intents and purposes they probably should be
        // treated as unmapped? If a better classification for those does exist, that's great, but we always need to check
        // explicitly for end>=start condition and do something when it's not true, otherwise ref.get() will fail in parse_bases()!!
       if (read.mapped ) {
	       read.chr = (byte)(int)(Integer)dictSeqIndex2Chr.get(x.getReferenceIndex());
	       read.start = x.getAlignmentStart();
	       read.end = x.getAlignmentEnd();
	       read.strand = x.getReadNegativeStrandFlag() ? (byte)1 : (byte)0;
	       read.mapqual = (byte)x.getMappingQuality();
	   } else {   // unmapped read
	       read.chr = -200;
	       read.start = -200;
           read.end = -200+read.seqlength-1;
	       read.strand = -1;
	       read.mapqual = -1;
	   }

	   read.paired = x.getReadPairedFlag();
	   if (read.paired) {
	       read.whichpairmate = x.getFirstOfPairFlag() ? (byte)1 : (byte)2;
	       read.pairmatemapped = (!x.getMateUnmappedFlag());
	       if (read.pairmatemapped) {
	  	      read.pairmatechr = (byte)(int)(Integer)dictSeqIndex2Chr.get(x.getMateReferenceIndex());
		      read.pairmatestart = x.getMateAlignmentStart();
		      read.pairmatestrand = x.getMateNegativeStrandFlag() ? (byte)1 : (byte)0;
		      read.insertsize = Math.abs(x.getInferredInsertSize());
	       } else {
	          read.pairmatechr = -1;     // pairmate not mapped
		      read.pairmatestart = -1;
		      read.pairmatestrand = -1;
		      read.insertsize = -1;
	       }
	  } else {
	       read.whichpairmate = -1;   // not paired
	       read.pairmatechr = -1;
	       read.pairmatestart = -1;
	       read.pairmatestrand = -1;
	       read.insertsize = -1;
	  }

	  read.check_if_weird();

      if (parseBases) {
          read.parse_cigar();
          read.parse_bases(ref);
      }

    } // end of parse()

    //////////////////////////////////////////////////////////////////////
    //////////////////////////////////////////////////////////////////////

    // accessor functions

    public int[] getReadgroup() { return(readgroup); }
    public int[] getNamenumber() { return(namenumber); }
    public int[] getWhichpairmate() { return(whichpairmate); }
    public int[] getReadstart() { return(readstart); }
    public int[] getReadend() { return(readend); }
    public int[] getReadstrand() { return(readstrand); }
    public int[] getReadmapqual() { return(readmapqual); }
    public int[] getBaseindex() { return(baseindex); }
    public int[] getBase() { return(base); }
    public int[] getBasequal() { return(basequal); }
    public int[] getNmismatches() { return(nmismatches); }
    public int[] getPairmatechr() { return(pairmatechr); }
    public int[] getPairmatestart() { return(pairmatestart); }
    public int[] getPairmatestrand() { return(pairmatestrand); }
    public String[] getReadname() { return(readname); }

    public char[] getReference() throws Exception {
        if (ref == null) throw new Exception("call loadRegion() first");
        return(ref.get(query_chr,query_start,query_end).toCharArray());
    }

    /////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////

    // functions to calculate the last two columns of B as previously done in Matlab

    public int[] getBaseReadIndex() {
	int[] x = null;
	if (numBases>0) {
	    x = new int[numBases];
	    for (int i=0; i<numReads; i++) { x[baseindex[i]] = 1; }
	    for (int i=1; i<numBases; i++) { x[i] += x[i-1]; }
	}
	return(x);
    }

    public int[] getBasePosition() {
        int[] x = null;
        if (numBases>0) {
            x = new int[numBases];
	    for (int i=0; i<numBases; i++) { x[i] = 1; }
            for (int i=0; i<numReads; i++) { x[baseindex[i]] = readstart[i] - (i==0 ? 0 : readend[i-1]); }
            for (int i=1; i<numBases; i++) { x[i] += x[i-1]; }
        }
        return(x);
    }

} // end of class
