public static ComputationGraphConfiguration getConf() { ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() .seed(12345) .updater(new Adam(0.01)) .weightInit(WeightInit.RELU) .graphBuilder() .addInputs("in"); String[] poolNames = new String[ngramFilters.length]; int i = 0; for (int ngram : ngramFilters) { String filterName = String.format("ngram%d", ngram); poolNames[i] = String.format("pool%d", ngram); builder = builder.addLayer(filterName, new Convolution1DLayer.Builder() .nOut(numFilters) .kernelSize(ngram) .activation(Activation.RELU) .build(), "in") .addLayer(poolNames[i], new GlobalPoolingLayer.Builder(PoolingType.MAX).build(), filterName); i++; } return builder.addVertex("concat", new MergeVertex(), poolNames) .addLayer("predict", new DenseLayer.Builder().nOut(numClasses).dropOut(dropoutRetain) .activation(Activation.SOFTMAX).build(), "concat") .addLayer("loss", new LossLayer.Builder(LossFunctions.LossFunction.MCXENT).build(), "predict") .setOutputs("loss") .setInputTypes(InputType.recurrent(W2V_VECTOR_SIZE, 1000)) .build(); } }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input for 1D CNN layer (layer index = " + layerIndex + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType); } return InputType.recurrent(nOut); }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { //Infer output shape from specified shape: switch (newShape.length) { case 2: return InputType.feedForward(newShape[1]); case 3: return InputType.recurrent(newShape[1]); case 4: return InputType.convolutional(newShape[1], newShape[2], newShape[0]); //[mb,d,h,w] for activations default: throw new UnsupportedOperationException( "Cannot infer input type for reshape array " + Arrays.toString(newShape)); } }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input for RNN layer (layer index = " + layerIndex + ", layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType); } InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; return InputType.recurrent(nOut, itr.getTimeSeriesLength()); }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer index = " + layerIndex + ", layer name=\"" + getLayerName() + "\"): Expected RNN input, got " + inputType); } InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; return InputType.recurrent(nOut, itr.getTimeSeriesLength()); }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType); } if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; return InputType.recurrent(ff.getSize()); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; return InputType.recurrent(cf.getFlattenedSize()); } }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType); } InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; int outSize = c.getDepth() * c.getHeight() * c.getWidth(); return InputType.recurrent(outSize); }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) throw new InvalidInputTypeException("Invalid input type: cannot duplicate more than 1 input"); int tsLength = 1; //TODO work this out properly if (vertexInputs[0].getType() == InputType.Type.FF) { return InputType.recurrent(((InputType.InputTypeFeedForward) vertexInputs[0]).getSize(), tsLength); } else if (vertexInputs[0].getType() != InputType.Type.CNNFlat) { return InputType.recurrent(((InputType.InputTypeConvolutionalFlat) vertexInputs[0]).getFlattenedSize(), tsLength); } else { throw new InvalidInputTypeException( "Invalid input type: cannot duplicate to time series non feed forward (or CNN flat) input (got: " + vertexInputs[0] + ")"); } }
private InputType[] getInputTypes(DomainDescriptor domainDescriptor) { String[] inputNames = getInputNames(); InputType[] inputTypes = new InputType[inputNames.length]; for (int i = 0; i < inputNames.length; i++) { switch(inputNames[i]) { case "input": inputTypes[i] = InputType.feedForward(domainDescriptor.getNumInputs(inputNames[i])[0]); break; case "indel": case "trueGenotypeInput": inputTypes[i] = InputType.recurrent(domainDescriptor.getNumInputs(inputNames[i])[0]); break; default: throw new RuntimeException("Invalid input to computation graph"); } } return inputTypes; }
private InputType[] getInputTypes(DomainDescriptor domainDescriptor) { String[] inputNames = getInputNames(); InputType[] inputTypes = new InputType[inputNames.length]; for (int i = 0; i < inputNames.length; i++) { switch (inputNames[i]) { case "input": inputTypes[i] = InputType.feedForward(domainDescriptor.getNumInputs(inputNames[i])[0]); break; case "from": case "G1": case "G2": case "G3": case "trueGenotypeInput": inputTypes[i] = InputType.recurrent(domainDescriptor.getNumInputs(inputNames[i])[0]); break; default: throw new RuntimeException("Invalid input to computation graph"); } } return inputTypes; }
private InputType getInputTypes(DomainDescriptor domainDescriptor) { final MappedDimensions inputDimensions = domainDescriptor.getFeatureMapper("input").dimensions(); System.out.printf("GenotypeSegmentsLSTM dimensions: sequence-length=%d, num-float-per-base=%d%n", inputDimensions.numElements(1), inputDimensions.numElements(2)); return InputType.recurrent(inputDimensions.numElements(1), inputDimensions.numElements(2)); }
public static InputType inferInputType(INDArray inputArray) { //Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something // like FeedForwardToCnnPreProcessor switch (inputArray.rank()) { case 2: return InputType.feedForward(inputArray.size(1)); case 3: return InputType.recurrent(inputArray.size(1), inputArray.size(2)); case 4: //Order: [minibatch, depth, height, width] -> [h, w, d] return InputType.convolutional(inputArray.size(2), inputArray.size(3), inputArray.size(1)); default: throw new IllegalArgumentException( "Cannot infer input type for array with shape: " + Arrays.toString(inputArray.shape())); } }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) { throw new InvalidInputTypeException( "SubsetVertex expects single input type. Received: " + Arrays.toString(vertexInputs)); } switch (vertexInputs[0].getType()) { case FF: return InputType.feedForward(to - from + 1); case RNN: return InputType.recurrent(to - from + 1); case CNN: InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) vertexInputs[0]; int depth = conv.getDepth(); if (to >= depth) { throw new InvalidInputTypeException("Invalid range: Cannot select depth subset [" + from + "," + to + "] inclusive from CNN activations with " + " [depth,width,height] = [" + depth + "," + conv.getWidth() + "," + conv.getHeight() + "]"); } return InputType.convolutional(conv.getHeight(), conv.getWidth(), from - to + 1); case CNNFlat: //TODO work out how to do this - could be difficult... throw new UnsupportedOperationException( "Subsetting data in flattened convolutional format not yet supported"); default: throw new RuntimeException("Unknown input type: " + vertexInputs[0]); } }
int nIn = brl.getNIn(); if (nIn > 0) { inputType = InputType.recurrent(nIn);
build.setInputTypes(InputType.recurrent(numLSTMInputs, numTimeSteps)); String lstmInputName = "input"; String lstmLayerName = "no layer";