package org.nd4j.aeron.ipc;

import java.beans.ConstructorProperties;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.UUID;
import org.agrona.DirectBuffer;
import org.agrona.concurrent.UnsafeBuffer;
import org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/aeron/ipc/NDArrayMessage.class */
public class NDArrayMessage implements Serializable {
    private INDArray arr;
    private long sent;
    private long index;
    private int[] dimensions;
    private byte[] chunk;
    private int numChunks;
    private static int[] WHOLE_ARRAY_UPDATE = {-1};
    private static int WHOLE_ARRAY_INDEX = -1;

    /* loaded from: input_file:org/nd4j/aeron/ipc/NDArrayMessage$MessageType.class */
    public enum MessageType {
        CHUNKED,
        WHOLE
    }

    /* loaded from: input_file:org/nd4j/aeron/ipc/NDArrayMessage$MessageValidity.class */
    public enum MessageValidity {
        VALID,
        NULL_VALUE,
        INCONSISTENT_DIMENSIONS
    }

    /* loaded from: input_file:org/nd4j/aeron/ipc/NDArrayMessage$NDArrayMessageBuilder.class */
    public static class NDArrayMessageBuilder {
        private INDArray arr;
        private long sent;
        private long index;
        private int[] dimensions;
        private byte[] chunk;
        private int numChunks;

        NDArrayMessageBuilder() {
        }

        public NDArrayMessageBuilder arr(INDArray iNDArray) {
            this.arr = iNDArray;
            return this;
        }

        public NDArrayMessageBuilder sent(long j) {
            this.sent = j;
            return this;
        }

        public NDArrayMessageBuilder index(long j) {
            this.index = j;
            return this;
        }

        public NDArrayMessageBuilder dimensions(int[] iArr) {
            this.dimensions = iArr;
            return this;
        }

        public NDArrayMessageBuilder chunk(byte[] bArr) {
            this.chunk = bArr;
            return this;
        }

        public NDArrayMessageBuilder numChunks(int i) {
            this.numChunks = i;
            return this;
        }

        public NDArrayMessage build() {
            return new NDArrayMessage(this.arr, this.sent, this.index, this.dimensions, this.chunk, this.numChunks);
        }

        public String toString() {
            return "NDArrayMessage.NDArrayMessageBuilder(arr=" + this.arr + ", sent=" + this.sent + ", index=" + this.index + ", dimensions=" + Arrays.toString(this.dimensions) + ", chunk=" + Arrays.toString(this.chunk) + ", numChunks=" + this.numChunks + ")";
        }
    }

    public static int numChunksForMessage(NDArrayMessage nDArrayMessage, int i) {
        int byteBufferSizeForMessage = byteBufferSizeForMessage(nDArrayMessage);
        int i2 = byteBufferSizeForMessage / i;
        if (i2 * i < byteBufferSizeForMessage) {
            i2++;
        }
        return i2;
    }

    public static NDArrayMessage[] chunkedMessages(NDArrayMessage nDArrayMessage, int i) {
        int byteBufferSizeForMessage = (byteBufferSizeForMessage(nDArrayMessage) - 4) / i;
        ByteBuffer byteBuffer = toBuffer(nDArrayMessage).byteBuffer();
        NDArrayMessage[] nDArrayMessageArr = new NDArrayMessage[byteBufferSizeForMessage];
        for (int i2 = 0; i2 < byteBufferSizeForMessage; i2++) {
            byte[] bArr = new byte[i];
            byteBuffer.get(bArr, i2 * i, i);
            nDArrayMessageArr[i2] = builder().chunk(bArr).numChunks(byteBufferSizeForMessage).build();
        }
        return nDArrayMessageArr;
    }

    public static NDArrayMessage wholeArrayUpdate(INDArray iNDArray) {
        return builder().arr(iNDArray).dimensions(WHOLE_ARRAY_UPDATE).index(WHOLE_ARRAY_INDEX).sent(getCurrentTimeUtc()).build();
    }

