/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.genee.cmap.scripts;

import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.broadinstitute.genee.heatmap.DefaultProject;
import org.broadinstitute.genee.heatmap.Project;
import org.broadinstitute.genee.io.util.IOUtil;
import org.broadinstitute.genee.io.util.ProjectIO;
import org.broadinstitute.genee.io.util.ToStringUtil;
import org.broadinstitute.genee.marker.permutation.CombinationGenerator;
import org.broadinstitute.genee.matrix.Dataset;
import org.broadinstitute.genee.matrix.DatasetRowView;
import org.broadinstitute.genee.matrix.DatasetUtil;
import org.broadinstitute.genee.matrix.RowMajorArray2DDataset;
import org.broadinstitute.genee.matrix.Vector;
import org.broadinstitute.genee.stats.Sorting;
import uk.co.flamingpenguin.jewel.cli.Cli;
import uk.co.flamingpenguin.jewel.cli.CliFactory;
import uk.co.flamingpenguin.jewel.cli.CommandLineInterface;
import uk.co.flamingpenguin.jewel.cli.Option;

public class CellLineSelector {
    public static Selection combo(Dataset d, int r, float cutoffForExpressed) throws Exception {
        CombinationGenerator gen = new CombinationGenerator(d.getColumnCount(), r);
        int best = 0;
        ArrayList<int[]> theBest = new ArrayList<int[]>();
        while (gen.hasMore()) {
            int[] columnIndices = gen.getNext();
            Dataset subset = DatasetUtil.sliceView(d, null, columnIndices);
            int numFound = CellLineSelector.getNumGenesExpressedInAtLeastOneCellLine(subset, cutoffForExpressed);
            if (numFound > best) {
                best = numFound;
                theBest.clear();
                theBest.add((int[])columnIndices.clone());
                continue;
            }
            if (numFound != best) continue;
            theBest.add((int[])columnIndices.clone());
        }
        return new Selection(theBest);
    }

    public static void discretize(Dataset dataset, float cutoff) {
        int nrows = dataset.getRowCount();
        for (int i = 0; i < nrows; ++i) {
            int ncols = dataset.getColumnCount();
            for (int j = 0; j < ncols; ++j) {
                float value = dataset.getValue(i, j);
                dataset.setValue(i, j, value > cutoff ? 1.0f : 0.0f);
            }
        }
    }

    public static boolean isExpressedInAll(DatasetRowView dataset, float cutoff) {
        int ncols = dataset.size();
        for (int j = 0; j < ncols; ++j) {
            float value = dataset.getValue(j);
            if (!(value <= cutoff)) continue;
            return false;
        }
        return true;
    }

    public static boolean isNotExpressed(DatasetRowView dataset, float cutoff) {
        int ncols = dataset.size();
        for (int j = 0; j < ncols; ++j) {
            float value = dataset.getValue(j);
            if (!(value > cutoff)) continue;
            return false;
        }
        return true;
    }

    public static void main(String[] args) throws Exception {
        Selection selection;
        Dataset originalDataset;
        Cli cli = CliFactory.createCli(CellLineSelectorArgs.class);
        if (args.length == 0) {
            System.out.println(cli.getHelpMessage());
            System.exit(0);
        }
        CellLineSelectorArgs argsp = (CellLineSelectorArgs)cli.parseArguments(args);
        String method = argsp.getMethod();
        int nsamples = argsp.getNSamplesToChoose();
        float cutoffForExpressed = 6.0f;
        Dataset d = originalDataset = ProjectIO.readProject(argsp.getFile()).getOriginalDataset();
        if (method.equalsIgnoreCase("random")) {
            selection = CellLineSelector.pickBestRandom(d, nsamples, cutoffForExpressed, argsp.getNumPermutations(), argsp.getSeed());
        } else if (method.equalsIgnoreCase("greedy1")) {
            selection = CellLineSelector.pickBestGreedy1(d, nsamples, cutoffForExpressed);
        } else if (method.equalsIgnoreCase("greedy2")) {
            selection = CellLineSelector.pickBestGreedy2(d, nsamples, cutoffForExpressed);
        } else if (method.equalsIgnoreCase("complete")) {
            selection = CellLineSelector.combo(d, nsamples, cutoffForExpressed);
        } else {
            throw new IllegalArgumentException();
        }
        System.out.println(selection.columnIndices.size() + " solutions found.");
        int size = selection.columnIndices.size();
        for (int datasetIndex = 0; datasetIndex < size; ++datasetIndex) {
            int[] columnIndices = (int[])selection.columnIndices.get(datasetIndex);
            Dataset subset = DatasetUtil.sliceView(d, null, columnIndices);
            int numFound = CellLineSelector.getNumGenesExpressedInAtLeastOneCellLine(subset, cutoffForExpressed);
            System.out.println(numFound + " = " + ToStringUtil.toString(columnIndices, ", "));
            String name = IOUtil.getBaseFileName(argsp.getFile()) + "_" + numFound + "_out_of_" + originalDataset.getRowCount() + "_genes_" + nsamples + "_samples";
            if (method.equalsIgnoreCase("random")) {
                name = name + "_" + argsp.getSeed() + "_seed";
            }
            if (size > 1) {
                name = name + "-" + (datasetIndex + 1);
            }
            CellLineSelector.addExpressionInfo(DatasetUtil.sliceView(originalDataset, null, Arrays.copyOf(columnIndices, 10)), 10, cutoffForExpressed);
            CellLineSelector.addExpressionInfo(subset, nsamples, cutoffForExpressed);
            CellLineSelector.addExpressionInfo(originalDataset, originalDataset.getColumnCount(), cutoffForExpressed);
            ProjectIO.writeProject((Project)new DefaultProject(subset), "gct", name, true);
        }
    }

