public static ComputationGraphConfiguration getConf() { ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() .seed(12345) .updater(new Adam(0.01)) .weightInit(WeightInit.RELU) .graphBuilder() .addInputs("in"); String[] poolNames = new String[ngramFilters.length]; int i = 0; for (int ngram : ngramFilters) { String filterName = String.format("ngram%d", ngram); poolNames[i] = String.format("pool%d", ngram); builder = builder.addLayer(filterName, new Convolution1DLayer.Builder() .nOut(numFilters) .kernelSize(ngram) .activation(Activation.RELU) .build(), "in") .addLayer(poolNames[i], new GlobalPoolingLayer.Builder(PoolingType.MAX).build(), filterName); i++; } return builder.addVertex("concat", new MergeVertex(), poolNames) .addLayer("predict", new DenseLayer.Builder().nOut(numClasses).dropOut(dropoutRetain) .activation(Activation.SOFTMAX).build(), "concat") .addLayer("loss", new LossLayer.Builder(LossFunctions.LossFunction.MCXENT).build(), "predict") .setOutputs("loss") .setInputTypes(InputType.recurrent(W2V_VECTOR_SIZE, 1000)) .build(); } }
/** * Get layer output type. * * @param inputType Array of InputTypes * @return output type as InputType * @throws InvalidKerasConfigurationException */ @Override public InputType getOutputType(InputType... inputType) { return this.vertex.getOutputType(-1, inputType); } }
@Override public ElementWiseVertex clone() { return new ElementWiseVertex(op); }
/** * Constructor from parsed Keras layer configuration dictionary. * * @param layerConfig dictionary containing Keras layer configuration * @param enforceTrainingConfig whether to enforce training-related configuration options * @throws InvalidKerasConfigurationException * @throws UnsupportedKerasConfigurationException */ public KerasMerge(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); this.mergeMode = getMergeMode(layerConfig); if (this.mergeMode == null) this.vertex = new MergeVertex(); else this.vertex = new ElementWiseVertex(mergeMode); }
/** * Constructor from parsed Keras layer configuration dictionary. * * @param layerConfig dictionary containing Keras layer configuration * @param enforceTrainingConfig whether to enforce training-related configuration options * @throws InvalidKerasConfigurationException * @throws UnsupportedKerasConfigurationException */ public KerasPoolHelper(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); this.vertex = new PoolHelperVertex(); }
@Override public ScaleVertex clone() { return new ScaleVertex(scaleFactor); }
@Override public L2NormalizeVertex clone() { return new L2NormalizeVertex(dimension, eps); }
@Override public ShiftVertex clone() { return new ShiftVertex(shiftFactor); }
@Override public StackVertex clone() { return new StackVertex(); }
@Override public SubsetVertex clone() { return new SubsetVertex(from, to); }
@Override public ReshapeVertex clone() { return new ReshapeVertex(newShape); }
@Override public int hashCode() { return op.hashCode(); }
@Override public L2Vertex clone() { return new L2Vertex(); }
@Override public UnstackVertex clone() { return new UnstackVertex(from, stackSize); }
@Override public MemoryReport getMemoryReport(InputType... inputTypes) { InputType outputType = getOutputType(-1, inputTypes); //TODO multiple input types return new LayerMemoryReport.Builder(null, MergeVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params .workingMemory(0, 0, 0, 0) //No working memory in addition to activations/epsilons .cacheMemory(0, 0) //No caching .build(); } }
@Override public MemoryReport getMemoryReport(InputType... inputTypes) { //Assume it's a reshape-with-copy op. In this case: memory use is accounted for in activations InputType outputType = getOutputType(-1, inputTypes); return new LayerMemoryReport.Builder(null, ReshapeVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params .workingMemory(0, 0, 0, 0).cacheMemory(0, 0) //No caching .build(); } }
.nOut(cnnLayerFeatureMaps) .build(), "input") .addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5") //Perform depth concatenation .addLayer("globalPool", new GlobalPoolingLayer.Builder() .poolingType(globalPoolingType)
/** * Get layer output type. * * @param inputType Array of InputTypes * @return output type as InputType * @throws InvalidKerasConfigurationException */ @Override public InputType getOutputType(InputType... inputType) { return this.vertex.getOutputType(-1, inputType); } }
@Override public PoolHelperVertex clone() { return new PoolHelperVertex(); }
@Override public MergeVertex clone() { return new MergeVertex(); }