/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.genee.clustering.hierarchical.algorithm;

import java.io.PrintWriter;
import java.util.Arrays;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.ClosestPairResult;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.LinkageMethod;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.Node;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.linkage.AverageLinkage;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.linkage.CompleteLinkage;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.linkage.LinkageAlgorithm;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.linkage.SingleLinkage;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.metrics.DistanceFunction;
import org.broadinstitute.genee.matrix.Dataset;
import org.broadinstitute.genee.stats.Sorting;

public class HCLCluster {
    private DistanceFunction distanceFunction;
    private float[][] distmatrix;
    private int[] indices;
    private String keyword = "";
    private String[] leftIds;
    private LinkageAlgorithm linkageAlgorithm;
    private String[] nodeID;
    private String[] rightIds;
    private Node[] tree;

    public HCLCluster(Dataset data, LinkageMethod linkage, DistanceFunction distanceFunction) {
        this(HCLCluster.computeDistanceMatrix(data, distanceFunction), linkage, distanceFunction, false);
    }

    public HCLCluster(float[][] distanceMatrix, LinkageMethod linkage, DistanceFunction distanceFunction, boolean scaleDistancesToOne) {
        int i;
        int nelements = distanceMatrix.length;
        int nNodes = nelements - 1;
        this.linkageAlgorithm = HCLCluster.createLinkageAlgorithm(linkage);
        this.distanceFunction = distanceFunction;
        this.distmatrix = distanceMatrix;
        this.tree = this.linkageAlgorithm.cluster(nelements, this.distmatrix);
        float[] nodeorder = new float[nNodes];
        int[] nodecounts = new int[nNodes];
        this.nodeID = new String[nNodes];
        float[] order = new float[nelements];
        for (i = 0; i < nelements; ++i) {
            order[i] = i;
        }
        this.leftIds = new String[nNodes];
        this.rightIds = new String[nNodes];
        for (i = 0; i < nNodes; ++i) {
            String ID2;
            int counts2;
            float order2;
            String ID1;
            int counts1;
            float order1;
            int min1 = this.tree[i].left;
            int min2 = this.tree[i].right;
            this.nodeID[i] = this.makeID("NODE", i + 1);
            if (min1 < 0) {
                int index1 = -min1 - 1;
                order1 = nodeorder[index1];
                counts1 = nodecounts[index1];
                ID1 = this.nodeID[index1];
                this.tree[i].distance = Math.max(this.tree[i].distance, this.tree[index1].distance);
            } else {
                order1 = order[min1];
                counts1 = 1;
                ID1 = this.makeID(this.keyword, min1);
            }
            if (min2 < 0) {
                int index2 = -min2 - 1;
                order2 = nodeorder[index2];
                counts2 = nodecounts[index2];
                ID2 = this.nodeID[index2];
                this.tree[i].distance = Math.max(this.tree[i].distance, this.tree[index2].distance);
            } else {
                order2 = order[min2];
                counts2 = 1;
                ID2 = this.makeID(this.keyword, min2);
            }
            this.leftIds[i] = ID1;
            this.rightIds[i] = ID2;
            nodecounts[i] = counts1 + counts2;
            nodeorder[i] = ((float)counts1 * order1 + (float)counts2 * order2) / (float)(counts1 + counts2);
        }
        this.treeSort(nNodes, order, nodeorder, nodecounts, this.tree);
    }

    public DistanceFunction getDistanceFunction() {
        return this.distanceFunction;
    }

    public float[][] getDistanceMatrix() {
        return this.distmatrix;
    }

    public int[] getIndices() {
        return this.indices;
    }

    public String[] getLeftIds() {
        return this.leftIds;
    }

    public String[] getNodeID() {
        return this.nodeID;
    }

    public String[] getReorderedIds() {
        String[] s = new String[this.indices.length];
        for (int i = 0; i < s.length; ++i) {
            String ID;
            s[i] = ID = this.makeID(this.keyword, this.indices[i]);
        }
        return s;
    }

    public String[] getRightIds() {
        return this.rightIds;
    }

    public Node[] getTree() {
        return this.tree;
    }

