/** * Reshape the parameters view, without modifying the paramsView array values. * * @param shape Shape to reshape * @param paramsView Parameters array view */ public static INDArray reshapeWeights(int[] shape, INDArray paramsView) { return reshapeWeights(shape, paramsView, DEFAULT_WEIGHT_INIT_ORDER); }
protected INDArray createWeightMatrix(int nIn, int nOut, WeightInit weightInit, Distribution dist, INDArray weightParamView, boolean initializeParameters) { int[] shape = new int[] {nIn, nOut}; if (initializeParameters) { INDArray ret = WeightInitUtil.initWeights(nIn, //Fan in nOut, //Fan out shape, weightInit, dist, weightParamView); return ret; } else { return WeightInitUtil.reshapeWeights(shape, weightParamView); } } }
@Test public void testWeightInit() throws Exception { List<WeightInit> skipWeightInits = new ArrayList<>(); skipWeightInits.add(WeightInit.IDENTITY); for (WeightInit wi : WeightInit.values()) { if (skipWeightInits.contains(wi)) { continue; } NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.setWeightInit(wi); checkAppliedParameters(conf, wi, BaseLayer::getWeightInit); } }
if (nParams > 0) { WeightInit wi = bl.getWeightInit(); String str = wi.toString(); if (wi == WeightInit.DISTRIBUTION) { str += bl.getDist();
public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) { //Note: using f order here as params get flattened to f order INDArray ret; switch (initScheme) { case DISTRIBUTION: ret = dist.sample(shape); break; case RELU: ret = Nd4j.randn(order, shape).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn) break; case RELU_UNIFORM: double u = Math.sqrt(6.0 / fanIn); ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) break; case SIGMOID_UNIFORM: double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r)); break; case UNIFORM: double a = 1.0 / Math.sqrt(fanIn); ret = Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a)); break; case XAVIER: ret = Nd4j.randn(order, shape).muli(FastMath.sqrt(2.0 / (fanIn + fanOut))); break; case XAVIER_UNIFORM: //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut)) //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
if (nParams > 0) { WeightInit wi = bl.getWeightInit(); String str = wi.toString(); if (wi == WeightInit.DISTRIBUTION) { str += bl.getDist();
return WeightInitUtil.initWeights(fanIn, fanOut, weightsShape, layerConf.getWeightInit(), dist, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); return WeightInitUtil.reshapeWeights( new int[] {layerConf.getNOut(), layerConf.getNIn(), kernel[0], kernel[1]}, weightView, 'c');
if (nParams > 0) { WeightInit wi = bl.getWeightInit(); String str = wi.toString(); if (wi == WeightInit.DISTRIBUTION) { str += bl.getDist();
int[] recurrentWShape = new int[] {nL, 4 * nL + 3}; params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, iwF)); params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, layerConf.getWeightInit(), dist, rwF)); params.put(BIAS_KEY_FORWARDS, bF); params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, iwR)); params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, layerConf.getWeightInit(), dist, rwR)); params.put(BIAS_KEY_BACKWARDS, bR); } else { params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new int[] {nLast, 4 * nL}, iwF)); params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new int[] {nL, 4 * nL + 3}, rwF)); params.put(BIAS_KEY_FORWARDS, bF); params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new int[] {nLast, 4 * nL}, iwR)); params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new int[] {nL, 4 * nL + 3}, rwR)); params.put(BIAS_KEY_BACKWARDS, bR);
int[] recurrentWShape = new int[] {nL, 4 * nL}; params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, layerConf.getWeightInit(), dist, recurrentWeightView)); biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, params.put(INPUT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[] {nLast, 4 * nL}, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[] {nL, 4 * nL}, recurrentWeightView)); params.put(BIAS_KEY, biasView);
int[] recurrentWShape = new int[] {nL, 4 * nL + 3}; params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, layerConf.getWeightInit(), dist, recurrentWeightView)); biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, params.put(INPUT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[] {nLast, 4 * nL}, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[] {nL, 4 * nL + 3}, recurrentWeightView)); params.put(BIAS_KEY, biasView);