package org.deeplearning4j.parallelism.inference.observers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.class */
public class BatchedInferenceObservable extends BasicInferenceObservable implements InferenceObservable {
    private static final Logger log = LoggerFactory.getLogger(BatchedInferenceObservable.class);
    private List<INDArray[]> inputs;
    private List<INDArray[]> outputs;
    private AtomicInteger counter;
    private ThreadLocal<Integer> position;
    private final Object locker;
    private ReentrantReadWriteLock realLocker;
    private AtomicBoolean isLocked;
    private AtomicBoolean isReadLocked;

    public BatchedInferenceObservable() {
        super(new INDArray[0]);
        this.inputs = new ArrayList();
        this.outputs = new ArrayList();
        this.counter = new AtomicInteger(0);
        this.position = new ThreadLocal<>();
        this.locker = new Object();
        this.realLocker = new ReentrantReadWriteLock();
        this.isLocked = new AtomicBoolean(false);
        this.isReadLocked = new AtomicBoolean(false);
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public void setInput(INDArray... iNDArrayArr) {
        synchronized (this.locker) {
            this.inputs.add(iNDArrayArr);
            this.position.set(Integer.valueOf(this.counter.getAndIncrement()));
            if (this.isReadLocked.get()) {
                this.realLocker.readLock().unlock();
            }
        }
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public INDArray[] getInput() {
        this.realLocker.writeLock().lock();
        this.isLocked.set(true);
        if (this.counter.get() <= 1) {
            this.realLocker.writeLock().unlock();
            return this.inputs.get(0);
        }
        INDArray[] iNDArrayArr = new INDArray[this.inputs.get(0).length];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < this.inputs.size(); i2++) {
                arrayList.add(this.inputs.get(i2)[i]);
            }
            iNDArrayArr[i] = Nd4j.pile(arrayList);
        }
        this.realLocker.writeLock().unlock();
        return iNDArrayArr;
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public void setOutput(INDArray... iNDArrayArr) {
        if (this.counter.get() > 1) {
            for (int i = 0; i < this.counter.get(); i++) {
                this.outputs.add(new INDArray[iNDArrayArr.length]);
            }
            int i2 = 0;
            for (INDArray iNDArray : iNDArrayArr) {
                int[] iArr = new int[iNDArray.rank() - 1];
                for (int i3 = 1; i3 < iNDArray.rank(); i3++) {
                    iArr[i3 - 1] = i3;
                }
                INDArray[] tear = Nd4j.tear(iNDArray, iArr);
                if (tear.length != this.counter.get()) {
                    throw new ND4JIllegalStateException("Number of splits [" + tear.length + "] doesn't match number of queries [" + this.counter.get() + "]");
                }
                for (int i4 = 0; i4 < this.counter.get(); i4++) {
                    this.outputs.get(i4)[i2] = tear[i4];
                }
                i2++;
            }
        } else {
            this.outputs.add(iNDArrayArr);
        }
        setChanged();
        notifyObservers();
    }

    protected List<INDArray[]> getOutputs() {
        return this.outputs;
    }

    protected void setCounter(int i) {
        this.counter.set(i);
    }

    public void setPosition(int i) {
        this.position.set(Integer.valueOf(i));
    }

    public int getCounter() {
        return this.counter.get();
    }

    public boolean isLocked() {
        boolean z = (!this.realLocker.readLock().tryLock()) || this.isLocked.get();
        if (!z) {
            this.isReadLocked.set(true);
        }
        return z;
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public INDArray[] getOutput() {
        return this.outputs.get(this.position.get().intValue());
    }
}