    public void writeAtrGtr(PrintWriter pw) {
        int length = this.nodeID.length;
        for (int i = 0; i < length; ++i) {
            pw.print(this.nodeID[i]);
            pw.print("\t");
            pw.print(this.leftIds[i]);
            pw.print("\t");
            pw.print(this.rightIds[i]);
            pw.print("\t");
            pw.println(this.tree[i].distance);
        }
        pw.flush();
    }

    protected String makeID(String name, int i) {
        return name + i + "X";
    }

    void treeSort(int nNodes, float[] order, float[] nodeorder, int[] nodecounts, Node[] tree) {
        int i;
        int nElements = nNodes + 1;
        float[] neworder = new float[nElements];
        int[] clusterids = new int[nElements];
        for (i = 0; i < nElements; ++i) {
            clusterids[i] = i;
        }
        for (i = 0; i < nNodes; ++i) {
            int clusterid;
            int j;
            float increase;
            int count2;
            int i1 = tree[i].left;
            int i2 = tree[i].right;
            float order1 = i1 < 0 ? nodeorder[-i1 - 1] : order[i1];
            float order2 = i2 < 0 ? nodeorder[-i2 - 1] : order[i2];
            int count1 = i1 < 0 ? nodecounts[-i1 - 1] : 1;
            int n = count2 = i2 < 0 ? nodecounts[-i2 - 1] : 1;
            if (i1 < i2) {
                increase = order1 < order2 ? (float)count1 : (float)count2;
                for (j = 0; j < nElements; ++j) {
                    clusterid = clusterids[j];
                    if (clusterid == i1 && order1 >= order2) {
                        int n2 = j;
                        neworder[n2] = neworder[n2] + increase;
                    }
                    if (clusterid == i2 && order1 < order2) {
                        int n3 = j;
                        neworder[n3] = neworder[n3] + increase;
                    }
                    if (clusterid != i1 && clusterid != i2) continue;
                    clusterids[j] = -i - 1;
                }
                continue;
            }
            increase = order1 <= order2 ? (float)count1 : (float)count2;
            for (j = 0; j < nElements; ++j) {
                clusterid = clusterids[j];
                if (clusterid == i1 && order1 > order2) {
                    int n4 = j;
                    neworder[n4] = neworder[n4] + increase;
                }
                if (clusterid == i2 && order1 <= order2) {
                    int n5 = j;
                    neworder[n5] = neworder[n5] + increase;
                }
                if (clusterid != i1 && clusterid != i2) continue;
                clusterids[j] = -i - 1;
            }
        }
        this.indices = Sorting.index(neworder, true);
    }

    public static float[][] computeDistanceMatrix(Dataset data, DistanceFunction distanceFunction) {
        int i;
        float[][] matrix = new float[data.getRowCount()][];
        int n = data.getRowCount();
        float[] weights = new float[data.getColumnCount()];
        Arrays.fill(weights, 1.0f);
        for (i = 1; i < n; ++i) {
            matrix[i] = new float[i];
        }
        for (i = 1; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                matrix[i][j] = distanceFunction.evaluate(data, data, weights, i, j);
            }
        }
        return matrix;
    }

    public static void find_closest_pair(int n, float[][] distmatrix, ClosestPairResult r) {
        float distance = distmatrix[1][0];
        int ip = 1;
        int jp = 0;
        for (int i = 1; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                float temp = distmatrix[i][j];
                if (!(temp < distance)) continue;
                distance = temp;
                ip = i;
                jp = j;
            }
        }
        r.distance = distance;
        r.ip = ip;
        r.jp = jp;
    }

    private static LinkageAlgorithm createLinkageAlgorithm(LinkageMethod method) {
        switch (method) {
            case SINGLE_LINKAGE: {
                return new SingleLinkage();
            }
            case COMPLETE_LINKAGE: {
                return new CompleteLinkage();
            }
            case AVERAGE_LINKAGE: {
                return new AverageLinkage();
            }
        }
        throw new RuntimeException("Unknown linkage method " + (Object)((Object)method));
    }
}

