/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer;
import org.nd4j.linalg.dataset.api.preprocessor.StandardizeStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

public class NormalizerStandardize
extends AbstractDataSetNormalizer<DistributionStats> {
    public NormalizerStandardize(@NonNull INDArray featureMean, @NonNull INDArray featureStd) {
        this();
        if (featureMean == null) {
            throw new NullPointerException("featureMean");
        }
        if (featureStd == null) {
            throw new NullPointerException("featureStd");
        }
        this.setFeatureStats(new DistributionStats(featureMean, featureStd));
        this.fitLabel(false);
    }

    public NormalizerStandardize(@NonNull INDArray featureMean, @NonNull INDArray featureStd, @NonNull INDArray labelMean, @NonNull INDArray labelStd) {
        this();
        if (featureMean == null) {
            throw new NullPointerException("featureMean");
        }
        if (featureStd == null) {
            throw new NullPointerException("featureStd");
        }
        if (labelMean == null) {
            throw new NullPointerException("labelMean");
        }
        if (labelStd == null) {
            throw new NullPointerException("labelStd");
        }
        this.setFeatureStats(new DistributionStats(featureMean, featureStd));
        this.setLabelStats(new DistributionStats(labelMean, labelStd));
        this.fitLabel(true);
    }

    public NormalizerStandardize() {
        super(new StandardizeStrategy());
    }

    public void setLabelStats(@NonNull INDArray labelMean, @NonNull INDArray labelStd) {
        if (labelMean == null) {
            throw new NullPointerException("labelMean");
        }
        if (labelStd == null) {
            throw new NullPointerException("labelStd");
        }
        this.setLabelStats(new DistributionStats(labelMean, labelStd));
    }

    public INDArray getMean() {
        return ((DistributionStats)this.getFeatureStats()).getMean();
    }

    public INDArray getLabelMean() {
        return ((DistributionStats)this.getLabelStats()).getMean();
    }

    public INDArray getStd() {
        return ((DistributionStats)this.getFeatureStats()).getStd();
    }

    public INDArray getLabelStd() {
        return ((DistributionStats)this.getLabelStats()).getStd();
    }

    public void load(File ... files) throws IOException {
        this.setFeatureStats(DistributionStats.load(files[0], files[1]));
        if (this.isFitLabel()) {
            this.setLabelStats(DistributionStats.load(files[2], files[3]));
        }
    }

    public void save(File ... files) throws IOException {
        ((DistributionStats)this.getFeatureStats()).save(files[0], files[1]);
        if (this.isFitLabel()) {
            ((DistributionStats)this.getLabelStats()).save(files[2], files[3]);
        }
    }

    @Override
    protected NormalizerStats.Builder newBuilder() {
        return new DistributionStats.Builder();
    }

    @Override
    public NormalizerType getType() {
        return NormalizerType.STANDARDIZE;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof NormalizerStandardize)) {
            return false;
        }
        NormalizerStandardize other = (NormalizerStandardize)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return super.equals(o);
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof NormalizerStandardize;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + super.hashCode();
        return result;
    }
}

