package org.broadinstitute.cga.tools.seq;

// Mike Lawrence 2010-2011

import org.broadinstitute.cga.tools.seq.Genome;
import org.broadinstitute.cga.tools.seq.Wiggle;
import java.io.*;
import java.lang.*;
import java.util.*;

public class FixedWidthBinary {
    
    private static final boolean debug = false;

    // status
    private boolean open = false;
    private int width = 0;
    private RandomAccessFile raf = null;

    // index storage
    private boolean index_loaded = false; 
    private int num_regions = 0;
    private int[] region_chr = null;
    private int[] region_start = null;
    private int[] region_end = null;
    private long[] region_offset = null;

    // index hash
    private static final boolean useIndexHash = true;
    private static final int indexHashResolution = 1000000;
    public HashMap<Integer,HashMap<Integer,ArrayList<Integer>>> indexHash =
        new HashMap<Integer,HashMap<Integer,ArrayList<Integer>>>();

    public int getNumRegions() throws Exception {
        if (!index_loaded) throw new Exception("No index loaded");
        return(num_regions);
    }  
    public int getRegionChr(int r) throws Exception {    // r is 1-based  <--- really? (2011-09-15)
        if (!index_loaded) throw new Exception("No index loaded");
	if (r<0 || r>=num_regions) throw new Exception("no such region");
        return(region_chr[r]);
    }
    public int getRegionStart(int r) throws Exception {
        if (!index_loaded) throw new Exception("No index loaded");
        if (r<0 || r>=num_regions) throw new Exception("no such region");
        return(region_start[r]);
    }
    public int getRegionEnd(int r) throws Exception {
        if (!index_loaded) throw new Exception("No index loaded");
        if (r<0 || r>=num_regions) throw new Exception("no such region");
        return(region_end[r]);
    }
    
    // legal widths (in bits)
    private static final int[] legalWidth = { 1, 2, 4, 8, 16, 24, 32 };
    
    public int getWidth() throws Exception {
	if (!open) throw new Exception("No file open");
	return(width);
    }

    public static boolean isLegalWidth(int testWidth) {
        for (int i=0; i<legalWidth.length; i++) if (testWidth == legalWidth[i]) return(true);
        return(false);
    }

    public int maxValForWidth() throws Exception {
        if (!open) throw new Exception("Please specify width, or else open a file (and its width will be used)");
	return(maxValForWidth(width));
    }

    public static int maxValForWidth(int testWidth) throws Exception {
        if (!isLegalWidth(testWidth)) throw new Exception("Illegal width "+testWidth);
        return((int)java.lang.Math.pow(2,testWidth)-1);
    }

    // value returned in cases of "no data"
    private int nullVal = -1;

    public FixedWidthBinary() {
    }

    public FixedWidthBinary(String fwbName, String indexName, int setWidth,  
			    int chrcol, int startcol, int endcol) throws Exception {
        open(fwbName, indexName, setWidth, chrcol, startcol, endcol);
    }

    public FixedWidthBinary(String fwbName, String indexName, int setWidth) throws Exception {
        open(fwbName, indexName, setWidth);
    }

    public FixedWidthBinary(String fwbName, int setWidth) throws Exception {
        open(fwbName, setWidth);
    }

    public FixedWidthBinary(String fwbName) throws Exception {
        open(fwbName);
    }

    public boolean isOpen() {
	return(open);
    }


    public void open(String fwbName, String indexName, 
		     int chrcol, int startcol, int endcol) throws Exception {
        loadIndex(indexName,chrcol,startcol,endcol);
        open(fwbName,-1);
    }

    public void open(String fwbName, String indexName) throws Exception {
	loadIndex(indexName,1,2,3);
        open(fwbName,-1);
    }

    public void open(String fwbName, String indexName, int setWidth) throws Exception {
        loadIndex(indexName,1,2,3);
        open(fwbName,setWidth);
    }

