@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException("Invalid input for Convolution layer (layer name=\"" + getLayerName() + "\"): Expected CNN input, got " + inputType); } if (nIn <= 0 || override) { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; this.nIn = c.getDepth(); } }
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, String layerName) { if (inputType == null) { throw new IllegalStateException( "Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null"); } switch (inputType.getType()) { case FF: case CNNFlat: //FF -> RNN or CNNFlat -> RNN //In either case, input data format is a row vector per example return new FeedForwardToRnnPreProcessor(); case RNN: //RNN -> RNN: No preprocessor necessary return null; case CNN: //CNN -> RNN InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getDepth()); default: throw new RuntimeException("Unknown input type: " + inputType); } }
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType == null) { throw new IllegalStateException( "Invalid input for layer (layer name = \"" + getLayerName() + "\"): input type is null"); } switch (inputType.getType()) { case FF: case CNNFlat: //FF -> FF and CNN (flattened format) -> FF: no preprocessor necessary return null; case RNN: //RNN -> FF return new RnnToFeedForwardPreProcessor(); case CNN: //CNN -> FF InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getDepth()); default: throw new RuntimeException("Unknown input type: " + inputType); } }
c.getDepth() * outputType.getHeight() * outputType.getWidth() * kernelSize[0] * kernelSize[1];
InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; if (c2.getDepth() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) { throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getDepth() + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels + "," + inputHeight + "," + inputWidth + ")");
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) { throw new InvalidInputTypeException( "SubsetVertex expects single input type. Received: " + Arrays.toString(vertexInputs)); } switch (vertexInputs[0].getType()) { case FF: return InputType.feedForward(to - from + 1); case RNN: return InputType.recurrent(to - from + 1); case CNN: InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) vertexInputs[0]; int depth = conv.getDepth(); if (to >= depth) { throw new InvalidInputTypeException("Invalid range: Cannot select depth subset [" + from + "," + to + "] inclusive from CNN activations with " + " [depth,width,height] = [" + depth + "," + conv.getWidth() + "," + conv.getHeight() + "]"); } return InputType.convolutional(conv.getHeight(), conv.getWidth(), from - to + 1); case CNNFlat: //TODO work out how to do this - could be difficult... throw new UnsupportedOperationException( "Subsetting data in flattened convolutional format not yet supported"); default: throw new RuntimeException("Unknown input type: " + vertexInputs[0]); } }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { int inH; int inW; int inDepth; if (inputType instanceof InputType.InputTypeConvolutional) { InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType; inH = conv.getHeight(); inW = conv.getWidth(); inDepth = conv.getDepth(); } else if (inputType instanceof InputType.InputTypeConvolutionalFlat) { InputType.InputTypeConvolutionalFlat conv = (InputType.InputTypeConvolutionalFlat) inputType; inH = conv.getHeight(); inW = conv.getWidth(); inDepth = conv.getDepth(); } else { throw new IllegalStateException( "Invalid input type: expected InputTypeConvolutional or InputTypeConvolutionalFlat." + " Got: " + inputType); } int outH = inH + padding[0] + padding[1]; int outW = inW + padding[2] + padding[3]; return InputType.convolutional(outH, outW, inDepth); }
@Override public LayerMemoryReport getMemoryReport(InputType inputType) { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType); int actElementsPerEx = outputType.arrayElementsPerExample(); //TODO Subsampling helper memory use... (CuDNN etc) //During forward pass: im2col array + reduce. Reduce is counted as activations, so only im2col is working mem int im2colSizePerEx = c.getDepth() * outputType.getHeight() * outputType.getWidth() * kernelSize[0] * kernelSize[1]; //Current implementation does NOT cache im2col etc... which means: it's recalculated on each backward pass int trainingWorkingSizePerEx = im2colSizePerEx; if (getDropOut() > 0) { //Dup on the input before dropout, but only for training trainingWorkingSizePerEx += inputType.arrayElementsPerExample(); } return new LayerMemoryReport.Builder(layerName, SubsamplingLayer.class, inputType, outputType) .standardMemory(0, 0) //No params .workingMemory(0, im2colSizePerEx, 0, trainingWorkingSizePerEx) .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .build(); }
@Override public void setNIn(InputType inputType, boolean override) { if (nIn <= 0 || override) { switch (inputType.getType()) { case FF: nIn = ((InputType.InputTypeFeedForward) inputType).getSize(); break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getDepth(); break; case CNNFlat: nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth(); default: throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer " + getLayerName() + "\""); } nOut = nIn; } }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType); } InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; int outSize = c.getDepth() * c.getHeight() * c.getWidth(); return InputType.feedForward(outSize); }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType); } InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; int outSize = c.getDepth() * c.getHeight() * c.getWidth(); return InputType.recurrent(outSize); }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException("Invalid input for Subsampling layer (layer name=\"" + getLayerName() + "\"): Expected CNN input, got " + inputType); } return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, convolutionMode, ((InputType.InputTypeConvolutional) inputType).getDepth(), layerIndex, getLayerName(), SubsamplingLayer.class); }