/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.genee.compound;

import com.jgoodies.forms.factories.CC;
import com.jgoodies.forms.layout.FormLayout;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.awt.Component;
import java.awt.LayoutManager;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import javax.swing.JCheckBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.tree.MutableTreeNode;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optimization.ConvergenceChecker;
import org.apache.commons.math3.optimization.GoalType;
import org.apache.commons.math3.optimization.PointValuePair;
import org.apache.commons.math3.optimization.direct.AbstractSimplex;
import org.apache.commons.math3.optimization.direct.NelderMeadSimplex;
import org.apache.commons.math3.optimization.direct.SimplexOptimizer;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.rank.Max;
import org.apache.commons.math3.stat.descriptive.rank.Min;
import org.broadinstitute.genee.application.Application;
import org.broadinstitute.genee.application.ComponentCustomizer;
import org.broadinstitute.genee.application.GENEEFolderNode;
import org.broadinstitute.genee.application.History;
import org.broadinstitute.genee.application.ProjectGENEEResultTreeNode;
import org.broadinstitute.genee.gui.AbstractInputAction;
import org.broadinstitute.genee.gui.ProgressNotifier;
import org.broadinstitute.genee.gui.UIUtil;
import org.broadinstitute.genee.gui.parameters.CheckBoxInputLabel;
import org.broadinstitute.genee.gui.parameters.CheckBoxListParameter;
import org.broadinstitute.genee.gui.parameters.CheckBoxParameter;
import org.broadinstitute.genee.gui.parameters.HiddenInputLabel;
import org.broadinstitute.genee.gui.parameters.MetadataComboBoxInputParameter;
import org.broadinstitute.genee.gui.parameters.TextFieldInputParameter;
import org.broadinstitute.genee.heatmap.CurveElementList;
import org.broadinstitute.genee.heatmap.CurveElementPainter;
import org.broadinstitute.genee.heatmap.DefaultCurveElement;
import org.broadinstitute.genee.heatmap.DefaultProject;
import org.broadinstitute.genee.heatmap.HeatMapPanel;
import org.broadinstitute.genee.heatmap.Project;
import org.broadinstitute.genee.io.util.Formatter;
import org.broadinstitute.genee.io.util.IOUtil;
import org.broadinstitute.genee.io.util.ParserHelper;
import org.broadinstitute.genee.math.FloatListStatUtils;
import org.broadinstitute.genee.math.stat.CollapseDataset;
import org.broadinstitute.genee.math.stat.function.UnivariateFloatFunction;
import org.broadinstitute.genee.matrix.Dataset;
import org.broadinstitute.genee.matrix.Identifier;
import org.broadinstitute.genee.matrix.MetadataUtil;
import org.broadinstitute.genee.matrix.RowMajorArray2DDataset;
import org.broadinstitute.genee.matrix.TroveFloatList;
import org.broadinstitute.genee.matrix.Vector;
import org.broadinstitute.genee.matrix.VectorUtil;

