case CSR: BaseSparseNDArrayCSR csrArray = (BaseSparseNDArrayCSR) s; if (csrArray.rows() == rows() && csrArray.columns() == columns() && csrArray.getVectorCoordinates().equals(getVectorCoordinates()) && csrArray.data().equals(data()) && csrArray.getPointerB().equals(getPointerB()) && csrArray.getPointerE().equals(getPointerE())) { return true; INDArray dense = toDense(); INDArray oDense = s.toDense(); return dense.equals(oDense); INDArray dense = toDense(); return dense.equals(o);
/** * Return the minor pointers. (columns for CSR, rows for CSC,...) * */ public DataBuffer getVectorCoordinates() { return Nd4j.getDataBufferFactory().create(columnsPointers, 0, length()); }
public INDArray get(INDArrayIndex... indexes) { if (indexes.length == 1 && indexes[0] instanceof NDArrayIndexAll || (indexes.length == 2 && (isRowVector() && indexes[0] instanceof PointIndex && indexes[0].offset() == 0 && indexes[1] instanceof NDArrayIndexAll || isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0 && indexes[0] instanceof NDArrayIndexAll))) return this; indexes = NDArrayIndex.resolve(shapeInfoDataBuffer(), indexes); ShapeOffsetResolution resolution = new ShapeOffsetResolution(this); resolution.exec(indexes); INDArray ret = subArray(resolution); return ret;
public BaseSparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, int[] shape) { checkArgument(pointerB.length == pointerE.length); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape)); init(shape); this.values = data; this.columnsPointers = Nd4j.getDataBufferFactory().createInt(data.length()); this.columnsPointers.setData(columnsPointers); this.length = columnsPointers.length; // The size of these pointers are constant int pointersSpace = rows; this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace); this.pointerB.setData(pointerB); this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace); this.pointerE.setData(pointerE); }
public INDArray putScalar(int row, int col, double value) { checkArgument(row < rows && 0 <= rows); checkArgument(col < columns && 0 <= columns); int idx = pointerB.getInt(row); int idxNextRow = pointerE.getInt(row); while (columnsPointers.getInt(idx) < col && columnsPointers.getInt(idx) < idxNextRow) { idx++; } if (columnsPointers.getInt(idx) == col) { values.put(idx, value); } else { //Add a new entry in both buffers at a given position values = addAtPosition(values, length, idx, value); columnsPointers = addAtPosition(columnsPointers, length, idx, col); length++; // shift the indices of the next rows pointerE.put(row, pointerE.getInt(row) + 1); for (int i = row + 1; i < rows; i++) { pointerB.put(i, pointerB.getInt(i) + 1); pointerE.put(i, pointerE.getInt(i) + 1); } } return this; }
checkArgument(pointerB.length == pointerE.length); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape)); init(shape); int valuesSpace = (int) (data.length * THRESHOLD_MEMORY_ALLOCATION); this.values = Nd4j.getDataBufferFactory().createDouble(valuesSpace);
@Override public DataBuffer data() { return Nd4j.getDataBufferFactory().create(values, 0, length()); }