/** * ND Convolution * * @param input the input to op * @param kernel the kernel to op with * @param type the type of convolution * @param axes the axes to do the convolution along * @return the convolution of the given input and kernel */ @Override public IComplexNDArray convn(IComplexNDArray input, IComplexNDArray kernel, Convolution.Type type, int[] axes) { if (kernel.isScalar() && input.isScalar()) return kernel.mul(input); INDArray shape = NDArrayUtil.toNDArray(Shape.sizeForAxes(axes, input.shape())) .add(NDArrayUtil.toNDArray(Shape.sizeForAxes(axes, kernel.shape()))).subi(1); int[] intShape = NDArrayUtil.toInts(shape); IComplexNDArray ret = FFT.rawifftn(FFT.rawfftn(input, intShape, axes).muli(FFT.rawfftn(kernel, intShape, axes)), intShape, axes); switch (type) { case FULL: return ret; case SAME: return ComplexNDArrayUtil.center(ret, input.shape()); case VALID: return ComplexNDArrayUtil.center(ret, NDArrayUtil.toInts(Transforms.abs(NDArrayUtil .toNDArray(input.shape()).sub(NDArrayUtil.toNDArray(kernel.shape())).addi(1)))); } return ret; }
return convolution.getReal(); case SAME: return ComplexNDArrayUtil.center(convolution, input.shape()).getReal(); case VALID: int[] shape2 = NDArrayUtil.toInts(Transforms.abs(NDArrayUtil.toNDArray(input.shape()) .sub(NDArrayUtil.toNDArray(kernel.shape())).addi(1))); return ComplexNDArrayUtil.center(convolution, shape2).getReal();