public class CurveFittingAction
extends AbstractInputAction {
    private static final String CONCENTRATION_OR_TIME = "Curve x variable";
    private static final String CURVE_DEF_PARAMETER = "Each curve has the same";
    private static final String TOP_FIXED_PARAMETER = "Fix top to";
    private static final String BOTTOM_FIXED_PARAMETER = "Fix bottom to";
    private static final String AUC = "Compute AUC using trapezoid rule";
    private MetadataComboBoxInputParameter curveXVariableInputParameter = new MetadataComboBoxInputParameter(true);
    private CheckBoxListParameter rowMetadataSelectorParameter;
    private boolean firstTime = true;

    public CurveFittingAction() {
        super("IC\u2085\u2080 (Beta)");
        this.curveXVariableInputParameter.setClassFilter(Number.class);
        this.addParameter(CONCENTRATION_OR_TIME, this.curveXVariableInputParameter, true);
        this.rowMetadataSelectorParameter = new CheckBoxListParameter();
        this.addParameter(CURVE_DEF_PARAMETER, this.rowMetadataSelectorParameter, true);
        CheckBoxInputLabel fixedTopLabel = new CheckBoxInputLabel(TOP_FIXED_PARAMETER);
        this.addParameter(fixedTopLabel, new TextFieldInputParameter().setDependsOn((JCheckBox)fixedTopLabel.getComponent()), false);
        CheckBoxInputLabel fixedBottomLabel = new CheckBoxInputLabel(BOTTOM_FIXED_PARAMETER);
        this.addParameter(fixedBottomLabel, new TextFieldInputParameter().setDependsOn((JCheckBox)fixedBottomLabel.getComponent()), false);
        final CheckBoxParameter aucCheckBox = new CheckBoxParameter(AUC);
        this.addParameter(new HiddenInputLabel(AUC), aucCheckBox, true);
        aucCheckBox.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                CurveFittingAction.this.getInputPanelBuilder().setEnabled(CurveFittingAction.TOP_FIXED_PARAMETER, !aucCheckBox.isSelected());
                CurveFittingAction.this.getInputPanelBuilder().setEnabled(CurveFittingAction.BOTTOM_FIXED_PARAMETER, !aucCheckBox.isSelected());
            }
        });
    }

    @Override
    public boolean beforeWindowShown() {
        Project project = Application.getProject();
        if (project != null) {
            String[] rowFields = MetadataUtil.getNames(project.getOriginalDataset().getRowMetadata()).toArray(new String[0]);
            Arrays.sort(rowFields, String.CASE_INSENSITIVE_ORDER);
            this.rowMetadataSelectorParameter.init(rowFields);
            if (this.rowMetadataSelectorParameter.isSelectionEmpty()) {
                this.rowMetadataSelectorParameter.setValueFromString("id", false);
            }
            this.curveXVariableInputParameter.init(project);
            if (this.firstTime || this.curveXVariableInputParameter.getSelectedIndex() == -1) {
                this.curveXVariableInputParameter.setValueFromString("dose", false);
                this.firstTime = false;
            }
        }
        return true;
    }

    @Override
    protected MutableTreeNode execute(Map<String, Object> map, ProgressNotifier status) throws Exception {
        Map<Identifier, TIntArrayList> curveDefinitionFieldsToDatasetRowIndices;
        Dataset dataset = (Dataset)map.get("dataset");
        String concentrationOrTimeField = (String)map.get(CONCENTRATION_OR_TIME);
        Object[] tmp = (Object[])map.get(CURVE_DEF_PARAMETER);
        String[] curveDefinitionFields = new String[tmp.length];
        for (int i = 0; i < curveDefinitionFields.length; ++i) {
            curveDefinitionFields[i] = (String)tmp[i];
        }
        Vector[] vectors = new Vector[curveDefinitionFields.length];
        for (int i = 0; i < vectors.length; ++i) {
            vectors[i] = dataset.getRowMetadata().get(curveDefinitionFields[i]);
            if (vectors[i] != null) continue;
            throw new NullPointerException(curveDefinitionFields[i] + " not found. Available metadata fields: " + MetadataUtil.getNames(dataset.getRowMetadata()));
        }
        Vector concentrationOrTimeVector = dataset.getRowMetadata().get(concentrationOrTimeField);
        if (concentrationOrTimeVector == null) {
            throw new NullPointerException(concentrationOrTimeField + " not found. Available metadata fields: " + MetadataUtil.getNames(dataset.getRowMetadata()));
        }
        if (vectors.length == 0) {
            curveDefinitionFieldsToDatasetRowIndices = new HashMap<Identifier, TIntArrayList>();
            int nrows = dataset.getRowCount();
            for (int i = 0; i < nrows; ++i) {
                TIntArrayList list = new TIntArrayList(1);
                list.add(i);
                curveDefinitionFieldsToDatasetRowIndices.put(new Identifier<Integer>(i), list);
            }
        } else {
            curveDefinitionFieldsToDatasetRowIndices = VectorUtil.createValuesToIndicesMap(vectors);
        }
        boolean auc = (Boolean)map.get(AUC);
        double topFixed = Double.NaN;
        double bottomFixed = Double.NaN;
        if (!auc) {
            String bottomString;
            String topString = (String)map.get(TOP_FIXED_PARAMETER);
            if (topString != null && !"".equals(topString)) {
                try {
                    topFixed = ParserHelper.parseFloat(topString);
                }
                catch (NumberFormatException nfe) {
                    throw new IllegalArgumentException("Top is not a number.");
                }
            }
            if ((bottomString = (String)map.get(BOTTOM_FIXED_PARAMETER)) != null && !"".equals(bottomString)) {
                try {
                    bottomFixed = ParserHelper.parseFloat(bottomString);
                }
                catch (NumberFormatException nfe) {
                    throw new IllegalArgumentException("Bottom is not a number.");
                }
            }
        }
        RowMajorArray2DDataset newDataset = new RowMajorArray2DDataset(dataset.getName(), curveDefinitionFieldsToDatasetRowIndices.size(), dataset.getColumnCount());
        int curveSeriesIndex = !auc ? newDataset.addSeries("Curve", CurveElementList.class) : -1;
        int ncols = newDataset.getColumnCount();
        int newDatasetRowIndex = 0;
        for (Identifier id : curveDefinitionFieldsToDatasetRowIndices.keySet()) {
            TIntArrayList rowIndices = curveDefinitionFieldsToDatasetRowIndices.get(id);
            for (int j = 0; j < ncols; ++j) {
                TreeMap<Float, TFloatArrayList> concentrationToValues = new TreeMap<Float, TFloatArrayList>();
                int size = rowIndices.size();
                for (int i = 0; i < size; ++i) {
                    int rowIndex = rowIndices.getQuick(i);
                    float conc = ((Number)concentrationOrTimeVector.getValue(rowIndex)).floatValue();
                    float value = dataset.getValue(rowIndex, j);
                    if (Float.isNaN(value)) continue;
                    TFloatArrayList values = (TFloatArrayList)concentrationToValues.get(Float.valueOf(conc));
                    if (values == null) {
                        values = new TFloatArrayList();
                        concentrationToValues.put(Float.valueOf(conc), values);
                    }
                    values.add(value);
                }
                double[] x = new double[concentrationToValues.size()];
                double[] y = new double[x.length];
                float[] yError = new float[x.length];
                int index = 0;
                for (Float conc : concentrationToValues.keySet()) {
                    x[index] = conc.floatValue();
                    TroveFloatList l = new TroveFloatList((TFloatArrayList)concentrationToValues.get(conc));
                    y[index] = FloatListStatUtils.median(l);
                    yError[index] = FloatListStatUtils.stdev(l);
                    ++index;
                }
                if (x.length > 0) {
                    double[] start;
                    FourParamDoseResponseSumOfSquares ssFunction;
                    double hillSlopeStart;
                    if (auc) {
                        newDataset.setValue(newDatasetRowIndex, j, (float)CurveFittingAction.trapz(x, y));
                        continue;
                    }
                    int maxEval = 8000;
                    SimplexOptimizer opt = new SimplexOptimizer((ConvergenceChecker)new ConvergenceChecker<PointValuePair>(){

                        public boolean converged(int iteration, PointValuePair previous, PointValuePair current) {
                            double ss = (Double)current.getValue();
                            return ss < 1.0E-9 || iteration == 7999;
                        }
                    });
                    double bottomStart = new Min().evaluate(y);
                    double topStart = new Max().evaluate(y);
                    double ic50Start = new Mean().evaluate(x);
                    double d = hillSlopeStart = topStart - bottomStart >= 0.0 ? 1.0 : -1.0;
                    if (!Double.isNaN(bottomFixed) && !Double.isNaN(topFixed)) {
                        ssFunction = new FixedMinMaxFourParamDoseResponseSumOfSquares(bottomFixed, topFixed, x, y);
                        start = new double[]{hillSlopeStart, ic50Start};
                    } else if (!Double.isNaN(bottomFixed)) {
                        ssFunction = new FixedMinFourParamDoseResponseSumOfSquares(bottomFixed, x, y);
                        start = new double[]{hillSlopeStart, ic50Start, topStart};
                    } else if (!Double.isNaN(topFixed)) {
                        ssFunction = new FixedMaxFourParamDoseResponseSumOfSquares(topFixed, x, y);
                        start = new double[]{bottomStart, hillSlopeStart, ic50Start};
                    } else {
                        start = new double[]{bottomStart, hillSlopeStart, ic50Start, topStart};
                        ssFunction = new FourParamDoseResponseSumOfSquares(x, y);
                    }
                    NelderMeadSimplex simplex = new NelderMeadSimplex(start.length);
                    opt.setSimplex((AbstractSimplex)simplex);
                    PointValuePair optResult = opt.optimize(Integer.MAX_VALUE, (MultivariateFunction)ssFunction, GoalType.MINIMIZE, start);
                    double[] optArray = optResult.getPoint();
                    final double A = ssFunction.getA(optArray);
                    final double B = ssFunction.getB(optArray);
                    final double C = ssFunction.getC(optArray);
                    final double D = ssFunction.getD(optArray);
                    final double ss = ssFunction.value(optArray);
                    newDataset.setValue(newDatasetRowIndex, j, (float)C);
                    DefaultCurveElement element = new DefaultCurveElement(){

                        @Override
                        public String toString() {
                            StringBuilder sb = new StringBuilder(super.toString());
                            sb.append("<br>Bottom: ");
                            sb.append(Formatter.format(A));
                            sb.append("<br>Top: ");
                            sb.append(Formatter.format(D));
                            sb.append("<br>HillSlope: ");
                            sb.append(Formatter.format(B));
                            sb.append("<br>IC50: ");
                            sb.append(Formatter.format(C));
                            sb.append("<br>sum-of-squares: ");
                            sb.append(Formatter.format(ss));
                            return sb.toString();
                        }
                    };
                    element.setX(IOUtil.toFloat(x));
                    element.setY(IOUtil.toFloat(y), yError);
                    element.setFunction(new FourParamDoseResponseUnivariateFloatFunction(A, B, C, D));
                    element.setSpecialX((float)C);
                    newDataset.setObjectValue(newDatasetRowIndex, j, element, curveSeriesIndex);
                    continue;
                }
                newDataset.setValue(newDatasetRowIndex, j, Float.NaN);
            }
            ++newDatasetRowIndex;
        }
        MetadataUtil.copy(dataset.getColumnMetadata(), newDataset.getColumnMetadata());
        CollapseDataset.setMetadata(dataset.getRowMetadata(), newDataset.getRowMetadata(), curveDefinitionFields, curveDefinitionFieldsToDatasetRowIndices);
        History history = this.getHistory(map);
        ComponentCustomizer c = new ComponentCustomizer(){

            @Override
            public void customize(Component c) {
                HeatMapPanel heatMapPanel = (HeatMapPanel)c;
                heatMapPanel.setRowSize(50.0f);
                heatMapPanel.setColumnSize(50.0f);
                CurveElementPainter painter = new CurveElementPainter();
                heatMapPanel.setElementPainter(painter);
            }

            @Override
            public boolean inherits() {
                return false;
            }
        };
        ProjectGENEEResultTreeNode node = new ProjectGENEEResultTreeNode((Project)new DefaultProject(newDataset), history, c);
        return new GENEEFolderNode((String)map.get("Operation"), history, node);
    }

    public static double trapz(double[] x, double[] y) {
        double sum = 0.0;
        for (int k = 0; k < x.length - 1; ++k) {
            sum += 0.5 * (x[k + 1] - x[k]) * (y[k + 1] + y[k]);
        }
        return sum;
    }

    private MetadataComboBoxInputParameter createFields(String checkBoxFieldName, String metadataFieldName, String textFieldFieldName) {
        final CheckBoxParameter checkbox = new CheckBoxParameter(checkBoxFieldName);
        checkbox.setBorder(null);
        MetadataComboBoxInputParameter metadataParameter = new MetadataComboBoxInputParameter(true);
        final JPanel substractBackgroundPanel = new JPanel((LayoutManager)new FormLayout("left:p, 6px, p, 6px, p, 6px, p", "p"));
        substractBackgroundPanel.add((Component)checkbox, CC.xy((int)1, (int)1));
        substractBackgroundPanel.add((Component)metadataParameter.getJComponent(), CC.xy((int)3, (int)1));
        substractBackgroundPanel.add((Component)new JLabel("equals"), CC.xy((int)5, (int)1));
        TextFieldInputParameter textField = new TextFieldInputParameter(10);
        substractBackgroundPanel.add((Component)textField, CC.xy((int)7, (int)1));
        checkbox.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                UIUtil.setChildrenEnabledRecursively(substractBackgroundPanel, checkbox.isSelected());
                checkbox.setEnabled(true);
            }
        });
        UIUtil.setChildrenEnabledRecursively(substractBackgroundPanel, checkbox.isSelected());
        checkbox.setEnabled(true);
        this.getInputPanelBuilder().addToParameterMap(new HiddenInputLabel(checkBoxFieldName), checkbox, false);
        this.getInputPanelBuilder().addToParameterMap(new HiddenInputLabel(metadataFieldName), metadataParameter, false);
        this.getInputPanelBuilder().addToParameterMap(new HiddenInputLabel(textFieldFieldName), textField, false);
        this.getInputPanelBuilder().getFormBuilder().addSpanned(substractBackgroundPanel);
        this.getInputPanelBuilder().getFormBuilder().nextRow();
        return metadataParameter;
    }

    private static class FourParamDoseResponseUnivariateFloatFunction
    implements UnivariateFloatFunction {
        private double A;
        private double B;
        private double C;
        private double D;

        public FourParamDoseResponseUnivariateFloatFunction(double a, double b, double c, double d) {
            this.A = a;
            this.B = b;
            this.C = c;
            this.D = d;
        }

        @Override
        public float evaluate(float x) {
            double y = (this.A - this.D) / (1.0 + Math.pow((double)x / this.C, this.B)) + this.D;
            return (float)y;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("(");
            sb.append(Formatter.format(this.A));
            sb.append(" - ");
            sb.append(Formatter.format(this.D));
            sb.append(") / 1 + x/");
            sb.append(Formatter.format(this.C));
            sb.append("^");
            sb.append(Formatter.format(this.B));
            sb.append(" + ");
            sb.append(Formatter.format(this.D));
            return sb.toString();
        }
    }

    private static class FourParamDoseResponseSumOfSquares
    implements MultivariateFunction {
        private double[] doseOrTime;
        private double[] response;

        public FourParamDoseResponseSumOfSquares(double[] doseOrTime, double[] response) {
            this.doseOrTime = doseOrTime;
            this.response = response;
        }

        public double value(double[] point) {
            double A = this.getA(point);
            double B = this.getB(point);
            double C = this.getC(point);
            double D = this.getD(point);
            int length = point.length;
            for (int i = 0; i < length; ++i) {
                if (!(point[i] < 0.0)) continue;
                return Double.MAX_VALUE;
            }
            double ss = 0.0;
            int length2 = this.doseOrTime.length;
            for (int i = 0; i < length2; ++i) {
                double x = this.doseOrTime[i];
                double y = (A - D) / (1.0 + Math.pow(x / C, B)) + D;
                double actual = this.response[i];
                double square = Math.pow(actual - y, 2.0);
                ss += square;
            }
            return ss;
        }

        protected double getA(double[] point) {
            return point[0];
        }

        protected double getB(double[] point) {
            return point[1];
        }

        protected double getC(double[] point) {
            return point[2];
        }

        protected double getD(double[] point) {
            return point[3];
        }
    }

    private static class FixedMinMaxFourParamDoseResponseSumOfSquares
    extends FourParamDoseResponseSumOfSquares {
        private double D;
        private double A;

        public FixedMinMaxFourParamDoseResponseSumOfSquares(double A, double D, double[] doseOrTime, double[] response) {
            super(doseOrTime, response);
            this.A = A;
            this.D = D;
        }

        @Override
        protected double getA(double[] point) {
            return this.A;
        }

        @Override
        protected double getB(double[] point) {
            return point[0];
        }

        @Override
        protected double getC(double[] point) {
            return point[1];
        }

        @Override
        protected double getD(double[] point) {
            return this.D;
        }
    }

    private static class FixedMinFourParamDoseResponseSumOfSquares
    extends FourParamDoseResponseSumOfSquares {
        private double A;

        public FixedMinFourParamDoseResponseSumOfSquares(double A, double[] doseOrTime, double[] response) {
            super(doseOrTime, response);
            this.A = A;
        }

        @Override
        protected double getA(double[] point) {
            return this.A;
        }

        @Override
        protected double getB(double[] point) {
            return point[0];
        }

        @Override
        protected double getC(double[] point) {
            return point[1];
        }

        @Override
        protected double getD(double[] point) {
            return point[2];
        }
    }

    private static class FixedMaxFourParamDoseResponseSumOfSquares
    extends FourParamDoseResponseSumOfSquares {
        private double D;

        public FixedMaxFourParamDoseResponseSumOfSquares(double D, double[] doseOrTime, double[] response) {
            super(doseOrTime, response);
            this.D = D;
        }

        @Override
        protected double getA(double[] point) {
            return point[0];
        }

        @Override
        protected double getB(double[] point) {
            return point[1];
        }

        @Override
        protected double getC(double[] point) {
            return point[2];
        }

        @Override
        protected double getD(double[] point) {
            return this.D;
        }
    }
}

