package org.deeplearning4j.parallelism.factory;

import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.trainer.DefaultTrainer;
import org.deeplearning4j.parallelism.trainer.Trainer;

/* loaded from: input_file:org/deeplearning4j/parallelism/factory/DefaultTrainerContext.class */
public class DefaultTrainerContext implements TrainerContext {
    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void init(Model model, Object... objArr) {
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public Trainer create(int i, Model model, int i2, boolean z, ParallelWrapper parallelWrapper, WorkspaceMode workspaceMode, int i3) {
        DefaultTrainer build = DefaultTrainer.builder().originalModel(model).replicatedModel(model).threadId(i).parallelWrapper(parallelWrapper).workspaceMode(workspaceMode).useMDS(z).averagingFrequency(i3).build();
        build.setName("DefaultTrainer thread " + i);
        build.setDaemon(true);
        return build;
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void finalizeRound(Model model, Model... modelArr) {
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void finalizeTraining(Model model, Model... modelArr) {
        finalizeRound(model, modelArr);
    }
}
