/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.genee.heatmap.menu;

import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.distribution.HypergeometricDistribution;
import org.broadinstitute.genee.clustering.hierarchical.Dendrogram;
import org.broadinstitute.genee.clustering.hierarchical.Node;
import org.broadinstitute.genee.math.adjust.LogScaleUnivariateFunction;
import org.broadinstitute.genee.matrix.Dataset;
import org.broadinstitute.genee.matrix.Vector;
import org.broadinstitute.genee.matrix.VectorUtil;

public class NodeScoreCalculator {
    private float maxNodeCorrelation = 1.0f;
    private int minClusterMembers = 5;
    private int minClusterMembersForCategory = 3;
    private float minPValue = 0.05f;
    private int numberOfHypothesesTested;

    public NodeScoreCalculator(int minClusterMembers, int minClusterMembersForCategory, float minPValue, float maxNodeCorrelation) {
        this.minClusterMembers = minClusterMembers;
        this.minClusterMembersForCategory = minClusterMembersForCategory;
        this.minPValue = minPValue;
        this.maxNodeCorrelation = maxNodeCorrelation;
    }

    public void setNodePValues(Dataset dataset, Dendrogram dendrogram, NodeScoreSetter nodeScoreSetter) {
        Node[] rootNodesArray;
        int[] numberOfSuccesses = new int[dataset.getColumnCount()];
        int cols = dataset.getColumnCount();
        for (int j = 0; j < cols; ++j) {
            int count = 0;
            int rows = dataset.getRowCount();
            for (int i = 0; i < rows; ++i) {
                if (dataset.getValue(i, j) == 0.0f) continue;
                ++count;
            }
            numberOfSuccesses[j] = count;
        }
        for (Node node : rootNodesArray = dendrogram.getRootNodes()) {
            this.computeNodeScores(node, dataset, nodeScoreSetter, numberOfSuccesses, dataset.getRowCount());
        }
    }

    public void setNodePValues(Dendrogram dendrogram, Dataset dataset, Vector[] categoryGroups, NodeScoreSetter nodeScoreSetter) {
        Node[] rootNodesArray;
        for (Node node : rootNodesArray = dendrogram.getRootNodes()) {
            this.computePValues(node, dataset, categoryGroups, nodeScoreSetter);
        }
    }

    private void computePValues(Node node, Dataset dataset, Vector[] categoryGroups, NodeScoreSetter nodeScoreSetter) {
        this.numberOfHypothesesTested = 0;
        int[][] numberOfSuccessesInPopulation = new int[categoryGroups.length][];
        int[] populationSize = new int[categoryGroups.length];
        int length = categoryGroups.length;
        for (int catIdx = 0; catIdx < length; ++catIdx) {
            Vector categoryGroup = categoryGroups[catIdx];
            Set categories = VectorUtil.getValues(categoryGroup);
            numberOfSuccessesInPopulation[catIdx] = new int[categories.size()];
            Map<Object, Integer> counts = NodeScoreCalculator.getCounts(dataset, 0, dataset.getRowCount() - 1, categoryGroup);
            int sum = 0;
            int i = 0;
            for (Object cat : categories) {
                int count;
                numberOfSuccessesInPopulation[catIdx][i] = count = counts.get(cat).intValue();
                sum += count;
                ++i;
            }
            populationSize[catIdx] = sum;
        }
        LinkedList<Node> q = new LinkedList<Node>();
        q.add(node);
        while (!q.isEmpty()) {
            boolean addChildren;
            Node n = (Node)q.removeFirst();
            float[] pValues = this.setNodePValue(n, dataset, categoryGroups, nodeScoreSetter, numberOfSuccessesInPopulation, populationSize);
            if (n.isLeaf()) continue;
            boolean bl = addChildren = n.getDistance() <= this.maxNodeCorrelation;
            if (addChildren) {
                for (float p : pValues) {
                    if (!(p <= this.minPValue)) continue;
                    addChildren = false;
                    break;
                }
            }
            if (!addChildren) continue;
            q.addFirst(n.getLeft());
            q.addFirst(n.getRight());
        }
    }