    public void open(String fwbName, String indexName, int setWidth,
		     int chrcol, int startcol, int endcol) throws Exception {
        loadIndex(indexName,chrcol,startcol,endcol);
        open(fwbName,setWidth);
    }

    public void open(String fwbName) throws Exception {
	open(fwbName,-1);
    }

    public void open(String fwbName, int setWidth) throws Exception {
        open(fwbName,-1,1,2,3);
    }

    public void open(String fwbName, int setWidth, int chrcol, int startcol, int endcol) throws Exception {
        if (!index_loaded) { // try to guess index name
            String inStem = null;
            if (fwbName.toLowerCase().endsWith(new String(".fwb"))) {
                inStem = fwbName.substring(0,fwbName.length()-4);
            } else {
                throw new Exception("fwbName needs to end in '.fwb' in order to generate indexName");
            }
            String indexName = inStem + ".fwi";
            loadIndex(indexName,chrcol,startcol,endcol);
        }
	if (open) throw new Exception("Close file before opening another one");

        if (setWidth!=-1 && !isLegalWidth(setWidth)) throw new Exception("Illegal width "+setWidth);
	raf = new RandomAccessFile(fwbName, "r");
	long raflen = raf.length();
	if (setWidth==-1) {   // try to guess the width
	    for (int i=0; i<legalWidth.length; i++) {
		if (raflen == calculateIndexImpliedFileLength(legalWidth[i])) {
		    setWidth = legalWidth[i];
		    break;
		}
	    }    
	    if (setWidth==-1) {
		raf.close();
		if (setWidth==-1) throw new Exception("File and index are mismatched, no matter what width is tried");
	    }
	} else {
	    if (raflen != calculateIndexImpliedFileLength(setWidth)) {
		raf.close();
		throw new Exception("File and index are mismatched, using width = "+setWidth);
	    }
	}
	width = setWidth;
	open = true;
    }

    public void close() throws Exception {
	if (open) {
	    raf.close();
	    open = false;
	    index_loaded = false;
	} else {
	    System.out.println("No file open.");
	}
    }

    public void loadIndex(String indexName) throws Exception {
	loadIndex(indexName, 1, 2, 3);
    }

    public void loadIndex(String indexName, int chrcol, int startcol, int endcol) throws Exception {
	if (open) {
	    System.out.println("First, closing the open file.");
	    close();
	}
	int maxcol = chrcol;
	if (startcol>maxcol) maxcol = startcol;
	if (endcol>maxcol) maxcol = endcol;

        index_loaded = false;
        num_regions = 0;

	// first, count lines
	BufferedReader input = new BufferedReader(new FileReader(indexName));
        String line;
        int nr = 0;
        while ((line = input.readLine()) != null) {
            if (line.equals("")) continue;   // skip blank lines
	    nr++;
	}
	input.close();
	if (nr==0) throw new Exception("Blank index file!");

	// allocate space for index
	try {
	    region_chr = new int[nr];
	    region_start = new int[nr];
	    region_end = new int[nr];
	    region_offset = new long[nr];
	} catch (Exception e) {
	    throw new Exception("Index too long: memory allocation failed");
	}

	// now load it for real
	long offset = 0;
	int lineno = 0;
	int idx = 0;
	input = new BufferedReader(new FileReader(indexName));
	while ((line = input.readLine()) != null) {
	    lineno++;
            if (line.equals("")) continue;   // skip blank lines
	    if (idx >= nr) {
		throw new Exception("File length changed!");
	    }
	    final String[] fields = line.split("\t");
	    if (fields.length<maxcol) {
		throw new Exception("No column "+maxcol+" in index line "+lineno+":\n"+line);
	    }
	    int chr=-1, start=-1, end=-1;
	    String chrstr = fields[chrcol-1];
	    if (chrstr.startsWith("chr")) chrstr=chrstr.substring(3);
	    if (chrstr.equalsIgnoreCase("M")) chrstr="0";
	    if (chrstr.equalsIgnoreCase("MT")) chrstr="0";
	    if (chrstr.equalsIgnoreCase("X")) chrstr="23";
	    if (chrstr.equalsIgnoreCase("Y")) chrstr="24";
	    try {
		chr = Integer.parseInt(chrstr);
	    } catch (Exception e) {
		throw new Exception("Invalid chromosome "+fields[chrcol-1]+" in index line "+lineno+":\n"+line);
	    }
	    try {
		start = Integer.parseInt(fields[startcol-1]);
	    } catch (Exception e) {
		throw new Exception("Invalid start "+fields[startcol-1]+" in index line "+lineno+":\n"+line);
	    }
	    try {
		end = Integer.parseInt(fields[endcol-1]);
	    } catch (Exception e) {
		throw new Exception("Invalid end "+fields[endcol-1]+" in index line "+lineno+":\n"+line);
	    }
	    if (end<start) throw new Exception("end<start in index line "+lineno+":\n"+line);
	    region_chr[idx] = chr;
	    region_start[idx] = start;
	    region_end[idx] = end;
	    region_offset[idx] = offset;

	    offset += (end-start+1);	
	    idx++;
	}
	input.close();
	index_loaded = true;
	num_regions = nr;
	if (useIndexHash) createIndexHash();
    }

