Tabnine Logo
ComputationGraphUpdater
Code IndexAdd Tabnine to your IDE (free)

How to use
ComputationGraphUpdater
in
org.deeplearning4j.nn.updater.graph

Best Java code snippets using org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater (Showing top 18 results out of 315)

origin: de.datexis/texoo-core

@Deprecated
public void saveUpdater(Resource modelPath, String name) {
 Resource modelFile = modelPath.resolve(name + ".bin.gz");
 INDArray updaterState = null;
 if(net instanceof MultiLayerNetwork) updaterState = ((MultiLayerNetwork) net).getUpdater().getStateViewArray();
 else if(net instanceof ComputationGraph) updaterState = ((ComputationGraph) net).getUpdater().getStateViewArray();
 if(updaterState != null) try(DataOutputStream dos = new DataOutputStream(modelFile.getGZIPOutputStream())){
  Nd4j.write(updaterState, dos);
  dos.flush();
 } catch (IOException ex) {
  log.error(ex.toString());
 } 
}

origin: org.deeplearning4j/deeplearning4j-nn

  cg.getUpdater().setStateViewArray(updaterState);
} else if (gotOldUpdater && updater != null) {
  cg.setUpdater(updater);
origin: org.deeplearning4j/deeplearning4j-nn

@Override
public ComputationGraphUpdater getComputationGraphUpdater() {
  if (computationGraphUpdater == null && model instanceof ComputationGraph) {
    computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) model);
  }
  return computationGraphUpdater;
}
origin: org.deeplearning4j/deeplearning4j-nn

@Override
public ComputationGraph clone() {
  ComputationGraph cg = new ComputationGraph(configuration.clone());
  cg.init(params().dup(), false);
  if (solver != null) {
    //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
    ComputationGraphUpdater u = this.getUpdater();
    INDArray updaterState = u.getStateViewArray();
    if (updaterState != null) {
      cg.getUpdater().setStateViewArray(updaterState.dup());
    }
  }
  cg.listeners = this.listeners;
  for (int i = 0; i < topologicalOrder.length; i++) {
    if (!vertices[topologicalOrder[i]].hasLayer())
      continue;
    String layerName = vertices[topologicalOrder[i]].getVertexName();
    if (getLayer(layerName) instanceof FrozenLayer) {
      cg.getVertex(layerName).setLayerAsFrozen();
    }
  }
  return cg;
}
origin: org.deeplearning4j/deeplearning4j-nn

@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
  if (model instanceof ComputationGraph) {
    ComputationGraph graph = (ComputationGraph) model;
    if (computationGraphUpdater == null) {
      try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
        computationGraphUpdater = new ComputationGraphUpdater(graph);
      }
    }
    computationGraphUpdater.update(gradient, getIterationCount(model), batchSize);
  } else {
    if (updater == null) {
      try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
        updater = UpdaterCreator.getUpdater(model);
      }
    }
    Layer layer = (Layer) model;
    updater.update(layer, gradient, getIterationCount(model), batchSize);
  }
}
origin: org.deeplearning4j/deeplearning4j-nn

public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState) {
  super(graph, updaterState);
  layersByName = new HashMap<>();
  Layer[] layers = getOrderedLayers();
  for (Layer l : layers) {
    layersByName.put(l.conf().getLayer().getLayerName(), l);
  }
}
origin: org.deeplearning4j/deeplearning4j-nn

@Override
public INDArray updaterState() {
  return getUpdater() != null ? getUpdater().getUpdaterStateViewArray() : null;
}
origin: org.deeplearning4j/deeplearning4j-nn

boolean paramsEquals = network.params().equals(params());
boolean confEquals = getConfiguration().equals(network.getConfiguration());
boolean updaterEquals = getUpdater().equals(network.getUpdater());
return paramsEquals && confEquals && updaterEquals;
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

@Override
public void updateModel(@NonNull Model model) {
  this.shouldUpdate.set(true);
  if (replicatedModel instanceof MultiLayerNetwork) {
    replicatedModel.setParams(model.params().dup());
    Updater updater = ((MultiLayerNetwork) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
    }
  } else if (replicatedModel instanceof ComputationGraph) {
    replicatedModel.setParams(model.params().dup());
    ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater = ((ComputationGraph) replicatedModel).getUpdater();
      updater.setStateViewArray(viewD);
    }
  }
  Nd4j.getExecutioner().commit();
}
origin: org.deeplearning4j/deeplearning4j-nn

Pair<Gradient, Double> gradAndScore = graph.gradientAndScore();
ComputationGraphUpdater updater = new ComputationGraphUpdater(graph);
updater.update(gradAndScore.getFirst(), 0, graph.batchSize());
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

