/*
 * Decompiled with CFR 0.152.
 */
package chemaxon.calculations.training;

import chemaxon.calculations.training.DescriptorMatrix;
import chemaxon.calculations.training.FittingAlgorithm;
import chemaxon.calculations.training.FullMatrix;
import chemaxon.calculations.training.FutureCalculation;
import chemaxon.calculations.training.FutureTrainer;
import chemaxon.calculations.training.LongRunningTask;
import chemaxon.calculations.training.MatrixWrapper;
import chemaxon.calculations.training.PredictionAlgorithm;
import chemaxon.calculations.training.Stat;
import chemaxon.calculations.training.TrainingModel;
import chemaxon.calculations.training.Validator;
import java.beans.PropertyChangeListener;
import java.util.ArrayList;
import java.util.List;

public class CrossValidator<TT, CT>
implements Validator<TT> {
    protected CrossValidationCalculation<TT, CT> crossValidation;

    protected CrossValidator(List<FittingAlgorithm<CT>> fittingAlgorithms, List<TrainingModel<TT>> trainingModels, DescriptorMatrix descriptorMatrix, List<PredictionAlgorithm<CT>> predictionAlgorithms) {
        this(new CrossValidationTask<TT, CT>(fittingAlgorithms, trainingModels, descriptorMatrix, predictionAlgorithms));
    }

    protected CrossValidator(CrossValidationTask<TT, CT> task) {
        this.crossValidation = new CrossValidationCalculation<TT, CT>(task);
    }

    @Override
    public boolean add(TT data, double experimentalValue) {
        double[] descriptors = this.crossValidation.getModel().getDescriptors(data, false);
        if (descriptors.length > 0) {
            this.crossValidation.getMatrix().insertRow(this.crossValidation.getMatrix().getRowCount(), descriptors);
            this.crossValidation.getList().add(experimentalValue);
            return true;
        }
        return false;
    }

    public double[] getPrediction() {
        return (double[])this.crossValidation.getAndWait();
    }

    @Override
    public double validate() {
        double[] prediction = this.getPrediction();
        return prediction == null ? Double.NaN : Stat.pearsonR2(this.crossValidation.getExperimentalValues(), prediction);
    }

    public double getQ2() {
        return this.validate();
    }

    public float getProgress() {
        return this.crossValidation.getProgress();
    }

    public boolean isCancelled() {
        return this.crossValidation.isCancelled();
    }

    public void cancel() {
        this.crossValidation.cancel();
    }

    public boolean isDone() {
        return this.crossValidation.isDone();
    }

    public void addPropertyChangeListener(PropertyChangeListener listener) {
        this.crossValidation.addPropertyChangeListener(listener);
    }

    public void addPropertyChangeListener(String propertyName, PropertyChangeListener listener) {
        this.crossValidation.addPropertyChangeListener(propertyName, listener);
    }

    public PropertyChangeListener[] getPropertyChangeListeners() {
        return this.crossValidation.getPropertyChangeListeners();
    }

    public PropertyChangeListener[] getPropertyChangeListeners(String propertyName) {
        return this.crossValidation.getPropertyChangeListeners(propertyName);
    }

    public void removePropertyChangeListener(PropertyChangeListener listener) {
        this.crossValidation.removePropertyChangeListener(listener);
    }

    public void removePropertyChangeListener(String propertyName, PropertyChangeListener listener) {
        this.crossValidation.removePropertyChangeListener(propertyName, listener);
    }

    protected static class CrossValidationTask<TT, CT>
    implements LongRunningTask<double[]> {
        protected final List<FittingAlgorithm<CT>> fittingAlgorithms;
        protected final List<TrainingModel<TT>> trainingModels;
        protected final DescriptorMatrix descriptorMatrix;
        protected final List<PredictionAlgorithm<CT>> predictionAlgorithms;
        protected final List<Double> experimentalValues = new ArrayList<Double>();
        protected int iteration = 0;
        protected int counter = -1;
        protected int threadCount = 1;

        public CrossValidationTask(List<FittingAlgorithm<CT>> fittingAlgorithms, List<TrainingModel<TT>> trainingModels, DescriptorMatrix descriptorMatrix, List<PredictionAlgorithm<CT>> predictionAlgorithms) {
            this.fittingAlgorithms = fittingAlgorithms;
            this.trainingModels = trainingModels;
            this.descriptorMatrix = descriptorMatrix;
            this.predictionAlgorithms = predictionAlgorithms;
            this.threadCount = fittingAlgorithms == null ? 1 : fittingAlgorithms.size();
        }

        @Override
        public double[] call() throws Exception {
            final double[] predictedValues = new double[this.descriptorMatrix.getRowCount()];
            this.iteration = predictedValues.length;
            this.counter = 0;
            while (this.counter < this.descriptorMatrix.getRowCount()) {
                int i;
                int tCount = Math.min(this.threadCount, predictedValues.length - this.counter);
                Thread[] threads = new Thread[tCount];
                for (i = 0; i < threads.length && this.counter < this.descriptorMatrix.getRowCount(); ++i) {
                    final int threadIndex = i;
                    final int rowCount = this.counter++;
                    threads[i] = new Thread(new Runnable(){

                        @Override
                        public void run() {
                            predictedValues[rowCount] = CrossValidationTask.this.getPredictedValue(rowCount, threadIndex);
                        }
                    });
                    threads[i].start();
                }
                for (i = 0; i < threads.length; ++i) {
                    try {
                        threads[i].join();
                        continue;
                    }
                    catch (InterruptedException e) {
                        // empty catch block
                    }
                }
            }
            return predictedValues;
        }

        protected double getPredictedValue(int rowIndex, int threadIndex) {
            FullMatrix reducedMatrix = new FullMatrix(this.descriptorMatrix.getColumnCount());
            FutureTrainer<TT, CT> trainer = new FutureTrainer<TT, CT>(this.fittingAlgorithms.get(threadIndex), this.trainingModels.get(threadIndex), reducedMatrix);
            PredictionAlgorithm<CT> predictionAlgorithm = this.predictionAlgorithms.get(threadIndex);
            double prediction = -1.0;
            MatrixWrapper filteredMatrix = new MatrixWrapper(this.descriptorMatrix);
            ArrayList<Double> filteredValues = new ArrayList<Double>();
            filteredValues.addAll(this.experimentalValues);
            double[] descriptors = filteredMatrix.getRow(rowIndex);
            filteredMatrix.setState(rowIndex, false);
            filteredValues.remove(rowIndex);
            try {
                for (int j = 0; j < filteredMatrix.getRowCount(); ++j) {
                    trainer.add(filteredMatrix.getRow(j), (double)((Double)filteredValues.get(j)));
                }
                CT coeff = trainer.getResults();
                predictionAlgorithm.setCoefficients(coeff);
                prediction = predictionAlgorithm.predict(descriptors);
            }
            catch (Exception e) {
                // empty catch block
            }
            return prediction;
        }

        @Override
        public float getProgress() {
            return (float)this.counter / (float)this.iteration;
        }
    }

    protected static class CrossValidationCalculation<TT, CT>
    extends FutureCalculation<double[]> {
        protected CrossValidationCalculation(CrossValidationTask<TT, CT> task) {
            this.longRunningTask = task;
        }

        public CrossValidationTask<TT, CT> getTask() {
            return (CrossValidationTask)this.longRunningTask;
        }

        protected TrainingModel<TT> getModel() {
            return this.getTask().trainingModels.get(0);
        }

        protected DescriptorMatrix getMatrix() {
            return this.getTask().descriptorMatrix;
        }

        protected List<Double> getList() {
            return this.getTask().experimentalValues;
        }

        protected double[] getExperimentalValues() {
            List<Double> list = this.getList();
            double[] result = new double[list.size()];
            for (int i = 0; i < result.length; ++i) {
                result[i] = list.get(i);
            }
            return result;
        }
    }
}

