@Override public int numInputsFor(NodeDef nodeDef) { return nodeDef.getInputCount(); }
protected boolean hasReductionIndices(NodeDef nodeDef) { for(int i = 0; i < nodeDef.getInputCount(); i++) { if(nodeDef.getInput(i).contains("reduction_indices")) { return true; } } return false; }
@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + NAME_FIELD_NUMBER; hash = (53 * hash) + getName().hashCode(); hash = (37 * hash) + OP_FIELD_NUMBER; hash = (53 * hash) + getOp().hashCode(); if (getInputCount() > 0) { hash = (37 * hash) + INPUT_FIELD_NUMBER; hash = (53 * hash) + getInputList().hashCode(); } hash = (37 * hash) + DEVICE_FIELD_NUMBER; hash = (53 * hash) + getDevice().hashCode(); if (!internalGetAttr().getMap().isEmpty()) { hash = (37 * hash) + ATTR_FIELD_NUMBER; hash = (53 * hash) + internalGetAttr().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
int concatDimension = -1; String input = null; for(int i = 0; i < nodeDef.getInputCount(); i++) { if(nodeDef.getInput(i).contains("/concat_dim")) { input = nodeDef.getInput(i); input = nodeDef.getInput(nodeDef.getInputCount() - 1); if(inputArguments().length == nodeDef.getInputCount()) { val inputArgs = inputArguments(); removeInputArgument(inputArgs[inputArguments().length - 1]);
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) { seenNames.add(currNode.getInput(inputIdx));
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val lastNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(nodeDef.getInputCount() - 1)); val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",lastNode,graph); if(arr != null) { this.axis = arr.data().asInt(); addArguments(); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { if (!nodeDef.containsAttr("TShape") && nodeDef.getInputCount() == 1) { this.shape = new long[]{}; return; } else if (nodeDef.getInputCount() > 1) { val shapeNode = nodeDef.getInput(1); NodeDef shapeNodeInGraph = null;
val vars = new SDVariable[tfNode.getInputCount()]; for (int e = 0; e < tfNode.getInputCount(); e++) { val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e)); vars[e] = initWith.getVariable(input) == null ? initWith.var(input,null,new ZeroInitScheme()) : initWith.getVariable(input); val variables = new SDVariable[tfNode.getInputCount()]; for(int i = 0; i < tfNode.getInputCount(); i++) { val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i))); if(testVar == null) { val variables = new SDVariable[tfNode.getInputCount()]; for(int i = 0; i < tfNode.getInputCount(); i++) { val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)); variables[i] = scopeCondition.getVariable(name);
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { if(nodeDef.getInputCount() == 2) { val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(1)); val mapper = TFGraphMapper.getInstance(); val secondInputAsScalar = mapper.getNDArrayFromTensor("value",targetNode,graph); //must be scalar if(secondInputAsScalar.length() == 1) { addTArgument(secondInputAsScalar.getDouble(0)); } else { throw new ND4JIllegalStateException("Second input to node " + nodeDef + " should be scalar!"); } } }
position += node.getInputCount();
if(mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) { int tfMappingIdx = mapping.getTfInputPosition(); if(tfMappingIdx < 0) tfMappingIdx += node.getInputCount();
val args = new SDVariable[tfNode.getInputCount()]; newInstance.setOwnName(tfNode.getName()); for(int i = 0; i < tfNode.getInputCount(); i++) { val name = getNodeName(tfNode.getInput(i)); args[i] = diff.getVariable(name);
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); if (nodeDef.getInputCount() < 2) return; NodeDef permuteDimsNode = null;
@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + NAME_FIELD_NUMBER; hash = (53 * hash) + getName().hashCode(); hash = (37 * hash) + OP_FIELD_NUMBER; hash = (53 * hash) + getOp().hashCode(); if (getInputCount() > 0) { hash = (37 * hash) + INPUT_FIELD_NUMBER; hash = (53 * hash) + getInputList().hashCode(); } hash = (37 * hash) + DEVICE_FIELD_NUMBER; hash = (53 * hash) + getDevice().hashCode(); if (!internalGetAttr().getMap().isEmpty()) { hash = (37 * hash) + ATTR_FIELD_NUMBER; hash = (53 * hash) + internalGetAttr().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
for(int i = 0; i < graphDef1.getNodeCount(); i++) { NodeDef node = graphDef1.getNode(i); for(int input = 0; input < node.getInputCount(); input++) { seenAsInput.add(node.getInput(input));
@Procedure(value = "load.tensorflow", mode = Mode.WRITE) public Stream<LoadResult> loadTensorFlow(@Name("file") String url) throws IOException { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream())); Map<String, Node> nodes = new HashMap<>(); // tod model node, layer nodes for (NodeDef nodeDef : graphDef.getNodeList()) { Node node = db.createNode(Types.Neuron); node.setProperty("name", nodeDef.getName()); if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice()); node.setProperty("op", nodeDef.getOp()); nodeDef.getAttrMap().forEach((k, v) -> { Object value = getValue(v); if (value != null) { node.setProperty(k, value); } }); nodes.put(nodeDef.getName(), node); } long rels = 0; for (NodeDef nodeDef : graphDef.getNodeList()) { Node target = nodes.get(nodeDef.getName()); nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT)); // todo weights rels += nodeDef.getInputCount(); } return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels)); }