@Override public LayerMemoryReport getMemoryReport(InputType inputType) { return layer.getMemoryReport(inputType); }
@Override public MemoryReport getMemoryReport(InputType... inputTypes) { //TODO preprocessor memory return layerConf.getLayer().getMemoryReport(inputTypes[0]); } }
/** * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the * memory requirements for the given network configuration and input * * @param inputType Input types for the network * @return Memory report for the network */ public NetworkMemoryReport getMemoryReport(InputType inputType) { Map<String, MemoryReport> memoryReportMap = new LinkedHashMap<>(); int nLayers = confs.size(); for (int i = 0; i < nLayers; i++) { String layerName = confs.get(i).getLayer().getLayerName(); if (layerName == null) { layerName = String.valueOf(i); } //Pass input type through preprocessor, if necessary InputPreProcessor preproc = getInputPreProcess(0); //TODO memory requirements for preprocessor if (preproc != null) { inputType = preproc.getOutputType(inputType); } LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); memoryReportMap.put(layerName, report); inputType = confs.get(i).getLayer().getOutputType(i, inputType); } return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType); }