    public static Selection pickBestRandom(Dataset dataset, int maxSamples, float cutoff, int nperms, long seed) throws IOException {
        Random rnd = new Random(seed);
        int bestNumGenes = 0;
        ArrayList<int[]> bestIndicesList = new ArrayList<int[]>();
        for (int i = 0; i < nperms; ++i) {
            TIntHashSet set = new TIntHashSet();
            while (set.size() < maxSamples) {
                int index = rnd.nextInt(dataset.getColumnCount());
                set.add(index);
            }
            int[] indices = set.toArray();
            Dataset subset = DatasetUtil.sliceView(dataset, null, indices);
            int numGenes = CellLineSelector.getNumGenesExpressedInAtLeastOneCellLine(subset, cutoff);
            if (numGenes > bestNumGenes) {
                bestNumGenes = numGenes;
                bestIndicesList.clear();
                bestIndicesList.add(indices);
                continue;
            }
            if (numGenes != bestNumGenes) continue;
            bestIndicesList.add(indices);
        }
        return new Selection(bestIndicesList);
    }

    public static void runDefault() throws Exception {
        CellLineSelector.main(new String[]{"--file", "/Users/jgould/datasets/CCLE_Expression_Entrez_2010-09-29.res", "--nsamples", "15"});
    }

    private static void addExpressionInfo(Dataset dataset, int nsamples, float cutoff) {
        Vector isExpressedInAll = dataset.getRowMetadata().add("Expressed In All " + nsamples, String.class);
        Vector notExpressed = dataset.getRowMetadata().add("Not Expressed In Any " + nsamples, String.class);
        DatasetRowView rowView = new DatasetRowView(dataset);
        int nrows = dataset.getRowCount();
        for (int i = 0; i < nrows; ++i) {
            rowView.setIndex(i);
            boolean expressedInAll = CellLineSelector.isExpressedInAll(rowView, cutoff);
            boolean isNotExpressed = CellLineSelector.isNotExpressed(rowView, cutoff);
            isExpressedInAll.setValue(i, expressedInAll ? "Y" : "N");
            notExpressed.setValue(i, isNotExpressed ? "Y" : "N");
        }
    }

    private static float[] getCountsPerSample(Dataset dataset, float cutoff) {
        float[] countsPerSample = new float[dataset.getColumnCount()];
        for (int j = 0; j < countsPerSample.length; ++j) {
            int nrows = dataset.getRowCount();
            for (int i = 0; i < nrows; ++i) {
                float value = dataset.getValue(i, j);
                if (!(value > cutoff)) continue;
                int n = j;
                countsPerSample[n] = countsPerSample[n] + 1.0f;
            }
        }
        return countsPerSample;
    }

    private static int getNumGenesExpressedInAtLeastOneCellLine(Dataset dataset, float cutoff) {
        int numFound = 0;
        int nrows = dataset.getRowCount();
        for (int i = 0; i < nrows; ++i) {
            boolean geneFound = false;
            int ncols = dataset.getColumnCount();
            for (int j = 0; j < ncols && !geneFound; ++j) {
                float value = dataset.getValue(i, j);
                if (!(value > cutoff)) continue;
                geneFound = true;
            }
            if (!geneFound) continue;
            ++numFound;
        }
        return numFound;
    }

