@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { //cond is only part of while loops if(nodeDef.getName().contains("/cond/")) return; //usually should be a merge node for a conditional val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph); val trueScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getTrueNodes()) { trueScopeGraphDefBuilder.addNode(node); } val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build()); val falseScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getFalseNodes()) { falseScopeGraphDefBuilder.addNode(node); } val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build()); val condScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getCondNodes()) { condScopeGraphDefBuilder.addNode(node); } val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build()); initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope); initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope); initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope); this.loopBodyExecution = trueScope; this.falseBodyExecution = falseScope; this.predicateExecution = condScope; }
/** * <pre> * Graphs of the partitions executed by executors. * </pre> * * <code>repeated .tensorflow.GraphDef partition_graphs = 3;</code> */ public Builder setPartitionGraphs( int index, org.tensorflow.framework.GraphDef.Builder builderForValue) { if (partitionGraphsBuilder_ == null) { ensurePartitionGraphsIsMutable(); partitionGraphs_.set(index, builderForValue.build()); onChanged(); } else { partitionGraphsBuilder_.setMessage(index, builderForValue.build()); } return this; } /**
/** * <pre> * Graphs of the partitions executed by executors. * </pre> * * <code>repeated .tensorflow.GraphDef partition_graphs = 3;</code> */ public Builder addPartitionGraphs( org.tensorflow.framework.GraphDef.Builder builderForValue) { if (partitionGraphsBuilder_ == null) { ensurePartitionGraphsIsMutable(); partitionGraphs_.add(builderForValue.build()); onChanged(); } else { partitionGraphsBuilder_.addMessage(builderForValue.build()); } return this; } /**
/** * <pre> * Graphs of the partitions executed by executors. * </pre> * * <code>repeated .tensorflow.GraphDef partition_graphs = 3;</code> */ public Builder addPartitionGraphs( int index, org.tensorflow.framework.GraphDef.Builder builderForValue) { if (partitionGraphsBuilder_ == null) { ensurePartitionGraphsIsMutable(); partitionGraphs_.add(index, builderForValue.build()); onChanged(); } else { partitionGraphsBuilder_.addMessage(index, builderForValue.build()); } return this; } /**
/** * <pre> * Definition of remote graph * </pre> * * <code>.tensorflow.GraphDef remote_graph = 1;</code> */ public Builder setRemoteGraph( org.tensorflow.framework.GraphDef.Builder builderForValue) { if (remoteGraphBuilder_ == null) { remoteGraph_ = builderForValue.build(); onChanged(); } else { remoteGraphBuilder_.setMessage(builderForValue.build()); } return this; } /**
/** * <pre> * GraphDef. * </pre> * * <code>.tensorflow.GraphDef graph_def = 2;</code> */ public Builder setGraphDef( org.tensorflow.framework.GraphDef.Builder builderForValue) { if (graphDefBuilder_ == null) { graphDef_ = builderForValue.build(); onChanged(); } else { graphDefBuilder_.setMessage(builderForValue.build()); } return this; } /**
/** * <pre> * Definition of remote graph * </pre> * * <code>.tensorflow.GraphDef remote_graph = 1;</code> */ public Builder setRemoteGraph( org.tensorflow.framework.GraphDef.Builder builderForValue) { if (remoteGraphBuilder_ == null) { remoteGraph_ = builderForValue.build(); onChanged(); } else { remoteGraphBuilder_.setMessage(builderForValue.build()); } return this; } /**