/** * Calculate the gradient of the network with respect to some external errors. * Note that this is typically used for things like reinforcement learning, not typical networks that include * an OutputLayer or RnnOutputLayer * * @param epsilons Epsilons (errors) at the output. Same order with which the output layers are defined in configuration setOutputs(String...) * @return Gradient for the network */ public Gradient backpropGradient(INDArray... epsilons) { if (epsilons == null || epsilons.length != numOutputArrays) throw new IllegalArgumentException( "Invalid input: must have epsilons length equal to number of output arrays"); calcBackpropGradients(configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons); return gradient; }
public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) { ComputationGraphConfiguration clonedConf = newConf.clone(); this.vertices = clonedConf.getVertices(); this.vertexInputs = clonedConf.getVertexInputs(); this.networkInputs = clonedConf.getNetworkInputs(); this.networkOutputs = clonedConf.getNetworkOutputs(); this.pretrain = clonedConf.isPretrain(); this.backprop = clonedConf.isBackprop(); this.backpropType = clonedConf.getBackpropType(); this.tbpttFwdLength = clonedConf.getTbpttFwdLength(); this.tbpttBackLength = clonedConf.getTbpttBackLength(); this.globalConfiguration = globalConfiguration; //this.getGlobalConfiguration().setSeed(clonedConf.getDefaultConfiguration().getSeed()); }
@Override public void computeGradientAndScore() { if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) { Map<String, INDArray> activations = rnnActivateUsingStoredState(inputs, true, true); if (trainingListeners.size() > 0) {
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(new INDArray[] {next.getFeatures()}, new INDArray[] {next.getLabels()}, (hasMaskArrays ? new INDArray[] {next.getFeaturesMaskArray()} : null),
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays); } else {