    public static NDArrayMessage of(INDArray iNDArray, int[] iArr, long j) {
        if (iArr == null) {
            iArr = WHOLE_ARRAY_UPDATE;
        }
        if (j <= 0 || (iArr.length <= 1 && (iArr.length != 1 || iArr[0] == -1))) {
            return builder().index(j).dimensions(iArr).sent(getCurrentTimeUtc()).arr(iNDArray).build();
        }
        throw new IllegalArgumentException("Inconsistent message. Your index is > 0 indicating you want to send a whole ndarray message but your dimensions indicate you are trying to send a partial update. Please ensure you use a 1 length int array with negative 1 as an element or use NDArrayMesage.wholeArrayUpdate(ndarray) for creation instead");
    }

    public static MessageValidity validMessage(NDArrayMessage nDArrayMessage) {
        return (nDArrayMessage.getDimensions() == null || nDArrayMessage.getArr() == null) ? MessageValidity.NULL_VALUE : (nDArrayMessage.getIndex() == -1 || nDArrayMessage.getDimensions().length != 1 || nDArrayMessage.getDimensions()[0] == -1) ? MessageValidity.VALID : MessageValidity.INCONSISTENT_DIMENSIONS;
    }

    public static long getCurrentTimeUtc() {
        return Instant.now().atZone(ZoneOffset.UTC).toInstant().toEpochMilli();
    }

    public static int byteBufferSizeForMessage(NDArrayMessage nDArrayMessage) {
        return 4 + (4 * nDArrayMessage.getDimensions().length) + 4 + 8 + 8 + AeronNDArraySerde.byteBufferSizeFor(nDArrayMessage.getArr());
    }

    public static NDArrayMessage fromChunks(NDArrayMessageChunk[] nDArrayMessageChunkArr) {
        ByteBuffer order = ByteBuffer.allocateDirect(nDArrayMessageChunkArr[0].getChunkSize() * nDArrayMessageChunkArr.length).order(ByteOrder.nativeOrder());
        for (NDArrayMessageChunk nDArrayMessageChunk : nDArrayMessageChunkArr) {
            ByteBuffer data = nDArrayMessageChunk.getData();
            if (data.capacity() > nDArrayMessageChunkArr[0].getChunkSize()) {
                data.position(0).limit(nDArrayMessageChunkArr[0].getChunkSize());
                data = data.slice();
            }
            order.put(data);
        }
        UnsafeBuffer unsafeBuffer = new UnsafeBuffer(order);
        order.rewind();
        return fromBuffer(unsafeBuffer, 0);
    }

    public static NDArrayMessageChunk[] chunks(NDArrayMessage nDArrayMessage, int i) {
        int numChunksForMessage = numChunksForMessage(nDArrayMessage, i);
        NDArrayMessageChunk[] nDArrayMessageChunkArr = new NDArrayMessageChunk[numChunksForMessage];
        DirectBuffer buffer = toBuffer(nDArrayMessage);
        String uuid = UUID.randomUUID().toString();
        for (int i2 = 0; i2 < nDArrayMessageChunkArr.length; i2++) {
            ByteBuffer byteBuffer = (ByteBuffer) buffer.byteBuffer().asReadOnlyBuffer().position(i2 * i);
            byteBuffer.limit(Math.min((i2 * i) + i, buffer.capacity()));
            byteBuffer.order(ByteOrder.nativeOrder());
            nDArrayMessageChunkArr[i2] = NDArrayMessageChunk.builder().id(uuid).chunkSize(i).numChunks(numChunksForMessage).messageType(MessageType.CHUNKED).chunkIndex(i2).data(byteBuffer.slice()).build();
        }
        return nDArrayMessageChunkArr;
    }

    public static DirectBuffer toBuffer(NDArrayMessage nDArrayMessage) {
        ByteBuffer order = ByteBuffer.allocateDirect(byteBufferSizeForMessage(nDArrayMessage)).order(ByteOrder.nativeOrder());
        order.putInt(MessageType.WHOLE.ordinal());
        if (nDArrayMessage.getArr().isCompressed()) {
            AeronNDArraySerde.doByteBufferPutCompressed(nDArrayMessage.getArr(), order, false);
        } else {
            AeronNDArraySerde.doByteBufferPutUnCompressed(nDArrayMessage.getArr(), order, false);
        }
        long sent = nDArrayMessage.getSent();
        long index = nDArrayMessage.getIndex();
        order.putLong(sent);
        order.putLong(index);
        order.putInt(nDArrayMessage.getDimensions().length);
        for (int i = 0; i < nDArrayMessage.getDimensions().length; i++) {
            order.putInt(nDArrayMessage.getDimensions()[i]);
        }
        order.rewind();
        return new UnsafeBuffer(order);
    }