    private void createIndexHash() throws Exception {
	if (!index_loaded) throw new Exception("No index loaded");
	indexHash.clear();
	for (Integer i=0; i<num_regions; i++) {
	    Integer regChr = new Integer(region_chr[i]);
	    HashMap<Integer,ArrayList<Integer>> chrHash = null;
	    if (indexHash.containsKey(regChr)) chrHash = indexHash.get(regChr);
	    else chrHash = new HashMap<Integer,ArrayList<Integer>>();
	    for (int pos=region_start[i]; pos<region_end[i]+indexHashResolution; pos+=indexHashResolution) {
		if (pos>region_end[i] && pos/indexHashResolution == region_start[i]/indexHashResolution) break;
		Integer regPos = new Integer(pos/indexHashResolution);
		ArrayList<Integer> posList = null;
		if (chrHash.containsKey(regPos)) posList = chrHash.get(regPos);
		else posList = new ArrayList<Integer>();
		posList.add(i);
		chrHash.put(regPos,posList);
	    }
	    indexHash.put(regChr,chrHash);
	}
    }




    private long calculateIndexImpliedFileLength(int testWidth) throws Exception {
	if (!index_loaded) throw new Exception("No index loaded");
	if (!isLegalWidth(testWidth)) throw new Exception("Illegal width "+testWidth);
	if (num_regions==0) throw new Exception("No regions in index");
	long len = region_offset[num_regions-1] + region_end[num_regions-1] - region_start[num_regions-1] + 1;
	long bitlen = len * testWidth;
	long bytelen = (bitlen/8) + ((bitlen%8>0)?1:0);
	return(bytelen);
    }

    public void setNullVal(int val) {
	nullVal = val;
    }

    public int getNullVal() {
	return(nullVal);
    }

    public int get(int chr, int pos) throws Exception {
	int[] chrs = new int[1]; chrs[0] = chr;
        int[] poses = new int[1]; poses[0] = pos;
	int[] vals = get(chrs,poses,poses);
	return(vals[0]);
    }

    public int[] get(int chr, int[] poses) throws Exception {
        int[] chrs = new int[poses.length];
	Arrays.fill(chrs,chr);
        return(get(chrs,poses,poses));
    }

    public int[] get(int[] chrs, int[] poses) throws Exception {
        return(get(chrs,poses,poses));
    }

    public int[] get(int chr, int start, int end) throws Exception {
	int[] chrs = new int[1]; chrs[0] = chr;
        int[] starts = new int[1]; starts[0] = start;
        int[] ends = new int[1]; ends[0] = end;
	return(get(chrs,starts,ends));
    }

