/*
 * Decompiled with CFR 0.152.
 */
package chemaxon.calculations.training.pls;

import chemaxon.calculations.training.DescriptorMatrix;
import chemaxon.calculations.training.FittingAlgorithm;
import chemaxon.calculations.training.Stat;
import chemaxon.calculations.training.pls.PLSFittingAlgorithm;
import chemaxon.calculations.training.pls.PLSPredictionAlgorithm;
import chemaxon.calculations.training.pls.PLSUtil;
import chemaxon.common.util.DoubleVector;
import chemaxon.marvin.modelling.debug.ErrPrint;
import java.util.ArrayList;
import java.util.BitSet;

public class PLSDescriptorSelector
implements FittingAlgorithm<double[]> {
    double[][] trainX;
    double[] trainY;
    boolean debug = false;

    @Override
    public void setDescriptorMatrix(DescriptorMatrix matrix) {
        this.trainX = new double[matrix.getRowCount()][];
        for (int i = 0; i < matrix.getRowCount(); ++i) {
            this.trainX[i] = new double[matrix.getRow(i).length];
            System.arraycopy(matrix.getRow(i), 0, this.trainX[i], 0, this.trainX[i].length);
        }
    }

    public void setDebug(boolean b) {
        this.debug = b;
    }

    @Override
    public void setExperimentalValues(double[] values) {
        this.trainY = new double[values.length];
        System.arraycopy(values, 0, this.trainY, 0, values.length);
    }

    @Override
    public double[] getResult() {
        if (this.trainX == null || this.trainX[0] == null) {
            throw new RuntimeException("Null descriptormatrix");
        }
        double q2 = PLSUtil.calculateQ2(this.trainX, this.trainY);
        double newq2 = 0.0;
        BitSet bs = new BitSet();
        bs.set(0, this.trainX[0].length, true);
        int leftOutDescriptorCount = 1;
        while (true) {
            BitSet runningBs = new BitSet();
            runningBs.or(bs);
            double[][] runningMatrix = new double[this.trainX.length][runningBs.cardinality() - 1];
            newq2 = this.calculateBestQ2LeavingOneFromSelection(runningBs, runningMatrix);
            if (newq2 < q2) break;
            bs = runningBs;
            q2 = newq2;
            ++leftOutDescriptorCount;
        }
        double[] result = this.createResultFromSelection(bs);
        return result;
    }

    public double[][] decreaseDescriptorCountTo(int minNumberOfDescriptors) {
        if (this.trainX == null || this.trainX[0] == null) {
            throw new RuntimeException("Null descriptormatrix");
        }
        if (minNumberOfDescriptors < 1) {
            throw new IllegalArgumentException("At least one descriptor must be in the model.");
        }
        BitSet bs = new BitSet();
        bs.set(0, this.trainX[0].length, true);
        int leftOutDescriptorCount = 0;
        ArrayList<double[]> result = new ArrayList<double[]>();
        double[] origResult = this.buildModel(this.trainX, this.trainY);
        result.add(origResult);
        DoubleVector q2s = new DoubleVector();
        double oq2 = Stat.pearsonR2(origResult, this.trainY);
        q2s.addElement(oq2);
        ErrPrint.errPrint("q2 at " + leftOutDescriptorCount + "th left out descriptor: " + oq2 + "." + origResult.length + " coeffs:", origResult);
        do {
            double[][] runningMatrix = new double[this.trainX.length][bs.cardinality() - 1];
            double q2 = this.calculateBestQ2LeavingOneFromSelection(bs, runningMatrix);
            double[] runningResult = this.createResultFromSelection(bs);
            q2s.addElement(q2);
            result.add(runningResult);
            ++leftOutDescriptorCount;
            if (!this.debug) continue;
            ErrPrint.errPrint("q2 at " + leftOutDescriptorCount + "th left out descriptor: " + q2 + "." + runningResult.length + " coeffs:", runningResult);
            PLSPredictionAlgorithm pls = new PLSPredictionAlgorithm();
            pls.setCoefficients(runningResult);
            double[] trainPred = new double[this.trainX.length];
            for (int i = 0; i < this.trainX.length; ++i) {
                trainPred[i] = pls.predict(this.trainX[i]);
            }
            ErrPrint.errPrint("train w2 at " + leftOutDescriptorCount + "th left out descriptor: " + Stat.pearsonR2(this.trainY, trainPred) + "." + trainPred.length + " trainPred:", trainPred);
        } while (bs.cardinality() > minNumberOfDescriptors);
        double[][] res = new double[result.size()][this.trainX[0].length + 1];
        result.toArray((T[])res);
        return res;
    }

    private double[] buildModel(double[][] descriptorMatrix, double[] response) {
        PLSFittingAlgorithm pls = new PLSFittingAlgorithm(descriptorMatrix, response);
        pls.setMaxComponents((int)((double)descriptorMatrix[0].length * 0.6));
        return pls.getResult();
    }

    private double[] createResultFromSelection(BitSet bs) {
        double[][] resultMatrix = new double[this.trainX.length][bs.cardinality()];
        PLSUtil.copySelectedColoumns(this.trainX, bs, resultMatrix);
        PLSFittingAlgorithm pls = new PLSFittingAlgorithm(resultMatrix, this.trainY);
        pls.setMaxComponents((int)((double)bs.cardinality() * 0.6));
        double[] selectedResult = pls.getResult();
        int selectionIndex = 0;
        double[] result = new double[this.trainX[0].length + 1];
        for (int i = 0; i < result.length; ++i) {
            if (bs.get(i)) {
                result[i] = selectedResult[selectionIndex];
                ++selectionIndex;
                continue;
            }
            result[i] = 0.0;
        }
        result[result.length - 1] = selectedResult[selectedResult.length - 1];
        return result;
    }

    private double calculateBestQ2LeavingOneFromSelection(BitSet bs, double[][] runningMatrix) {
        double bestQ2 = Double.NEGATIVE_INFINITY;
        int bestQ2Index = Integer.MIN_VALUE;
        int i = bs.nextSetBit(0);
        while (i > -1) {
            bs.set(i, false);
            PLSUtil.copySelectedColoumns(this.trainX, bs, runningMatrix);
            if (this.debug) {
                System.err.print(".");
            }
            bs.set(i, true);
            double newq2 = PLSUtil.calculateQ2(runningMatrix, this.trainY, this.debug);
            if (newq2 > bestQ2) {
                bestQ2 = newq2;
                bestQ2Index = i;
            }
            System.err.println(newq2);
            i = bs.nextSetBit(i + 1);
        }
        if (this.debug) {
            System.err.print("\n");
        }
        bs.set(bestQ2Index, false);
        return bestQ2;
    }

    @Override
    public float getProgress() {
        return 0.0f;
    }
}

