@Override public Pair<DataBuffer, long[]> createShapeInformation(long[] shape, long[] stride, long offset, long elementWiseStride, char order) { DataBuffer buffer = Shape.createShapeInformation(shape, stride, offset, elementWiseStride, order); buffer.setConstant(true); return Pair.create(buffer, buffer.asLong()); }
/** * Gets the rank given the shape info buffer * @param buffer the buffer to get the rank for * @return the rank for the shape buffer */ public static int length(DataBuffer buffer) { int ret = 1; val rr = buffer.asLong(); DataBuffer shape = Shape.shapeOf(buffer); int rank = Shape.rank(buffer); for (int i = 0; i < rank; i++) ret *= shape.getLong(i); return ret; }
@Override public List<long[]> calculateOutputShape() { int numArgs = args().length; if(numArgs < 1) return Collections.emptyList(); val shape = args()[0].getArr(); if(shape == null) return Collections.emptyList(); else { if(shape.length() == 1) { if(shape.getDouble(0) < 1) return Arrays.asList(new long[]{1,1}); else return Arrays.asList(new long[]{1,shape.getInt(0)}); } } return Arrays.asList(shape.data().asLong()); }
@Override public long[] toLongVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asLong(); }
@Override public Pair<DataBuffer, long[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) { DataBuffer buffer = Shape.createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), offset, (long) elementWiseStride, order); buffer.setConstant(true); return Pair.create(buffer, buffer.asLong()); }
@Override public void resolvePropertiesFromSameDiffBeforeExecution() { super.resolvePropertiesFromSameDiffBeforeExecution(); if (arrName != null) { val args = args(); val firstInputShape = args[0].getShape(); val shapeInput = args[1].getArr().data().asLong(); for (int i = 0; i < shapeInput.length; i++) { if (shapeInput[i] < 0) { shapeInput[i] = firstInputShape[i]; } } this.shape = shapeInput; addIArgument(shapeInput); } }
@Override public int toFlatArray(FlatBufferBuilder builder) { int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); int buffer = FlatArray.createBufferVector(builder, this.data().asBytes()); int array = FlatArray.createFlatArray(builder, shape, buffer, SameDiff.getDataTypeAsByte(this.data().dataType()), ByteOrder.BE); return array; }
protected void read(ObjectInputStream s) { shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())], 0); shapeInformation.read(s); setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); data = Nd4j.createBuffer(length(), false); data().read(s); }
long[] idx = getUnderlyingIndicesOf(i).asLong(); if (Arrays.equals(idx, physicalIndexes)) {
this.shape = arr.data().asLong();
DataBuffer shapeInfo = tadInfo.getFirst(); val shape = Shape.shape(shapeInfo); val stride = Shape.stride(shapeInfo).asLong(); long offset = offset() + tadInfo.getSecond().getLong(index); INDArray toTad = Nd4j.create(data(), shape, stride, offset);
/** * * PLEASE NOTE: Never use this method, unless you 100% have to * * @param buffer */ public void setShapeInfoDataBuffer(DataBuffer buffer) { this.shapeInformation = buffer; this.jvmShapeInfo = new JvmShapeInfo(shapeInformation.asLong()); }
/** * PLEASE NOTE: This method implementation is hardware-dependant. * PLEASE NOTE: This method does NOT allow concurrent use of any array * * @param dataBuffer * @return */ @Override public DataBuffer relocateConstantSpace(DataBuffer dataBuffer) { // we always assume that data is sync, and valid on host side Integer deviceId = AtomicAllocator.getInstance().getDeviceId(); ensureMaps(deviceId); if (dataBuffer instanceof CudaIntDataBuffer) { int[] data = dataBuffer.asInt(); return getConstantBuffer(data); } else if (dataBuffer instanceof CudaFloatDataBuffer) { float[] data = dataBuffer.asFloat(); return getConstantBuffer(data); } else if (dataBuffer instanceof CudaDoubleDataBuffer) { double[] data = dataBuffer.asDouble(); return getConstantBuffer(data); } else if (dataBuffer instanceof CudaHalfDataBuffer) { float[] data = dataBuffer.asFloat(); return getConstantBuffer(data); } else if (dataBuffer instanceof CudaLongDataBuffer) { long[] data = dataBuffer.asLong(); return getConstantBuffer(data); } throw new IllegalStateException("Unknown CudaDataBuffer opType"); }