    public int[] get(int chr, int[] starts, int[] ends) throws Exception {
        int[] chrs = new int[1]; chrs[0] = chr;
        return(get(chrs,starts,ends));
    }

    public int[] get(int[] chrs, int[] starts, int[] ends) throws Exception {
	if (!open) throw new Exception("File not open");
	int nq = chrs.length;
	if (starts.length!=nq || ends.length!=nq) throw new Exception("Inconsistent query length");
	// calculate total query length, allocate buffer
	int totlen = 0;
	int qi;
	for (qi=0;qi<nq;qi++) totlen += (ends[qi]-starts[qi]+1);
	int[] result = new int[totlen];
	Arrays.fill(result,nullVal);  // initialize with nullVal
	// load each query region
	int result_idx = 0;
	for (qi=0;qi<nq;qi++) {
	    int chr = chrs[qi];
	    int pos = starts[qi];
	    int end = ends[qi];    
	    if (end<pos) throw new Exception("end "+end+" < start "+pos);
	    while(pos<=end) {
		if (debug)
		    System.out.println("Seeking "+chr+":"+pos+"-"+end);
		int ri = findRegionInIndex(chr,pos,end);
		if (ri==-1) {  // rest of query region not found
		    if (debug)      		    System.out.println("  Not found");
		    int len = (end-pos+1);
		    pos += len;
		    result_idx += len;
		} else {      
		    if (debug) 
			System.out.println("  Found: region "+ri+" = "+region_chr[ri]+":"+region_start[ri]+"-"+region_end[ri]);
		    if (region_start[ri]>pos) { // index region starts after query region: need to advance some 
			int len = (region_start[ri]-pos);
			pos += len;
			result_idx += len;
		    }
		    // determine which stretch of region to add
		    int region_add_first, region_add_last;
		    if (region_start[ri]<pos) {
			region_add_first = pos-region_start[ri];
		    } else {
			region_add_first = 0;
		    }
		    if (region_end[ri]>end) {
			if (debug)      			System.out.println("Using til before end");
			region_add_last = end-region_start[ri];
		    } else {
			if (debug)      			System.out.println("Using til end");
			region_add_last = region_end[ri]-region_start[ri];
		    }
		    getRegion(result,result_idx,ri,region_add_first,region_add_last);
		    int len = (region_add_last-region_add_first+1);
		    pos += len;
		    result_idx += len;
		}	    
	    }
	}
	return(result);
    }

    private int findRegionInIndex(int chr,int start,int end) throws Exception {
	// finds the *first* region in the region list that overlaps the specified interval by at least one basepair
	// returns the index (zero-based) of that region
	// returns -1 if no region found
	if (!index_loaded) throw new Exception("Index not loaded");
	if (num_regions==0) throw new Exception("No regions in index");
	if (useIndexHash) {
	    Integer regChr = new Integer(chr);
	    if (indexHash.containsKey(regChr)) {
		HashMap<Integer,ArrayList<Integer>> chrHash = indexHash.get(regChr);
		for (int pos=start; pos<=end; pos+=indexHashResolution) {
		    Integer regPos = new Integer(pos/indexHashResolution);
		    if (chrHash.containsKey(regPos)) {
			ArrayList<Integer> posList = chrHash.get(regPos);
			for (Integer ri : posList) {
			    if (region_chr[ri]==chr && region_start[ri]<=end && region_end[ri]>=start) return(ri);
			}
		    }
		}
	    }
	} else {
	    for (int ri=0;ri<num_regions;ri++) {
		if (region_chr[ri]==chr && region_start[ri]<=end && region_end[ri]>=start) return(ri);
	    }
	}
	return(-1);
    }

    private long calculateBitOffset(int ri, int ri_offset) throws Exception {
	// calculates the absolute file offset, in bits, for the given region, at the specified offset within the region
	if (!index_loaded) throw new Exception("File not open");
	if (!isLegalWidth(width)) throw new Exception("Illegal width set = "+width);
	return ((region_offset[ri] + ri_offset)*width);
    }

