@Override public void initializeBackend() { backend = new org.deeplearning4j.nn.conf.layers.LocalResponseNormalization(); }
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType == null) { throw new IllegalStateException( "Invalid input type for LRN layer (layer name = \"" + getLayerName() + "\"): null"); } return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getLayerName()); }
void initializeHelper() { try { helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnLocalResponseNormalizationHelper") .asSubclass(LocalResponseNormalizationHelper.class).newInstance(); log.debug("CudnnLocalResponseNormalizationHelper successfully initialized"); if (!helper.checkSupported(layerConf().getK(), layerConf().getN(), layerConf().getAlpha(), layerConf().getBeta())) { helper = null; } } catch (Throwable t) { if (!(t instanceof ClassNotFoundException)) { log.warn("Could not initialize CudnnLocalResponseNormalizationHelper", t); } } }
/** * Get layer output type. * * @param inputType Array of InputTypes * @return output type as InputType * @throws InvalidKerasConfigurationException */ @Override public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException { if (inputType.length > 1) throw new InvalidKerasConfigurationException( "Keras LRN layer accepts only one input (received " + inputType.length + ")"); return this.getLocalResponseNormalization().getOutputType(-1, inputType[0]); } }
@Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) { org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization ret = new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setConf(conf); return ret; }
@Override public LocalResponseNormalization build() { return new LocalResponseNormalization(this); }
@Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { throw new IllegalStateException( "Invalid input type for LRN layer (layer index = " + layerIndex + ", layer name = \"" + getLayerName() + "\"): Expected input of type CNN, got " + inputType); } return inputType; }