@Override public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { for(int i = 0; i < graph.getNodeCount(); i++) { val node = graph.getNode(i); if(node.getName().equals(name)) return node; } return null; }
if(graph.getNode(i).getName().equals(trueDefName)) { onFalseDefinition = false; onTrueDefinition = true; if(graph.getNode(i).getName().contains("pred_id")) { onTrueDefinition = false; if(onTrueDefinition && !graph.getNode(i).equals(from)) { trueBodyNodes.add(graph.getNode(i)); else if(onFalseDefinition && !graph.getNode(i).equals(from)) { falseBodyNodes.add(graph.getNode(i)); val currNode = graph.getNode(i); if(currNode.equals(from)) continue; if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { break; seenNames.add(graph.getNode(i).getName()); conditionNodes.add(graph.getNode(i));
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }