/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.genee.clustering.cc;

import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import org.broadinstitute.genee.clustering.cc.ClusteringAlgorithm;
import org.broadinstitute.genee.clustering.hierarchical.algorithm.metrics.DistanceFunction;
import org.broadinstitute.genee.clustering.nmf.NMF;
import org.broadinstitute.genee.matrix.Dataset;
import org.broadinstitute.genee.matrix.DatasetUtil;

public class NMFClusterer
implements ClusteringAlgorithm {
    private int[][] clusters;
    private Dataset dataset;
    private int iterations = 2000;
    private int k;

    @Override
    public void execute() {
        int i;
        float[][] array = new float[this.dataset.getRowCount()][this.dataset.getColumnCount()];
        DatasetUtil.copy(this.dataset, array);
        NMF nmf = new NMF(array, this.k, System.currentTimeMillis());
        nmf.setMaxNumIter(this.iterations);
        nmf.runNMF(true);
        int[] membership = nmf.whichMaxW();
        ArrayList<TIntArrayList> tmp = new ArrayList<TIntArrayList>();
        for (i = 0; i < this.k; ++i) {
            tmp.add(new TIntArrayList());
        }
        for (i = 0; i < membership.length; ++i) {
            int clusterNumber = membership[i];
            ((TIntArrayList)tmp.get(clusterNumber)).add(i);
        }
        this.clusters = new int[this.k][];
        for (i = 0; i < this.k; ++i) {
            this.clusters[i] = ((TIntArrayList)tmp.get(i)).toArray();
        }
    }

    @Override
    public int[] getIndices(int clusterIndex) {
        return this.clusters[clusterIndex];
    }

    @Override
    public void setDataset(Dataset dataset) {
        this.dataset = dataset;
    }

    @Override
    public void setDistanceFunction(DistanceFunction distanceFunction) {
    }

    @Override
    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    @Override
    public void setNumberOfClusters(int k) {
        this.k = k;
    }

    public static boolean containsNegativeValues(Dataset dataset) {
        int rows = dataset.getRowCount();
        for (int i = 0; i < rows; ++i) {
            int cols = dataset.getColumnCount();
            for (int j = 0; j < cols; ++j) {
                if (!(dataset.getValue(i, j) < 0.0f)) continue;
                return true;
            }
        }
        return false;
    }
}

