@Override public INDArray doCreate(long[] shape, INDArray paramsView) { //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut)) //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double a = 1.0 / Math.sqrt(fanIn); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double b = 3.0 / Math.sqrt(fanIn); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double scalingFanIn = 3.0 / Math.sqrt(fanIn); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double u = Math.sqrt(6.0 / fanIn); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { double scalingFanOut = 3.0 / Math.sqrt(fanOut); return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); }
@Override public INDArray rand(long[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) { //ensure shapes that wind up being scalar end up with the write shape if (shape.length == 1 && shape[0] == 0) { shape = new long[] {1, 1}; } return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
/** * Generates a random matrix between min and max * * @param shape the number of rows of the matrix * @param min the minimum number * @param max the maximum number * @param rng the rng to use * @return a random matrix of the specified shape and range */ @Override public INDArray rand(int[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) { //ensure shapes that wind up being scalar end up with the write shape if (shape.length == 1 && shape[0] == 0) { shape = new int[] {1, 1}; } return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
@Override public INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) { Nd4j.getRandom().setSeed(rng.getSeed()); return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
@Override public INDArray rand(long[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) { Nd4j.getRandom().setSeed(rng.getSeed()); return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
/** * Create an ndarray * of * @param seed * @param rank * @param numShapes * @return */ public static int[][] getRandomBroadCastShape(long seed, int rank, int numShapes) { Nd4j.getRandom().setSeed(seed); INDArray coinFlip = Nd4j.getDistributions().createBinomial(1, 0.5).sample(new int[] {numShapes, rank}); int[][] ret = new int[(int) coinFlip.rows()][(int) coinFlip.columns()]; for (int i = 0; i < coinFlip.rows(); i++) { for (int j = 0; j < coinFlip.columns(); j++) { int set = coinFlip.getInt(i, j); if (set > 0) ret[i][j] = set; else { //anything from 0 to 9 ret[i][j] = Nd4j.getRandom().nextInt(9) + 1; } } } return ret; }
/** * Generates a random matrix between min and max * * @param shape the number of rows of the matrix * @param min the minimum number * @param max the maximum number * @param rng the rng to use * @return a random matrix of the specified shape and range */ @Override public INDArray rand(int[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) { //ensure shapes that wind up being scalar end up with the write shape if (shape.length == 1 && shape[0] == 0) { shape = new int[] {1, 1}; } return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
@Override public INDArray preProcess(INDArray input, int miniBatchSize) { return Nd4j.getDistributions().createBinomial(1, input).sample(input.shape()); }
public static org.nd4j.linalg.api.rng.distribution.Distribution createDistribution(Distribution dist) { if (dist == null) return null; if (dist instanceof NormalDistribution) { NormalDistribution nd = (NormalDistribution) dist; return Nd4j.getDistributions().createNormal(nd.getMean(), nd.getStd()); } if (dist instanceof GaussianDistribution) { GaussianDistribution nd = (GaussianDistribution) dist; return Nd4j.getDistributions().createNormal(nd.getMean(), nd.getStd()); } if (dist instanceof UniformDistribution) { UniformDistribution ud = (UniformDistribution) dist; return Nd4j.getDistributions().createUniform(ud.getLower(), ud.getUpper()); } if (dist instanceof BinomialDistribution) { BinomialDistribution bd = (BinomialDistribution) dist; return Nd4j.getDistributions().createBinomial(bd.getNumberOfTrials(), bd.getProbabilityOfSuccess()); } throw new RuntimeException("unknown distribution type: " + dist.getClass()); } }
/** * Corrupts the given input by doing a binomial sampling * given the corruption level * @param x the input to corrupt * @param corruptionLevel the corruption value * @return the binomial sampled corrupted input */ public INDArray getCorruptedInput(INDArray x, double corruptionLevel) { INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1 - corruptionLevel).sample(x.shape()); corrupted.muli(x); return corrupted; }
@Override public INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) { Nd4j.getRandom().setSeed(rng.getSeed()); return Nd4j.getDistributions().createUniform(min, max).sample(shape); }
public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException { for(int j = 0; j < mnist.numExamples(); j++) { INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255); INDArray reconstructed2 = reconstruct.getRow(j); INDArray draw2 = Nd4j.getDistributions().createBinomial(1,reconstructed2).sample(reconstructed2.shape()).mul(255); DrawReconstruction d = new DrawReconstruction(draw1); d.title = "REAL"; d.draw(); DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000); d2.title = "TEST"; d2.draw(); Thread.sleep(1000); d.frame.dispose(); d2.frame.dispose(); } }
public static void drawMnist(DataSet mnist, INDArray reconstruct) throws InterruptedException { for (int j = 0; j < mnist.numExamples(); j++) { INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255); INDArray reconstructed2 = reconstruct.getRow(j); INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2).sample(reconstructed2.shape()) .mul(255); DrawReconstruction d = new DrawReconstruction(draw1); d.title = "REAL"; d.draw(); DrawReconstruction d2 = new DrawReconstruction(draw2, 1000, 1000); d2.title = "TEST"; d2.draw(); Thread.sleep(1000); d.frame.dispose(); d2.frame.dispose(); } }