package org.deeplearning4j.parallelism.trainer;

import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.SharedGradient;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/trainer/SymmetricTrainer.class */
public class SymmetricTrainer extends DefaultTrainer implements CommunicativeTrainer {
    private static final Logger log = LoggerFactory.getLogger(SymmetricTrainer.class);
    protected GradientsAccumulator accumulator;

    public SymmetricTrainer(@NonNull Model model, int i, @NonNull WorkspaceMode workspaceMode, @NonNull ParallelWrapper parallelWrapper, boolean z) {
        if (model == null) {
            throw new NullPointerException("originalModel");
        }
        if (workspaceMode == null) {
            throw new NullPointerException("mode");
        }
        if (parallelWrapper == null) {
            throw new NullPointerException("wrapper");
        }
        this.useMDS = z;
        this.originalModel = model;
        this.threadId = i;
        this.workspaceMode = workspaceMode;
        this.parallelWrapper = parallelWrapper;
        this.accumulator = parallelWrapper.getGradientsAccumulator();
    }

    @Override // org.deeplearning4j.parallelism.trainer.CommunicativeTrainer
    @Deprecated
    public void enqueueGradient(SharedGradient sharedGradient) {
    }

    @Override // org.deeplearning4j.parallelism.trainer.DefaultTrainer, org.deeplearning4j.parallelism.trainer.Trainer
    public boolean averagingRequired() {
        return false;
    }

    @Override // org.deeplearning4j.parallelism.trainer.DefaultTrainer
    protected void postInit() {
        super.postInit();
        if (this.accumulator == null) {
            log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped");
            return;
        }
        if (this.replicatedModel instanceof ComputationGraph) {
            this.replicatedModel.setGradientsAccumulator(this.accumulator);
        } else if (this.replicatedModel instanceof MultiLayerNetwork) {
            this.replicatedModel.setGradientsAccumulator(this.accumulator);
        }
        this.accumulator.touch();
    }
}