    public int getNumberOfTests() {
        return this.numberOfHypothesesTested;
    }

    private void computeNodeScores(Node node, Dataset dataset, NodeScoreSetter nodeScoreSetter, int[] numberOfSuccessesInPopulation, int populationSize) {
        LinkedList<Node> q = new LinkedList<Node>();
        q.add(node);
        while (!q.isEmpty()) {
            Node n = (Node)q.removeFirst();
            this.setNodeScore(n, dataset, nodeScoreSetter, numberOfSuccessesInPopulation, populationSize);
            if (n.isLeaf()) continue;
            q.addFirst(n.getLeft());
            q.addFirst(n.getRight());
        }
    }

    private float[] setNodePValue(Node node, Dataset dataset, Vector[] categoryGroups, NodeScoreSetter nodeScoreSetter, int[][] numberOfSuccessesInPopulation, int[] populationSize) {
        int min = node.getMinIndex();
        int max = node.getMaxIndex();
        int sampleSize = max - min + 1;
        float[] scores = new float[categoryGroups.length];
        Arrays.fill(scores, 1.0f);
        if (sampleSize >= this.minClusterMembers) {
            int length = categoryGroups.length;
            for (int catIdx = 0; catIdx < length; ++catIdx) {
                Vector categoryGroup = categoryGroups[catIdx];
                Set categories = VectorUtil.getValues(categoryGroup);
                float minP = 1.0f;
                Map<Object, Integer> counts = NodeScoreCalculator.getCounts(dataset, min, max, categoryGroup);
                int i = 0;
                for (Object value : categories) {
                    int count = counts.get(value);
                    if (count >= this.minClusterMembersForCategory) {
                        HypergeometricDistribution dist = new HypergeometricDistribution(populationSize[catIdx], numberOfSuccessesInPopulation[catIdx][i], sampleSize);
                        float p = (float)dist.cumulativeProbability(count);
                        minP = Math.min(minP, p);
                        ++this.numberOfHypothesesTested;
                    }
                    ++i;
                }
                scores[catIdx] = minP;
            }
        }
        nodeScoreSetter.setScore(node, scores);
        return scores;
    }

    private void setNodeScore(Node node, Dataset dataset, NodeScoreSetter nodeScoreSetter, int[] numberOfSuccessesInPopulation, int populationSize) {
        int min = node.getMinIndex();
        int max = node.getMaxIndex();
        int sampleSize = max - min + 1;
        float[] scores = new float[dataset.getColumnCount()];
        if (sampleSize >= this.minClusterMembers) {
            int cols = dataset.getColumnCount();
            for (int j = 0; j < cols; ++j) {
                int count = 0;
                for (int i = min; i <= max; ++i) {
                    float value = dataset.getValue(i, j);
                    if (value == 0.0f) continue;
                    ++count;
                }
                HypergeometricDistribution dist = new HypergeometricDistribution(populationSize, numberOfSuccessesInPopulation[j], sampleSize);
                float p = (float)dist.cumulativeProbability(count);
                scores[j] = -LogScaleUnivariateFunction.log2(p);
            }
        }
        nodeScoreSetter.setScore(node, scores);
    }

    public static Map<Object, Integer> getCounts(Dataset dataset, int min, int max, Vector categoryGroup) {
        HashMap<Object, Integer> counts = new HashMap<Object, Integer>();
        Set categories = VectorUtil.getValues(categoryGroup);
        for (Object c : categories) {
            counts.put(c, 0);
        }
        for (int i = min; i <= max; ++i) {
            Object c;
            c = categoryGroup.getValue(i);
            if (c == null) continue;
            counts.put(c, (Integer)counts.get(c) + 1);
        }
        return counts;
    }

    public static interface NodeScoreSetter {
        public void setScore(Node var1, float[] var2);
    }
}