    private void getRegion(int[] dest, int dest_offset, int ri, int region_add_first, int region_add_last) throws Exception {
	if (debug)	System.out.println("Getting region "+ri+" "+region_add_first+"-"+region_add_last);
	// loads data from the file, corresponding to the specified segment of the specified region
	// stores it in the supplied destination buffer, at the specified offset within the buffer
	if (!open) throw new Exception("File not open");
	long start_bit = calculateBitOffset(ri,region_add_first);
	long end_bit = calculateBitOffset(ri,region_add_last)+width-1;
	if (debug)	System.out.println("start_bit "+start_bit+"  end_bit "+end_bit);
	int start_byte_bit = (int)(start_bit%8);
	long start_byte = start_bit/8;
	int end_byte_bit = (int)(end_bit%8);
	long end_byte = end_bit/8;
	if (debug)      	System.out.println("start_byte "+start_byte+"  start_byte_bit "+start_byte_bit);
	if (debug)              System.out.println("end_byte "+end_byte+"  end_byte_bit "+end_byte_bit);
	byte[] buffer = new byte[(int)((end_byte-start_byte)+1)];
	if (debug)      	System.out.println("Reading bytes "+start_byte+" to "+end_byte);
	raf.seek(start_byte);
	raf.read(buffer);
	// decode
	int from_byte = 0;
	int from_bit = start_byte_bit;
	int to_idx = 0;
        int num_to_add = (region_add_last-region_add_first+1);
	int bitbuffer = 0;
	int bitbuffercount = 0;
	int byte_to_unpack = -1, bitmask = -1;  // (only used if width<8)
	if (width==1) bitmask=128; 
	else if (width==2) bitmask=128+64;
	else if (width==4) bitmask=128+64+32+16;
	while (to_idx<num_to_add) {
	    while (bitbuffercount<width) {
		if (debug)      		System.out.println("bitbuffer "+bitbuffer+"  bitbuffercount "+bitbuffercount);
		if (width>=8) {
		    int val = (int)buffer[from_byte];
		    if (val<0) val += 256;   // convert from signed
		    if (debug)                          System.out.println("adding from_byte "+from_byte+" = "+val);
                    bitbuffer <<= 8;
		    bitbuffer += val;
		    bitbuffercount += 8;
		    from_byte++;
		} else {   // width<*8		    
		    if (byte_to_unpack==-1) {
			byte_to_unpack = (int)buffer[from_byte];
			if (byte_to_unpack<0) byte_to_unpack += 256;  // convert from signed
			byte_to_unpack <<= from_bit;
		    }
		    if (debug)      		    System.out.println("Extracting value from byte_to_unpack "+byte_to_unpack);
		    bitbuffer = (byte_to_unpack & bitmask) >> (8-width);
		    if (debug)      		    System.out.println("bitbuffer = "+bitbuffer);
		    byte_to_unpack <<= width;
		    bitbuffercount = width;
		    from_bit += width;
		    if (from_bit>=8) {
			from_bit = 0;
			from_byte++;
			byte_to_unpack = -1;
		    }
		}
	    }
	    if (debug)      	    System.out.println("ADDING FINAL VALUE "+bitbuffer);
	    dest[dest_offset+to_idx] = bitbuffer;
	    bitbuffer = 0;
	    bitbuffercount = 0;
	    to_idx++;
	}
    }

    ///////////////////////////////////////////////////////////////////////////////////////////////////
    ///////////////////////////////////////////////////////////////////////////////////////////////////

    public static void CreateFromWiggle(String wigName,int outWidth) throws Exception {
	String outStem = null;
        if (wigName.toLowerCase().endsWith(new String(".wig"))) {
	    outStem = wigName.substring(0,wigName.length()-4);
	} else if (wigName.toLowerCase().endsWith(new String(".wig.txt"))) {
	    outStem = wigName.substring(0,wigName.length()-8);
	} else {
		throw new Exception("wigName needs to end in '.wig' or '.wig.txt' in order to generate outName, IndexName");
	}
	String outName = outStem + ".fwb";
	String indexName = outStem + ".fwi";
	CreateFromWiggle(wigName,outWidth,outName,indexName);
    }

