@Override public ComputationGraphUpdater getComputationGraphUpdater() { if (computationGraphUpdater == null && model instanceof ComputationGraph) { computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) model); } return computationGraphUpdater; }
@Override public INDArray updaterState() { return getUpdater() != null ? getUpdater().getUpdaterStateViewArray() : null; }
public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState) { super(graph, updaterState); layersByName = new HashMap<>(); Layer[] layers = getOrderedLayers(); for (Layer l : layers) { layersByName.put(l.conf().getLayer().getLayerName(), l); } }
@Override public ComputationGraph clone() { ComputationGraph cg = new ComputationGraph(configuration.clone()); cg.init(params().dup(), false); if (solver != null) { //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however ComputationGraphUpdater u = this.getUpdater(); INDArray updaterState = u.getStateViewArray(); if (updaterState != null) { cg.getUpdater().setStateViewArray(updaterState.dup()); } } cg.listeners = this.listeners; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; String layerName = vertices[topologicalOrder[i]].getVertexName(); if (getLayer(layerName) instanceof FrozenLayer) { cg.getVertex(layerName).setLayerAsFrozen(); } } return cg; }
@Deprecated public void saveUpdater(Resource modelPath, String name) { Resource modelFile = modelPath.resolve(name + ".bin.gz"); INDArray updaterState = null; if(net instanceof MultiLayerNetwork) updaterState = ((MultiLayerNetwork) net).getUpdater().getStateViewArray(); else if(net instanceof ComputationGraph) updaterState = ((ComputationGraph) net).getUpdater().getStateViewArray(); if(updaterState != null) try(DataOutputStream dos = new DataOutputStream(modelFile.getGZIPOutputStream())){ Nd4j.write(updaterState, dos); dos.flush(); } catch (IOException ex) { log.error(ex.toString()); } }
cg.getUpdater().setStateViewArray(updaterState); } else if (gotOldUpdater && updater != null) { cg.setUpdater(updater);
@Override public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) { if (model instanceof ComputationGraph) { ComputationGraph graph = (ComputationGraph) model; if (computationGraphUpdater == null) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { computationGraphUpdater = new ComputationGraphUpdater(graph); } } computationGraphUpdater.update(gradient, getIterationCount(model), batchSize); } else { if (updater == null) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { updater = UpdaterCreator.getUpdater(model); } } Layer layer = (Layer) model; updater.update(layer, gradient, getIterationCount(model), batchSize); } }
boolean paramsEquals = network.params().equals(params()); boolean confEquals = getConfiguration().equals(network.getConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals;
@Override public void updateModel(@NonNull Model model) { this.shouldUpdate.set(true); if (replicatedModel instanceof MultiLayerNetwork) { replicatedModel.setParams(model.params().dup()); Updater updater = ((MultiLayerNetwork) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { updater = ((MultiLayerNetwork) replicatedModel).getUpdater(); INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false); } } else if (replicatedModel instanceof ComputationGraph) { replicatedModel.setParams(model.params().dup()); ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater = ((ComputationGraph) replicatedModel).getUpdater(); updater.setStateViewArray(viewD); } } Nd4j.getExecutioner().commit(); }
int batchSize = 0; if (updater != null && updater.getStateViewArray() != null) { List<INDArray> updaters = new ArrayList<>(); for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) { ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel(); updaters.add(workerModel.getUpdater().getStateViewArray()); batchSize += workerModel.batchSize();
Pair<Gradient, Double> gradAndScore = graph.gradientAndScore(); ComputationGraphUpdater updater = new ComputationGraphUpdater(graph); updater.update(gradAndScore.getFirst(), 0, graph.batchSize());
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) { if (layer instanceof MultiLayerNetwork) { return new MultiLayerUpdater((MultiLayerNetwork) layer); } else if (layer instanceof ComputationGraph) { return new ComputationGraphUpdater((ComputationGraph) layer); } else { return new LayerUpdater((Layer) layer); } }
@Override public void updateModel(@NonNull Model model) { this.shouldUpdate.set(true); if (replicatedModel instanceof MultiLayerNetwork) { replicatedModel.setParams(model.params().dup()); Updater updater = ((MultiLayerNetwork) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { updater = ((MultiLayerNetwork) replicatedModel).getUpdater(); INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false); } } else if (replicatedModel instanceof ComputationGraph) { replicatedModel.setParams(model.params().dup()); ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater = ((ComputationGraph) replicatedModel).getUpdater(); updater.setStateViewArray(viewD); } } Nd4j.getExecutioner().commit(); }
int batchSize = 0; if (updater != null && updater.getStateViewArray() != null) { List<INDArray> updaters = new ArrayList<>(); for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) { ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel(); updaters.add(workerModel.getUpdater().getStateViewArray()); batchSize += workerModel.batchSize();
/** * Get the ComputationGraphUpdater for the network */ public ComputationGraphUpdater getUpdater() { if (solver == null) { solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this)); } return solver.getOptimizer().getComputationGraphUpdater(); }
ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater(); if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null) updaterReplica.setStateViewArray( updaterOrigina.getStateViewArray().unsafeDuplication(true)); ((ComputationGraph) replicatedModel).getUpdater(); if (updaterReplica.getStateViewArray() != null) Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST); if (updaterReplica.getStateViewArray() != null) Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray(); } else if (model instanceof ComputationGraph) { updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater(); if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null) updaterReplica.setStateViewArray( updaterOrigina.getStateViewArray().unsafeDuplication(true)); ((ComputationGraph) replicatedModel).getUpdater(); if (updaterReplica.getStateViewArray() != null) Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST); if (updaterReplica.getStateViewArray() != null) Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);