/** * Inserts the element at the specified index * * @param i the row insert into * @param j the column to insert into * @param element a scalar ndarray * @return a scalar ndarray of the element at this index */ @Override public IComplexNDArray put(int i, int j, Number element) { return (IComplexNDArray) super.put(i, j, Nd4j.scalar(element)); }
/** * Fetch a particular number on a multi dimensional scale. * * @param indexes the indexes to get a number from * @return the number at the specified indices */ @Override public INDArray getScalar(int... indexes) { return Nd4j.scalar(getDouble(indexes)); }
/** * @param name * @param value * @return */ public SDVariable scalar(String name, double value) { return var(name, Nd4j.scalar(value)); }
/** * Returns the element at the specified index * * @param i the index of the element to return * @return a scalar ndarray of the element at this index */ @Override public IComplexNDArray getScalar(long i) { return Nd4j.scalar(getComplex(i)); }
@Override public IComplexNDArray put(INDArrayIndex[] indices, Number element) { return put(indices, Nd4j.scalar(element)); }
@Override public INDArray putWhere(Number comp, INDArray put, Condition condition) { return putWhere(Nd4j.scalar(comp),put,condition); }
@Override public IComplexNDArray put(INDArrayIndex[] indices, IComplexNumber element) { return put(indices, Nd4j.scalar(element)); }
@Override public IComplexNDArray putScalar(int i, double value) { return put(i, Nd4j.scalar(value)); }
@Override public INDArray getScalar(long i) { return Nd4j.scalar(getDouble(i)); }
/** * @param mean row vector of means * @param std row vector of standard deviations */ public DistributionStats(@NonNull INDArray mean, @NonNull INDArray std) { Transforms.max(std, Nd4j.EPS_THRESHOLD, false); if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) { logger.info("API_INFO: Std deviation found to be zero. Transform will round up to epsilon to avoid nans."); } this.mean = mean; this.std = std; }
/** * * @param x * @return */ public static INDArray computeAbsoluteStep(INDArray x) { INDArray relStep = pow(Nd4j.scalar(Nd4j.EPS_THRESHOLD),0.5); return computeAbsoluteStep(relStep,x); }
@Override public void multiplyBy(double num) { getFeatures().muli(Nd4j.scalar(num)); }
public void fit(DataSet dataSet) { mean = dataSet.getFeatureMatrix().mean(0); std = dataSet.getFeatureMatrix().std(0); std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans."); }
public static void normalizeMatrix(INDArray toNormalize) { INDArray columnMeans = toNormalize.mean(0); toNormalize.subiRowVector(columnMeans); INDArray std = toNormalize.std(0); std.addi(Nd4j.scalar(1e-12)); toNormalize.diviRowVector(std); }
public Choose(String opName, INDArray[] inputs, Condition condition) { super(opName, inputs, null); if(condition == null) { throw new ND4JIllegalArgumentException("Must specify a condition."); } addInputArgument(inputs); addIArgument(condition.condtionNum()); addOutputArgument(Nd4j.create(inputs[0].length()),Nd4j.scalar(1.0)); }
@Override public INDArray mmul(INDArray other) { long[] shape = {rows(), other.columns()}; INDArray result = createUninitialized(shape, 'f'); if (result.isScalar()) return Nd4j.scalar(Nd4j.getBlasWrapper().dot(this, other)); return mmuli(other, result); }
/** * @Deprecated * Subtract by the column means and divide by the standard deviation */ @Deprecated @Override public void normalizeZeroMeanZeroUnitVariance() { INDArray columnMeans = getFeatures().mean(0); INDArray columnStds = getFeatureMatrix().std(0); setFeatures(getFeatures().subiRowVector(columnMeans)); columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); setFeatures(getFeatures().diviRowVector(columnStds)); }
@Override public void roundToTheNearest(int roundTo) { for (int i = 0; i < getFeatures().length(); i++) { double curr = (double) getFeatures().getScalar(i).element(); getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo))); } }
/** * * @param relStep * @param x * @return */ public static INDArray computeAbsoluteStep(INDArray relStep,INDArray x) { if(relStep == null) { relStep = pow(Nd4j.scalar(getEpsRelativeTo(x)),0.5); } INDArray signX0 = x.gte(0).muli(2).subi(1); return signX0.mul(relStep).muli(max(abs(x),1.0)); }