    public static void CreateFromWiggle(String wigName,int outWidth,String outName) throws Exception {
	String outStem = null;
        if (outName.toLowerCase().endsWith(new String(".fwb"))) {
	    outStem = outName.substring(0,outName.length()-4);
	} else {
	    throw new Exception("outName needs to end in '.fwb' in order to generate indexName");
        }
        String indexName = outStem + ".fwi";
        CreateFromWiggle(wigName,outWidth,outName,indexName);
    }

    public static void CreateFromWiggle(String wigName,int outWidth,String outName,String indexName) throws Exception {
	if (!isLegalWidth(outWidth)) throw new Exception("Illegal width "+outWidth);
	int maxval = maxValForWidth(outWidth);
        Wiggle wig = new Wiggle(wigName);

        int bsz = 10000000;       // buffer size
        DataOutputStream fwbOut = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(outName),bsz));
	bsz = 1000000;
	BufferedWriter indexOut = new BufferedWriter(new FileWriter(indexName),bsz);
	
	int this_chr = -1;
	int this_start = -1;
	int this_end = -1;
	int chr = -1;
	int pos = -1;
	int val = -1;
        int old_pos = -1;

	int bitbuffer = 0;
        int bitbuffercount = 0;

	boolean first_region = true;
	boolean first_base = true;
	boolean data_available = true;
	while(data_available) {
	    data_available = wig.next();
	    boolean close_last_region = true;
	    if (data_available) {
		chr = wig.getChr();
		pos = wig.getPos();
		val = wig.getValue();
                if (val<0) throw new Exception("Can't store negative values in FWB");
                if (val>maxval) throw new Exception("Can't store values above "+maxval+" (e.g. "+val+") using width="+outWidth);
		if (first_region || (chr==this_chr && pos==old_pos+1)) close_last_region = false;
	    }
	    if (close_last_region) {
		this_end = old_pos;
		indexOut.write(this_chr+"\t"+this_start+"\t"+this_end+"\n");
	    }
	    if (close_last_region || first_base) {
		this_chr = chr;
                this_start = pos;
                first_region = false;
		first_base = false;
	    }
	    if (data_available) {    // output the value that was just read from the wiggle file
                if (outWidth<8) {
                    bitbuffer <<= outWidth;
                    bitbuffer += val;
                    bitbuffercount += outWidth;
                    if (bitbuffercount==8) {
                        fwbOut.writeByte(bitbuffer);
                        bitbuffer = 0;
                        bitbuffercount = 0;
                    }
                } else if (outWidth==8) {
                    fwbOut.writeByte(val);
                } else if (outWidth==16) {
                    fwbOut.writeShort(val);
                } else if (outWidth==24) {
                    int lowshort = val % 65536;
                    int highbyte = (val-lowshort) / 65536;
                    fwbOut.writeByte(highbyte);
                    fwbOut.writeShort(lowshort);
                } else if (outWidth==32) {
                    fwbOut.writeInt(val);
                } else {
                    throw new Exception("Impossible: width became invalid after initial check");
                }
		old_pos = pos;
	    }
	}
        if (bitbuffercount>0) { // flush bitbuffer
            bitbuffer <<= (8-bitbuffercount);
            fwbOut.writeByte(bitbuffer);
        }
	fwbOut.close();
	indexOut.close();
	wig.close();
	System.out.print("Read:\n\t"+wigName+"\nWrote (width = "+outWidth+"):\n\t"+outName+"\n\t"+indexName+"\n");
    }


    ///////////////////////////////////////////////////////////////////////////////////////////////////
    ///////////////////////////////////////////////////////////////////////////////////////////////////


}
