/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.classify.WeightedRVFDataset;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.math.ADMath;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.DoubleAD;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.StochasticCalculateMethods;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.logging.Redwood;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.CountDownLatch;

public class LogConditionalObjectiveFunction<L, F>
extends AbstractStochasticCachingDiffUpdateFunction {
    private static Redwood.RedwoodChannels log = Redwood.channels(LogConditionalObjectiveFunction.class);
    protected final LogPrior prior;
    protected final int numFeatures;
    protected final int numClasses;
    protected final int[][] data;
    protected final Iterable<Datum<L, F>> dataIterable;
    protected final Index<L> labelIndex;
    protected final Index<F> featureIndex;
    protected final double[][] values;
    protected final int[] labels;
    protected final float[] dataWeights;
    protected final boolean useSummedConditionalLikelihood;
    protected double[] derivativeNumerator = null;
    protected double[] priorDerivative = null;
    protected boolean parallelGradientCalculation = true;
    protected int threads = ArgumentParser.threads;

    @Override
    public int domainDimension() {
        return this.numFeatures * this.numClasses;
    }

    @Override
    public int dataDimension() {
        return this.data.length;
    }

    private int classOf(int index) {
        return index % this.numClasses;
    }

    private int featureOf(int index) {
        return index / this.numClasses;
    }

    protected int indexOf(int f, int c) {
        return f * this.numClasses + c;
    }

    public double[][] to2D(double[] x) {
        double[][] x2 = new double[this.numFeatures][this.numClasses];
        for (int i = 0; i < this.numFeatures; ++i) {
            for (int j = 0; j < this.numClasses; ++j) {
                x2[i][j] = x[this.indexOf(i, j)];
            }
        }
        return x2;
    }

    @Override
    protected void calculate(double[] x) {
        if (this.useSummedConditionalLikelihood) {
            this.calculateSCL(x);
        } else {
            this.calculateCL(x);
        }
    }

    @Override
    public void calculateStochastic(double[] x, double[] v, int[] batch) {
        if (this.method.calculatesHessianVectorProduct() && v != null) {
            if (this.method.equals((Object)StochasticCalculateMethods.AlgorithmicDifferentiation)) {
                this.calculateStochasticAlgorithmicDifferentiation(x, v, batch);
            } else if (this.method.equals((Object)StochasticCalculateMethods.IncorporatedFiniteDifference)) {
                this.calculateStochasticFiniteDifference(x, v, this.finiteDifferenceStepSize, batch);
            }
        } else {
            this.calculateStochasticGradientLocal(x, batch);
        }
    }

    private void calculateSCL(double[] x) {
        this.value = 0.0;
        Arrays.fill(this.derivative, 0.0);
        double[] sums = new double[this.numClasses];
        double[] probs = new double[this.numClasses];
        for (int d = 0; d < this.data.length; ++d) {
            int[] features = this.data[d];
            Arrays.fill(sums, 0.0);
            for (int c = 0; c < this.numClasses; ++c) {
                for (int feature : features) {
                    int i = this.indexOf(feature, c);
                    int n = c;
                    sums[n] = sums[n] + x[i];
                }
            }
            double total = ArrayMath.logSum(sums);
            int ld = this.labels[d];
            for (int c = 0; c < this.numClasses; ++c) {
                probs[c] = Math.exp(sums[c] - total);
                int[] nArray = features;
                int n = nArray.length;
                for (int i = 0; i < n; ++i) {
                    int i2;
                    int feature = nArray[i];
                    int n2 = i2 = this.indexOf(feature, c);
                    this.derivative[n2] = this.derivative[n2] + probs[ld] * probs[c];
                }
            }
            for (int feature : features) {
                int i;
                int n = i = this.indexOf(feature, this.labels[d]);
                this.derivative[n] = this.derivative[n] - probs[ld];
            }
            this.value -= probs[ld];
        }
        int i = 0;
        while (i < x.length) {
            double k = 1.0;
            double w = x[i];
            this.value += k * w * w / 2.0;
            int n = i++;
            this.derivative[n] = this.derivative[n] + k * w;
        }
    }

    private void calculateCL(double[] x) {
        if (this.values != null) {
            this.rvfcalculate(x);
        } else if (this.dataIterable != null) {
            this.calculateCLiterable(x);
        } else {
            this.calculateCLbatch(x);
        }
    }

    private void calculateCLbatch(double[] x) {
        this.value = 0.0;
        if (this.derivativeNumerator == null) {
            this.derivativeNumerator = new double[x.length];
            for (int d = 0; d < this.data.length; ++d) {
                int[] features;
                for (int feature : features = this.data[d]) {
                    int i = this.indexOf(feature, this.labels[d]);
                    if (this.dataWeights == null) {
                        int n = i;
                        this.derivativeNumerator[n] = this.derivativeNumerator[n] - 1.0;
                        continue;
                    }
                    int n = i;
                    this.derivativeNumerator[n] = this.derivativeNumerator[n] - (double)this.dataWeights[d];
                }
            }
        }
        LogConditionalObjectiveFunction.copy(this.derivative, this.derivativeNumerator);
        if (this.parallelGradientCalculation && this.threads > 1) {
            int i;
            CLBatchDerivativeCalculation[] runnables = (CLBatchDerivativeCalculation[])Array.newInstance(CLBatchDerivativeCalculation.class, this.threads);
            CountDownLatch latch = new CountDownLatch(this.threads);
            for (i = 0; i < this.threads; ++i) {
                runnables[i] = new CLBatchDerivativeCalculation(this.threads, i, null, x, this.derivative.length, latch);
                new Thread(runnables[i]).start();
            }
            try {
                latch.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeInterruptedException(e);
            }
            for (i = 0; i < this.threads; ++i) {
                this.value += runnables[i].localValue;
                for (int j = 0; j < this.derivative.length; ++j) {
                    int n = j;
                    this.derivative[n] = this.derivative[n] + runnables[i].localDerivative[j];
                }
            }
        } else {
            double[] sums = new double[this.numClasses];
            double[] probs = new double[this.numClasses];
            for (int d = 0; d < this.data.length; ++d) {
                int feature2;
                int[] featuresArr;
                Arrays.fill(sums, 0.0);
                for (int feature2 : featuresArr = this.data[d]) {
                    int c = 0;
                    while (c < this.numClasses) {
                        int i = this.indexOf(feature2, c);
                        int n = c++;
                        sums[n] = sums[n] + x[i];
                    }
                }
                double total = ArrayMath.logSum(sums);
                for (int c = 0; c < this.numClasses; ++c) {
                    probs[c] = Math.exp(sums[c] - total);
                    if (this.dataWeights == null) continue;
                    int n = c;
                    probs[n] = probs[n] * (double)this.dataWeights[d];
                }
                int[] c = featuresArr;
                feature2 = c.length;
                for (int i = 0; i < feature2; ++i) {
                    int feature3 = c[i];
                    for (int c2 = 0; c2 < this.numClasses; ++c2) {
                        int i2;
                        int n = i2 = this.indexOf(feature3, c2);
                        this.derivative[n] = this.derivative[n] + probs[c2];
                    }
                }
                int labelindex = this.labels[d];
                double dV = sums[labelindex] - total;
                if (this.dataWeights != null) {
                    dV *= (double)this.dataWeights[d];
                }
                this.value -= dV;
            }
        }
        this.value += this.prior.compute(x, this.derivative);
    }

    private void calculateCLiterable(double[] x) {
        this.value = 0.0;
        if (this.derivativeNumerator == null) {
            this.derivativeNumerator = new double[x.length];
            for (Datum<L, F> datum : this.dataIterable) {
                Collection features = datum.asFeatures();
                for (Object feature : features) {
                    int i = this.indexOf(this.featureIndex.indexOf(feature), this.labelIndex.indexOf(datum.label()));
                    if (this.dataWeights != null) continue;
                    int n = i;
                    this.derivativeNumerator[n] = this.derivativeNumerator[n] - 1.0;
                }
            }
        }
        LogConditionalObjectiveFunction.copy(this.derivative, this.derivativeNumerator);
        double[] sums = new double[this.numClasses];
        double[] probs = new double[this.numClasses];
        for (Datum<L, F> datum : this.dataIterable) {
            int c;
            Arrays.fill(sums, 0.0);
            Collection features = datum.asFeatures();
            for (Object feature : features) {
                c = 0;
                while (c < this.numClasses) {
                    int i = this.indexOf(this.featureIndex.indexOf(feature), c);
                    int n = c++;
                    sums[n] = sums[n] + x[i];
                }
            }
            double total = ArrayMath.logSum(sums);
            for (c = 0; c < this.numClasses; ++c) {
                probs[c] = Math.exp(sums[c] - total);
            }
            for (Object feature : features) {
                for (int c2 = 0; c2 < this.numClasses; ++c2) {
                    int i;
                    int n = i = this.indexOf(this.featureIndex.indexOf(feature), c2);
                    this.derivative[n] = this.derivative[n] + probs[c2];
                }
            }
            int label = this.labelIndex.indexOf(datum.label());
            double dV = sums[label] - total;
            this.value -= dV;
        }
        this.value += this.prior.compute(x, this.derivative);
    }

    public void calculateStochasticFiniteDifference(double[] x, double[] v, double h, int[] batch) {
        if (this.values != null) {
            this.rvfcalculate(x);
            return;
        }
        this.value = 0.0;
        if (this.priorDerivative == null) {
            this.priorDerivative = new double[x.length];
        }
        double priorFactor = (double)batch.length / ((double)this.data.length * this.prior.getSigma() * this.prior.getSigma());
        this.derivative = ArrayMath.multiply(x, priorFactor);
        this.HdotV = ArrayMath.multiply(v, priorFactor);
        double[] sums = new double[this.numClasses];
        double[] sumsV = new double[this.numClasses];
        double[] probs = new double[this.numClasses];
        double[] probsV = new double[this.numClasses];
        for (int m : batch) {
            int[] features = this.data[m];
            Arrays.fill(sums, 0.0);
            Arrays.fill(sumsV, 0.0);
            for (int c = 0; c < this.numClasses; ++c) {
                for (int feature : features) {
                    int i = this.indexOf(feature, c);
                    int n = c;
                    sums[n] = sums[n] + x[i];
                    int n2 = c;
                    sumsV[n2] = sumsV[n2] + (x[i] + h * v[i]);
                }
            }
            double total = ArrayMath.logSum(sums);
            double totalV = ArrayMath.logSum(sumsV);
            for (int c = 0; c < this.numClasses; ++c) {
                probs[c] = Math.exp(sums[c] - total);
                probsV[c] = Math.exp(sumsV[c] - totalV);
                if (this.dataWeights != null) {
                    int n = c;
                    probs[n] = probs[n] * (double)this.dataWeights[m];
                    int n3 = c;
                    probsV[n3] = probsV[n3] * (double)this.dataWeights[m];
                }
                for (int feature : features) {
                    int i;
                    int n = i = this.indexOf(feature, c);
                    this.derivative[n] = this.derivative[n] + probs[c];
                    int n4 = i;
                    this.HdotV[n4] = this.HdotV[n4] + (probsV[c] - probs[c]) / h;
                    if (c != this.labels[m]) continue;
                    int n5 = i;
                    this.derivative[n5] = this.derivative[n5] - 1.0;
                }
            }
            double dV = sums[this.labels[m]] - total;
            if (this.dataWeights != null) {
                dV *= (double)this.dataWeights[m];
            }
            this.value -= dV;
        }
        this.value += (double)batch.length / (double)this.data.length * this.prior.compute(x, this.priorDerivative);
    }

    public void calculateStochasticGradientLocal(double[] x, int[] batch) {
        if (this.values != null) {
            this.rvfcalculate(x);
            return;
        }
        this.value = 0.0;
        int batchSize = batch.length;
        if (this.priorDerivative == null) {
            this.priorDerivative = new double[x.length];
        }
        double priorFactor = (double)batchSize / ((double)this.data.length * this.prior.getSigma() * this.prior.getSigma());
        this.derivative = ArrayMath.multiply(x, priorFactor);
        double[] sums = new double[this.numClasses];
        double[] probs = new double[this.numClasses];
        for (int m : batch) {
            int[] features = this.data[m];
            Arrays.fill(sums, 0.0);
            for (int c = 0; c < this.numClasses; ++c) {
                for (int feature : features) {
                    int i = this.indexOf(feature, c);
                    int n = c;
                    sums[n] = sums[n] + x[i];
                }
            }
            double total = ArrayMath.logSum(sums);
            for (int c = 0; c < this.numClasses; ++c) {
                probs[c] = Math.exp(sums[c] - total);
                if (this.dataWeights != null) {
                    int n = c;
                    probs[n] = probs[n] * (double)this.dataWeights[m];
                }
                for (int feature : features) {
                    int i;
                    int n = i = this.indexOf(feature, c);
                    this.derivative[n] = this.derivative[n] + probs[c];
                    if (c != this.labels[m]) continue;
                    int n2 = i;
                    this.derivative[n2] = this.derivative[n2] - 1.0;
                }
            }
            double dV = sums[this.labels[m]] - total;
            if (this.dataWeights != null) {
                dV *= (double)this.dataWeights[m];
            }
            this.value -= dV;
        }
        this.value += (double)batchSize / (double)this.data.length * this.prior.compute(x, this.priorDerivative);
    }

    @Override
    public double valueAt(double[] x, double xscale, int[] batch) {
        this.value = 0.0;
        double[] sums = new double[this.numClasses];
        for (int m : batch) {
            int[] features = this.data[m];
            Arrays.fill(sums, 0.0);
            for (int c = 0; c < this.numClasses; ++c) {
                for (int f = 0; f < features.length; ++f) {
                    int i = this.indexOf(features[f], c);
                    if (this.values != null) {
                        int n = c;
                        sums[n] = sums[n] + x[i] * xscale * this.values[m][f];
                        continue;
                    }
                    int n = c;
                    sums[n] = sums[n] + x[i] * xscale;
                }
            }
            double total = ArrayMath.logSum(sums);
            double dV = sums[this.labels[m]] - total;
            if (this.dataWeights != null) {
                dV *= (double)this.dataWeights[m];
            }
            this.value -= dV;
        }
        return this.value;
    }

    @Override
    public double calculateStochasticUpdate(double[] x, double xscale, int[] batch, double gain) {
        this.value = 0.0;
        if (this.parallelGradientCalculation && this.threads > 1) {
            int examplesPerProcessor = 50;
            if (batch.length <= Runtime.getRuntime().availableProcessors() * examplesPerProcessor) {
                log.info("\n\n***************");
                log.info("CONFIGURATION ERROR: YOUR BATCH SIZE DOESN'T MEET PARALLEL MINIMUM SIZE FOR PERFORMANCE");
                log.info("Batch size: " + batch.length);
                log.info("CPUS: " + Runtime.getRuntime().availableProcessors());
                log.info("Minimum batch size per CPU: " + examplesPerProcessor);
                log.info("MINIMIM BATCH SIZE ON THIS MACHINE: " + Runtime.getRuntime().availableProcessors() * examplesPerProcessor);
                log.info("TURNING OFF PARALLEL GRADIENT COMPUTATION");
                log.info("***************\n");
                this.parallelGradientCalculation = false;
            }
        }
        if (this.parallelGradientCalculation && this.threads > 1) {
            int i;
            CLBatchDerivativeCalculation[] runnables = (CLBatchDerivativeCalculation[])Array.newInstance(CLBatchDerivativeCalculation.class, this.threads);
            CountDownLatch latch = new CountDownLatch(this.threads);
            for (i = 0; i < this.threads; ++i) {
                runnables[i] = new CLBatchDerivativeCalculation(this.threads, i, batch, x, x.length, latch);
                new Thread(runnables[i]).start();
            }
            try {
                latch.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeInterruptedException(e);
            }
            for (i = 0; i < this.threads; ++i) {
                this.value += runnables[i].localValue;
                for (int j = 0; j < x.length; ++j) {
                    int n = j;
                    x[n] = x[n] + runnables[i].localDerivative[j] * xscale * gain;
                }
            }
        } else {
            double[] sums = new double[this.numClasses];
            double[] probs = new double[this.numClasses];
            for (int m : batch) {
                int[] features = this.data[m];
                Arrays.fill(sums, 0.0);
                for (int c = 0; c < this.numClasses; ++c) {
                    for (int f = 0; f < features.length; ++f) {
                        int i = this.indexOf(features[f], c);
                        if (this.values != null) {
                            int n = c;
                            sums[n] = sums[n] + x[i] * xscale * this.values[m][f];
                            continue;
                        }
                        int n = c;
                        sums[n] = sums[n] + x[i] * xscale;
                    }
                }
                for (int f = 0; f < features.length; ++f) {
                    int i = this.indexOf(features[f], this.labels[m]);
                    double v = this.values != null ? this.values[m][f] : 1.0;
                    double delta = this.dataWeights != null ? (double)this.dataWeights[m] * v : v;
                    int n = i;
                    x[n] = x[n] + delta * gain;
                }
                double total = ArrayMath.logSum(sums);
                for (int c = 0; c < this.numClasses; ++c) {
                    probs[c] = Math.exp(sums[c] - total);
                    if (this.dataWeights != null) {
                        int n = c;
                        probs[n] = probs[n] * (double)this.dataWeights[m];
                    }
                    for (int f = 0; f < features.length; ++f) {
                        int i = this.indexOf(features[f], c);
                        double v = this.values != null ? this.values[m][f] : 1.0;
                        double delta = probs[c] * v;
                        int n = i;
                        x[n] = x[n] - delta * gain;
                    }
                }
                double dV = sums[this.labels[m]] - total;
                if (this.dataWeights != null) {
                    dV *= (double)this.dataWeights[m];
                }
                this.value -= dV;
            }
        }
        return this.value;
    }

    @Override
    public void calculateStochasticGradient(double[] x, int[] batch) {
        if (this.derivative == null) {
            this.derivative = new double[this.domainDimension()];
        }
        Arrays.fill(this.derivative, 0.0);
        double[] sums = new double[this.numClasses];
        double[] probs = new double[this.numClasses];
        for (int d : batch) {
            int[] features = this.data[d];
            Arrays.fill(sums, 0.0);
            for (int c = 0; c < this.numClasses; ++c) {
                for (int feature : features) {
                    int i = this.indexOf(feature, c);
                    int n = c;
                    sums[n] = sums[n] + x[i];
                }
            }
            double total = ArrayMath.logSum(sums);
            int ld = this.labels[d];
            for (int c = 0; c < this.numClasses; ++c) {
                probs[c] = Math.exp(sums[c] - total);
                int[] nArray = features;
                int n = nArray.length;
                for (int i = 0; i < n; ++i) {
                    int i2;
                    int feature = nArray[i];
                    int n2 = i2 = this.indexOf(feature, c);
                    this.derivative[n2] = this.derivative[n2] + probs[ld] * probs[c];
                }
            }
            for (int feature : features) {
                int i;
                int n = i = this.indexOf(feature, this.labels[d]);
                this.derivative[n] = this.derivative[n] - probs[ld];
            }
        }
    }

    protected void calculateStochasticAlgorithmicDifferentiation(double[] x, double[] v, int[] batch) {
        log.info("*");
        this.value = 0.0;
        DoubleAD[] derivativeAD = new DoubleAD[x.length];
        for (int i = 0; i < x.length; ++i) {
            derivativeAD[i] = new DoubleAD(0.0, 0.0);
        }
        DoubleAD[] xAD = new DoubleAD[x.length];
        for (int i = 0; i < x.length; ++i) {
            xAD[i] = new DoubleAD(x[i], v[i]);
        }
        DoubleAD[] sums = new DoubleAD[this.numClasses];
        for (int c = 0; c < this.numClasses; ++c) {
            sums[c] = new DoubleAD(0.0, 0.0);
        }
        DoubleAD[] probs = new DoubleAD[this.numClasses];
        for (int c = 0; c < this.numClasses; ++c) {
            probs[c] = new DoubleAD(0.0, 0.0);
        }
        for (int i = 0; i < x.length; ++i) {
            xAD[i].set(x[i], v[i]);
            derivativeAD[i].set(0.0, 0.0);
        }
        for (int d = 0; d < batch.length; ++d) {
            int c;
            int m = (this.curElement + d) % this.data.length;
            int[] features = this.data[m];
            for (c = 0; c < this.numClasses; ++c) {
                sums[c].set(0.0, 0.0);
            }
            for (c = 0; c < this.numClasses; ++c) {
                for (int feature : features) {
                    int i = this.indexOf(feature, c);
                    sums[c] = ADMath.plus(sums[c], xAD[i]);
                }
            }
            DoubleAD total = ADMath.logSum(sums);
            for (int c2 = 0; c2 < this.numClasses; ++c2) {
                probs[c2] = ADMath.exp(ADMath.minus(sums[c2], total));
                if (this.dataWeights != null) {
                    probs[c2] = ADMath.multConst(probs[c2], this.dataWeights[d]);
                }
                for (int feature : features) {
                    int i = this.indexOf(feature, c2);
                    if (c2 == this.labels[m]) {
                        derivativeAD[i].plusEqualsConst(-1.0);
                    }
                    derivativeAD[i].plusEquals(probs[c2]);
                }
            }
            double dV = sums[this.labels[m]].getval() - total.getval();
            if (this.dataWeights != null) {
                dV *= (double)this.dataWeights[d];
            }
            this.value -= dV;
        }
        double[] tmp = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            tmp[i] = derivativeAD[i].getval();
            derivativeAD[i].plusEquals(ADMath.multConst(xAD[i], (double)batch.length / ((double)this.data.length * this.prior.getSigma() * this.prior.getSigma())));
            this.derivative[i] = derivativeAD[i].getval();
            this.HdotV[i] = derivativeAD[i].getdot();
        }
        this.value += (double)batch.length / (double)this.data.length * this.prior.compute(x, tmp);
    }

    protected void rvfcalculate(double[] x) {
        this.value = 0.0;
        if (this.derivativeNumerator == null) {
            this.derivativeNumerator = new double[x.length];
            for (int d = 0; d < this.data.length; ++d) {
                int[] features = this.data[d];
                double[] vals = this.values[d];
                for (int f = 0; f < features.length; ++f) {
                    int i = this.indexOf(features[f], this.labels[d]);
                    if (this.dataWeights == null) {
                        int n = i;
                        this.derivativeNumerator[n] = this.derivativeNumerator[n] - vals[f];
                        continue;
                    }
                    int n = i;
                    this.derivativeNumerator[n] = this.derivativeNumerator[n] - (double)this.dataWeights[d] * vals[f];
                }
            }
        }
        LogConditionalObjectiveFunction.copy(this.derivative, this.derivativeNumerator);
        if (this.parallelGradientCalculation && this.threads > 1) {
            int i;
            RVFDerivativeCalculation[] runnables = (RVFDerivativeCalculation[])Array.newInstance(RVFDerivativeCalculation.class, this.threads);
            CountDownLatch latch = new CountDownLatch(this.threads);
            for (i = 0; i < this.threads; ++i) {
                runnables[i] = new RVFDerivativeCalculation(this.threads, i, x, this.derivative.length, latch);
                new Thread(runnables[i]).start();
            }
            try {
                latch.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeInterruptedException(e);
            }
            for (i = 0; i < this.threads; ++i) {
                this.value += runnables[i].localValue;
                for (int j = 0; j < this.derivative.length; ++j) {
                    int n = j;
                    this.derivative[n] = this.derivative[n] + runnables[i].localDerivative[j];
                }
            }
        } else {
            double[] sums = new double[this.numClasses];
            double[] probs = new double[this.numClasses];
            for (int d = 0; d < this.data.length; ++d) {
                int[] features = this.data[d];
                double[] vals = this.values[d];
                Arrays.fill(sums, 0.0);
                for (int f = 0; f < features.length; ++f) {
                    int feature = features[f];
                    double val = vals[f];
                    int c = 0;
                    while (c < this.numClasses) {
                        int i = this.indexOf(feature, c);
                        int n = c++;
                        sums[n] = sums[n] + x[i] * val;
                    }
                }
                double total = ArrayMath.logSum(sums);
                for (int c = 0; c < this.numClasses; ++c) {
                    probs[c] = Math.exp(sums[c] - total);
                    if (this.dataWeights == null) continue;
                    int n = c;
                    probs[n] = probs[n] * (double)this.dataWeights[d];
                }
                for (int f = 0; f < features.length; ++f) {
                    int feature = features[f];
                    double val = vals[f];
                    for (int c = 0; c < this.numClasses; ++c) {
                        int i;
                        int n = i = this.indexOf(feature, c);
                        this.derivative[n] = this.derivative[n] + probs[c] * val;
                    }
                }
                double dV = sums[this.labels[d]] - total;
                if (this.dataWeights != null) {
                    dV *= (double)this.dataWeights[d];
                }
                this.value -= dV;
            }
        }
        this.value += this.prior.compute(x, this.derivative);
    }

    public LogConditionalObjectiveFunction(GeneralDataset<L, F> dataset) {
        this(dataset, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public LogConditionalObjectiveFunction(GeneralDataset<L, F> dataset, LogPrior prior) {
        this(dataset, prior, false);
    }

    public LogConditionalObjectiveFunction(GeneralDataset<L, F> dataset, float[] dataWeights, LogPrior prior) {
        this(dataset, prior, false, dataWeights);
    }

    public LogConditionalObjectiveFunction(GeneralDataset<L, F> dataset, LogPrior prior, boolean useSumCondObjFun) {
        this(dataset, prior, useSumCondObjFun, null);
    }

    public LogConditionalObjectiveFunction(GeneralDataset<L, F> dataset, LogPrior prior, boolean useSumCondObjFun, float[] dataWeights) {
        this.prior = prior;
        this.useSummedConditionalLikelihood = useSumCondObjFun;
        this.numFeatures = dataset.numFeatures();
        this.numClasses = dataset.numClasses();
        this.data = dataset.getDataArray();
        this.labels = dataset.getLabelsArray();
        this.values = dataset.getValuesArray();
        this.dataWeights = dataWeights != null ? dataWeights : (dataset instanceof WeightedDataset ? ((WeightedDataset)dataset).getWeights() : (float[])(dataset instanceof WeightedRVFDataset ? ((WeightedRVFDataset)dataset).getWeights() : null));
        this.labelIndex = null;
        this.featureIndex = null;
        this.dataIterable = null;
    }

    public LogConditionalObjectiveFunction(Iterable<Datum<L, F>> dataIterable, LogPrior logPrior, Index<F> featureIndex, Index<L> labelIndex) {
        this.prior = logPrior;
        this.useSummedConditionalLikelihood = false;
        this.numFeatures = featureIndex.size();
        this.numClasses = labelIndex.size();
        this.data = null;
        this.dataIterable = dataIterable;
        this.labelIndex = labelIndex;
        this.featureIndex = featureIndex;
        this.labels = null;
        this.values = null;
        this.dataWeights = null;
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, boolean useSumCondObjFun) {
        this(numFeatures, numClasses, data, labels, null, new LogPrior(LogPrior.LogPriorType.QUADRATIC), useSumCondObjFun);
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels) {
        this(numFeatures, numClasses, data, labels, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, LogPrior prior) {
        this(numFeatures, numClasses, data, labels, null, prior);
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, float[] dataWeights) {
        this(numFeatures, numClasses, data, labels, dataWeights, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, float[] dataWeights, LogPrior prior) {
        this(numFeatures, numClasses, data, labels, dataWeights, prior, false);
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, float[] dataWeights, LogPrior prior, boolean useSummedConditionalLikelihood) {
        this.numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.data = data;
        this.values = null;
        this.labels = labels;
        this.prior = prior;
        this.dataWeights = dataWeights;
        this.labelIndex = null;
        this.featureIndex = null;
        this.dataIterable = null;
        this.useSummedConditionalLikelihood = useSummedConditionalLikelihood;
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, int[] labels, int intPrior, double sigma, double epsilon) {
        this(numFeatures, numClasses, data, null, labels, intPrior, sigma, epsilon);
    }

    public LogConditionalObjectiveFunction(int numFeatures, int numClasses, int[][] data, double[][] values, int[] labels, int intPrior, double sigma, double epsilon) {
        this.numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.data = data;
        this.values = values;
        this.labels = labels;
        this.prior = new LogPrior(intPrior, sigma, epsilon);
        this.labelIndex = null;
        this.featureIndex = null;
        this.dataIterable = null;
        this.useSummedConditionalLikelihood = false;
        this.dataWeights = null;
    }

    private class RVFDerivativeCalculation
    implements Runnable {
        int numThreads;
        int threadIdx;
        double localValue = 0.0;
        double[] x;
        double[] localDerivative;
        CountDownLatch latch;

        public RVFDerivativeCalculation(int numThreads, int threadIdx, double[] x, int derivativeSize, CountDownLatch latch) {
            this.numThreads = numThreads;
            this.threadIdx = threadIdx;
            this.x = x;
            this.localDerivative = new double[derivativeSize];
            this.latch = latch;
        }

        @Override
        public void run() {
            double[] sums = new double[LogConditionalObjectiveFunction.this.numClasses];
            double[] probs = new double[LogConditionalObjectiveFunction.this.numClasses];
            for (int d = this.threadIdx; d < LogConditionalObjectiveFunction.this.data.length; d += this.numThreads) {
                int c;
                int[] features = LogConditionalObjectiveFunction.this.data[d];
                double[] vals = LogConditionalObjectiveFunction.this.values[d];
                Arrays.fill(sums, 0.0);
                for (int c2 = 0; c2 < LogConditionalObjectiveFunction.this.numClasses; ++c2) {
                    for (int f = 0; f < features.length; ++f) {
                        int feature = features[f];
                        double val = vals[f];
                        int i = LogConditionalObjectiveFunction.this.indexOf(feature, c2);
                        int n = c2;
                        sums[n] = sums[n] + this.x[i] * val;
                    }
                }
                double total = ArrayMath.logSum(sums);
                for (c = 0; c < LogConditionalObjectiveFunction.this.numClasses; ++c) {
                    probs[c] = Math.exp(sums[c] - total);
                    if (LogConditionalObjectiveFunction.this.dataWeights == null) continue;
                    int n = c;
                    probs[n] = probs[n] * (double)LogConditionalObjectiveFunction.this.dataWeights[d];
                }
                for (c = 0; c < LogConditionalObjectiveFunction.this.numClasses; ++c) {
                    for (int f = 0; f < features.length; ++f) {
                        int i;
                        int feature = features[f];
                        double val = vals[f];
                        int n = i = LogConditionalObjectiveFunction.this.indexOf(feature, c);
                        this.localDerivative[n] = this.localDerivative[n] + probs[c] * val;
                    }
                }
                double dV = sums[LogConditionalObjectiveFunction.this.labels[d]] - total;
                if (LogConditionalObjectiveFunction.this.dataWeights != null) {
                    dV *= (double)LogConditionalObjectiveFunction.this.dataWeights[d];
                }
                this.localValue -= dV;
            }
            this.latch.countDown();
        }
    }

    private class CLBatchDerivativeCalculation
    implements Runnable {
        int numThreads;
        int threadIdx;
        double localValue = 0.0;
        double[] x;
        int[] batch;
        double[] localDerivative;
        CountDownLatch latch;

        public CLBatchDerivativeCalculation(int numThreads, int threadIdx, int[] batch, double[] x, int derivativeSize, CountDownLatch latch) {
            this.numThreads = numThreads;
            this.threadIdx = threadIdx;
            this.x = x;
            this.batch = batch;
            this.localDerivative = new double[derivativeSize];
            this.latch = latch;
        }

        @Override
        public void run() {
            double[] sums = new double[LogConditionalObjectiveFunction.this.numClasses];
            double[] probs = new double[LogConditionalObjectiveFunction.this.numClasses];
            int batchSize = this.batch == null ? LogConditionalObjectiveFunction.this.data.length : this.batch.length;
            for (int m = this.threadIdx; m < batchSize; m += this.numThreads) {
                int c;
                int d = this.batch == null ? m : this.batch[m];
                Arrays.fill(sums, 0.0);
                int[] featuresArr = LogConditionalObjectiveFunction.this.data[d];
                for (int c2 = 0; c2 < LogConditionalObjectiveFunction.this.numClasses; ++c2) {
                    for (int feature : featuresArr) {
                        int i = LogConditionalObjectiveFunction.this.indexOf(feature, c2);
                        int n = c2;
                        sums[n] = sums[n] + this.x[i];
                    }
                }
                double total = ArrayMath.logSum(sums);
                for (c = 0; c < LogConditionalObjectiveFunction.this.numClasses; ++c) {
                    probs[c] = Math.exp(sums[c] - total);
                    if (LogConditionalObjectiveFunction.this.dataWeights == null) continue;
                    int n = c;
                    probs[n] = probs[n] * (double)LogConditionalObjectiveFunction.this.dataWeights[d];
                }
                for (c = 0; c < LogConditionalObjectiveFunction.this.numClasses; ++c) {
                    for (int feature : featuresArr) {
                        int i;
                        int n = i = LogConditionalObjectiveFunction.this.indexOf(feature, c);
                        this.localDerivative[n] = this.localDerivative[n] + probs[c];
                    }
                }
                int labelindex = LogConditionalObjectiveFunction.this.labels[d];
                double dV = sums[labelindex] - total;
                if (LogConditionalObjectiveFunction.this.dataWeights != null) {
                    dV *= (double)LogConditionalObjectiveFunction.this.dataWeights[d];
                }
                this.localValue -= dV;
            }
            this.latch.countDown();
        }
    }
}