    public static NDArrayMessage fromBuffer(DirectBuffer directBuffer, int i) {
        Pair<INDArray, ByteBuffer> arrayAndByteBuffer = AeronNDArraySerde.toArrayAndByteBuffer(directBuffer, i + 4);
        INDArray iNDArray = (INDArray) arrayAndByteBuffer.getKey();
        Nd4j.getCompressor().decompressi(iNDArray);
        ByteBuffer byteBuffer = (ByteBuffer) arrayAndByteBuffer.getRight();
        long j = byteBuffer.getLong();
        long j2 = byteBuffer.getLong();
        int i2 = byteBuffer.getInt();
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid dimension length " + i2);
        }
        int[] iArr = new int[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            iArr[i3] = byteBuffer.getInt();
        }
        return builder().sent(j).arr(iNDArray).index(j2).dimensions(iArr).build();
    }

    public static NDArrayMessageBuilder builder() {
        return new NDArrayMessageBuilder();
    }

    public INDArray getArr() {
        return this.arr;
    }

    public long getSent() {
        return this.sent;
    }

    public long getIndex() {
        return this.index;
    }

    public int[] getDimensions() {
        return this.dimensions;
    }

    public byte[] getChunk() {
        return this.chunk;
    }

    public int getNumChunks() {
        return this.numChunks;
    }

    public void setArr(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public void setSent(long j) {
        this.sent = j;
    }

    public void setIndex(long j) {
        this.index = j;
    }

    public void setDimensions(int[] iArr) {
        this.dimensions = iArr;
    }

    public void setChunk(byte[] bArr) {
        this.chunk = bArr;
    }

    public void setNumChunks(int i) {
        this.numChunks = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof NDArrayMessage)) {
            return false;
        }
        NDArrayMessage nDArrayMessage = (NDArrayMessage) obj;
        if (!nDArrayMessage.canEqual(this)) {
            return false;
        }
        INDArray arr = getArr();
        INDArray arr2 = nDArrayMessage.getArr();
        if (arr == null) {
            if (arr2 != null) {
                return false;
            }
        } else if (!arr.equals(arr2)) {
            return false;
        }
        return getSent() == nDArrayMessage.getSent() && getIndex() == nDArrayMessage.getIndex() && Arrays.equals(getDimensions(), nDArrayMessage.getDimensions()) && Arrays.equals(getChunk(), nDArrayMessage.getChunk()) && getNumChunks() == nDArrayMessage.getNumChunks();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof NDArrayMessage;
    }

    public int hashCode() {
        INDArray arr = getArr();
        int hashCode = (1 * 59) + (arr == null ? 43 : arr.hashCode());
        long sent = getSent();
        int i = (hashCode * 59) + ((int) ((sent >>> 32) ^ sent));
        long index = getIndex();
        return (((((((i * 59) + ((int) ((index >>> 32) ^ index))) * 59) + Arrays.hashCode(getDimensions())) * 59) + Arrays.hashCode(getChunk())) * 59) + getNumChunks();
    }

    public String toString() {
        return "NDArrayMessage(arr=" + getArr() + ", sent=" + getSent() + ", index=" + getIndex() + ", dimensions=" + Arrays.toString(getDimensions()) + ", chunk=" + Arrays.toString(getChunk()) + ", numChunks=" + getNumChunks() + ")";
    }

    @ConstructorProperties({"arr", "sent", "index", "dimensions", "chunk", "numChunks"})
    public NDArrayMessage(INDArray iNDArray, long j, long j2, int[] iArr, byte[] bArr, int i) {
        this.numChunks = 0;
        this.arr = iNDArray;
        this.sent = j;
        this.index = j2;
        this.dimensions = iArr;
        this.chunk = bArr;
        this.numChunks = i;
    }

    public NDArrayMessage() {
        this.numChunks = 0;
    }
}