@Override
public void updateModel(@NonNull Model model) {
  this.shouldUpdate.set(true);
  if (replicatedModel instanceof MultiLayerNetwork) {
    replicatedModel.setParams(model.params().dup());
    Updater updater = ((MultiLayerNetwork) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
    }
  } else if (replicatedModel instanceof ComputationGraph) {
    replicatedModel.setParams(model.params().dup());
    ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater = ((ComputationGraph) replicatedModel).getUpdater();
      updater.setStateViewArray(viewD);
    }
  }
  Nd4j.getExecutioner().commit();
}
origin: org.deeplearning4j/deeplearning4j-nn

  updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
} else if (model instanceof ComputationGraph) {
  updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
origin: org.deeplearning4j/deeplearning4j-nn

public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
  if (layer instanceof MultiLayerNetwork) {
    return new MultiLayerUpdater((MultiLayerNetwork) layer);
  } else if (layer instanceof ComputationGraph) {
    return new ComputationGraphUpdater((ComputationGraph) layer);
  } else {
    return new LayerUpdater((Layer) layer);
  }
}
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater();
if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
  updaterReplica.setStateViewArray(
          updaterOrigina.getStateViewArray().unsafeDuplication(true));
            ((ComputationGraph) replicatedModel).getUpdater();
    if (updaterReplica.getStateViewArray() != null)
      Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
              AffinityManager.Location.HOST);
  if (updaterReplica.getStateViewArray() != null)
    Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
            AffinityManager.Location.HOST);
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
  List<INDArray> updaters = new ArrayList<>();
  for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
    ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
    updaters.add(workerModel.getUpdater().getStateViewArray());
    batchSize += workerModel.batchSize();
origin: org.deeplearning4j/deeplearning4j-nn

/**
 * Get the ComputationGraphUpdater for the network
 */
public ComputationGraphUpdater getUpdater() {
  if (solver == null) {
    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
    solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
  }
  return solver.getOptimizer().getComputationGraphUpdater();
}
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater();
if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
  updaterReplica.setStateViewArray(
          updaterOrigina.getStateViewArray().unsafeDuplication(true));
            ((ComputationGraph) replicatedModel).getUpdater();
    if (updaterReplica.getStateViewArray() != null)
      Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
              AffinityManager.Location.HOST);
  if (updaterReplica.getStateViewArray() != null)
    Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
            AffinityManager.Location.HOST);
origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
  List<INDArray> updaters = new ArrayList<>();
  for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
    ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
    updaters.add(workerModel.getUpdater().getStateViewArray());
    batchSize += workerModel.batchSize();
org.deeplearning4j.nn.updater.graphComputationGraphUpdater

Javadoc

Gradient updater for ComputationGraph. Most of the functionality is shared with org.deeplearning4j.nn.updater.MultiLayerUpdater via BaseMultiLayerUpdater

Most used methods

  • getStateViewArray
  • setStateViewArray
  • <init>
  • equals
  • getOrderedLayers
  • getUpdaterStateViewArray
  • update

Popular in Java

  • Reactive rest calls using spring rest template
  • requestLocationUpdates (LocationManager)
  • onCreateOptionsMenu (Activity)
  • getOriginalFilename (MultipartFile)
    Return the original filename in the client's filesystem.This may contain path information depending
  • GridLayout (java.awt)
    The GridLayout class is a layout manager that lays out a container's components in a rectangular gri
  • EOFException (java.io)
    Thrown when a program encounters the end of a file or stream during an input operation.
  • URL (java.net)
    A Uniform Resource Locator that identifies the location of an Internet resource as specified by RFC
  • Enumeration (java.util)
    A legacy iteration interface.New code should use Iterator instead. Iterator replaces the enumeration
  • FileUtils (org.apache.commons.io)
    General file manipulation utilities. Facilities are provided in the following areas: * writing to a
  • DateTimeFormat (org.joda.time.format)
    Factory that creates instances of DateTimeFormatter from patterns and styles. Datetime formatting i
  • From CI to AI: The AI layer in your organization
Tabnine Logo
  • Products

    Search for Java codeSearch for JavaScript code
  • IDE Plugins

    IntelliJ IDEAWebStormVisual StudioAndroid StudioEclipseVisual Studio CodePyCharmSublime TextPhpStormVimGoLandRubyMineEmacsJupyter NotebookJupyter LabRiderDataGripAppCode
  • Company

    About UsContact UsCareers
  • Resources

    FAQBlogTabnine AcademyTerms of usePrivacy policyJava Code IndexJavascript Code Index
Get Tabnine for your IDE now