@Override public String getName(NodeDef nodeDef) { return nodeDef.getName(); }
@Override public boolean alreadySeen(NodeDef nodeDef) { return seenNodes.contains(nodeDef.getName()); }
@Override public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) { Map<String,NodeDef> ret = new LinkedHashMap<>(); for(NodeDef nodeDef : graphDef.getNodeList()) { if(nodeDef.getName().endsWith("/read")) { continue; } val name = translateToSameDiffName(nodeDef.getName(), nodeDef); ret.put(name,nodeDef); } return ret; }
@Override public boolean shouldSkip(NodeDef opType) { if(opType == null) return true; boolean endsWithRead = opType.getName().endsWith("/read"); boolean isReductionIndices = opType.getOp().endsWith("/reduction_indices"); return endsWithRead || isReductionIndices; }
@Override public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) { if(nodeDef == null) { return null; } return getNDArrayFromTensor(nodeDef.getName(),nodeDef, graph); }
@Override public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { for(int i = 0; i < graph.getNodeCount(); i++) { val node = graph.getNode(i); if(node.getName().equals(name)) return node; } return null; }
/** * <pre> * The name given to this operator. Used for naming inputs, * logging, visualization, etc. Unique within a single GraphDef. * Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". * </pre> * * <code>string name = 1;</code> */ public Builder clearName() { name_ = getDefaultInstance().getName(); onChanged(); return this; } /**
@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + NAME_FIELD_NUMBER; hash = (53 * hash) + getName().hashCode(); hash = (37 * hash) + OP_FIELD_NUMBER; hash = (53 * hash) + getOp().hashCode(); if (getInputCount() > 0) { hash = (37 * hash) + INPUT_FIELD_NUMBER; hash = (53 * hash) + getInputList().hashCode(); } hash = (37 * hash) + DEVICE_FIELD_NUMBER; hash = (53 * hash) + getDevice().hashCode(); if (!internalGetAttr().getMap().isEmpty()) { hash = (37 * hash) + ATTR_FIELD_NUMBER; hash = (53 * hash) + internalGetAttr().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.NodeDef)) { return super.equals(obj); } org.tensorflow.framework.NodeDef other = (org.tensorflow.framework.NodeDef) obj; boolean result = true; result = result && getName() .equals(other.getName()); result = result && getOp() .equals(other.getOp()); result = result && getInputList() .equals(other.getInputList()); result = result && getDevice() .equals(other.getDevice()); result = result && internalGetAttr().equals( other.internalGetAttr()); result = result && unknownFields.equals(other.unknownFields); return result; }
if(graph.getNode(i).getName().equals(trueDefName)) { onFalseDefinition = false; onTrueDefinition = true; if(graph.getNode(i).getName().contains("pred_id")) { onTrueDefinition = false; if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { break; seenNames.add(graph.getNode(i).getName()); conditionNodes.add(graph.getNode(i));
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
if(node.getName().equals(nodeDef.getInput(0))) { startNode = node; if(node.getName().equals(nodeDef.getInput(1))) { endNode = node; if(node.getName().equals(nodeDef.getInput(2))) { deltaNode = node; val fromVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(startNode.getName())); val toVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(endNode.getName())); val deltaVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(deltaNode.getName()));
public Builder mergeFrom(org.tensorflow.framework.NodeDef other) { if (other == org.tensorflow.framework.NodeDef.getDefaultInstance()) return this; if (!other.getName().isEmpty()) { name_ = other.name_; onChanged(); } if (!other.getOp().isEmpty()) { op_ = other.op_; onChanged(); } if (!other.input_.isEmpty()) { if (input_.isEmpty()) { input_ = other.input_; bitField0_ = (bitField0_ & ~0x00000004); } else { ensureInputIsMutable(); input_.addAll(other.input_); } onChanged(); } if (!other.getDevice().isEmpty()) { device_ = other.device_; onChanged(); } internalGetMutableAttr().mergeFrom( other.internalGetAttr()); this.mergeUnknownFields(other.unknownFields); onChanged(); return this; }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val name = TFGraphMapper.getInstance().getNodeName(nodeDef.getName()); val input = initWith.getVariable(name); val outputVertex = input.getVarName(); if (!initWith.isPlaceHolder(input.getVarName()) && initWith.shapeAlreadyExistsForVarName(outputVertex)) { val inputShape = initWith.getShapeForVarName(input.getVarName()); val resultLength = Nd4j.scalar(inputShape.length); val thisResultId = outputVertex; initWith.putArrayForVarName(thisResultId, resultLength); initWith.putShapeForVarName(thisResultId, new long[]{1, 1}); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { //cond is only part of while loops if(nodeDef.getName().contains("/cond/")) return; //usually should be a merge node for a conditional val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph); val trueScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getTrueNodes()) { trueScopeGraphDefBuilder.addNode(node); } val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build()); val falseScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getFalseNodes()) { falseScopeGraphDefBuilder.addNode(node); } val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build()); val condScopeGraphDefBuilder = GraphDef.newBuilder(); for(val node : ifNodes.getCondNodes()) { condScopeGraphDefBuilder.addNode(node); } val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build()); initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope); initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope); initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope); this.loopBodyExecution = trueScope; this.falseBodyExecution = falseScope; this.predicateExecution = condScope; }
NodeDef permuteDimsNode = null; for (int i = 0; i < graph.getNodeCount(); i++) { if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) { permuteDimsNode = graph.getNode(i);