/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.iterator;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class KFoldIterator
implements DataSetIterator {
    private DataSet singleFold;
    private int k;
    private int batch;
    private int lastBatch;
    private int kCursor = 0;
    private DataSet test;
    private DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet singleFold) {
        this(10, singleFold);
    }

    public KFoldIterator(int k, DataSet singleFold) {
        this.k = k;
        this.singleFold = singleFold.copy();
        if (k <= 1) {
            throw new IllegalArgumentException();
        }
        if (singleFold.numExamples() % k != 0) {
            if (k != 2) {
                this.batch = singleFold.numExamples() / (k - 1);
                this.lastBatch = singleFold.numExamples() % (k - 1);
            } else {
                this.lastBatch = singleFold.numExamples() / 2;
                this.batch = this.lastBatch + 1;
            }
        } else {
            this.batch = singleFold.numExamples() / k;
            this.lastBatch = singleFold.numExamples() / k;
        }
    }

    @Override
    public DataSet next(int num) throws UnsupportedOperationException {
        return null;
    }

    @Override
    public int totalExamples() {
        return this.singleFold.getLabels().size(0);
    }

    @Override
    public int inputColumns() {
        return this.singleFold.getFeatures().size(1);
    }

    @Override
    public int totalOutcomes() {
        return this.singleFold.getLabels().size(1);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    @Override
    public void reset() {
        this.singleFold.shuffle();
        this.kCursor = 0;
    }

    @Override
    public int batch() {
        return this.batch;
    }

    public int lastBatch() {
        return this.lastBatch;
    }

    @Override
    public int cursor() {
        return this.kCursor;
    }

    @Override
    public int numExamples() {
        return this.totalExamples();
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override
    public List<String> getLabels() {
        return this.singleFold.getLabelNamesList();
    }

    @Override
    public boolean hasNext() {
        return this.kCursor < this.k;
    }

    @Override
    public DataSet next() {
        this.nextFold();
        return this.train;
    }

    @Override
    public void remove() {
    }

    private void nextFold() {
        int right;
        int left;
        if (this.kCursor == this.k - 1) {
            left = this.totalExamples() - this.lastBatch;
            right = this.totalExamples();
        } else {
            left = this.kCursor * this.batch;
            right = left + this.batch;
        }
        ArrayList<DataSet> kMinusOneFoldList = new ArrayList<DataSet>();
        if (right < this.totalExamples()) {
            if (left > 0) {
                kMinusOneFoldList.add((DataSet)this.singleFold.getRange(0, left));
            }
            kMinusOneFoldList.add((DataSet)this.singleFold.getRange(right, this.totalExamples()));
            this.train = DataSet.merge(kMinusOneFoldList);
        } else {
            this.train = (DataSet)this.singleFold.getRange(0, left);
        }
        this.test = (DataSet)this.singleFold.getRange(left, right);
        ++this.kCursor;
    }

    public DataSet testFold() {
        return this.test;
    }
}