    private static Selection pickBestGreedy1(Dataset originalDataset, int nsamples, float cutoff) {
        RowMajorArray2DDataset dataset = DatasetUtil.deepCopy(originalDataset);
        TIntArrayList bestIndices = new TIntArrayList();
        Random random = new Random();
        while (bestIndices.size() < nsamples) {
            int indexToAdd;
            float[] countsPerSample = CellLineSelector.getCountsPerSample(dataset, cutoff);
            int[] indexTable = Sorting.index(countsPerSample, false);
            int index = indexTable[0];
            float count = countsPerSample[index];
            TIntArrayList matches = new TIntArrayList();
            matches.add(index);
            int matchIndex = 1;
            while (countsPerSample[indexTable[matchIndex]] == count) {
                matches.add(indexTable[matchIndex]);
                ++matchIndex;
            }
            if (matches.size() > 1) {
                int rnd = random.nextInt(matches.size());
                indexToAdd = matches.getQuick(rnd);
                System.out.println(matches.size() + " ties for " + bestIndices.size() + " selected " + rnd);
            } else {
                indexToAdd = matches.getQuick(0);
            }
            countsPerSample[indexToAdd] = 0.0f;
            bestIndices.add(indexToAdd);
            CellLineSelector.zeroCountsForSample(dataset, indexToAdd, cutoff);
        }
        int[] selection = bestIndices.toArray();
        return new Selection(Arrays.asList(new int[][]{selection}));
    }

    private static Selection pickBestGreedy2(Dataset originalDataset, int nsamples, float cutoff) {
        RowMajorArray2DDataset dataset = DatasetUtil.deepCopy(originalDataset);
        float[] countsPerGene = new float[dataset.getRowCount()];
        int nrows = dataset.getRowCount();
        for (int i = 0; i < nrows; ++i) {
            int ncols = dataset.getColumnCount();
            for (int j = 0; j < ncols; ++j) {
                float value = dataset.getValue(i, j);
                if (!(value > cutoff)) continue;
                int n = i;
                countsPerGene[n] = countsPerGene[n] + 1.0f;
            }
        }
        TIntArrayList bestIndices = new TIntArrayList();
        Random random = new Random();
        int[] geneIndexTable = Sorting.index(countsPerGene, false);
        for (int i = 0; i < nsamples && bestIndices.size() < nsamples; ++i) {
            int indexToAdd;
            int geneIndex = geneIndexTable[i];
            Dataset currentSelection = DatasetUtil.sliceView(originalDataset, null, bestIndices.toArray());
            boolean found = false;
            for (int j = 0; j < currentSelection.getColumnCount(); ++j) {
                float value = currentSelection.getValue(geneIndex, j);
                if (!(value > cutoff)) continue;
                found = true;
                break;
            }
            if (found) continue;
            float[] countsPerSample = CellLineSelector.getCountsPerSample(dataset, cutoff);
            int[] sampleIndexTable = Sorting.index(countsPerSample, false);
            int sampleIndex = sampleIndexTable[sampleIndexTable[0]];
            float count = countsPerSample[sampleIndex];
            TIntArrayList matches = new TIntArrayList();
            matches.add(sampleIndex);
            int matchIndex = 1;
            while (countsPerSample[sampleIndexTable[matchIndex]] == count) {
                matches.add(sampleIndexTable[matchIndex]);
                ++matchIndex;
            }
            if (matches.size() > 1) {
                indexToAdd = matches.getQuick(random.nextInt(matches.size()));
                System.out.println(matches.size() + " ties for " + bestIndices.size() + " selected " + indexToAdd);
            } else {
                indexToAdd = matches.getQuick(0);
            }
            countsPerSample[indexToAdd] = 0.0f;
            bestIndices.add(indexToAdd);
            CellLineSelector.zeroCountsForSample(dataset, sampleIndex, cutoff);
        }
        int[] selection = bestIndices.toArray();
        return new Selection(Arrays.asList(new int[][]{selection}));
    }

    private static void zeroCountsForSample(Dataset dataset, int sampleIndex, float cutoff) {
        int nrows = dataset.getRowCount();
        for (int i = 0; i < nrows; ++i) {
            float value = dataset.getValue(i, sampleIndex);
            if (!(value > cutoff)) continue;
            int ncols = dataset.getColumnCount();
            for (int j = 0; j < ncols; ++j) {
                dataset.setValue(i, j, 0.0f);
            }
        }
    }

    private static class Selection {
        private List<int[]> columnIndices;

        public Selection(List<int[]> columnIndices) {
            this.columnIndices = columnIndices;
        }
    }

    @CommandLineInterface(application="cls")
    public static interface CellLineSelectorArgs {
        @Option(longName={"file"}, description="input dataset")
        public String getFile();

        @Option(longName={"nsamples"}, description="number of samples to select", defaultValue={"10"})
        public int getNSamplesToChoose();

        @Option(longName={"method"}, description="random, greedy1, greedy2, or complete", defaultValue={"greedy1"})
        public String getMethod();

        @Option(longName={"permutations"}, description="number of permutations when random", defaultValue={"10000"})
        public int getNumPermutations();

        @Option(longName={"seed"}, description="seed for rng", defaultValue={"1234567"})
        public long getSeed();
    }
}

