SimpleCPUConvolution.java
package neureka.backend.main.implementations.convolution;
import neureka.Tensor;
import neureka.devices.host.CPU;
import neureka.ndim.config.NDConfiguration;
/**
* Performs fast image convolution on nd-array of rank 2 (matrices) or rank 3 (tensors with batch size)
* where one tensor is the kernel and the other one is the image.
*/
class SimpleCPUConvolution
{
Conv2DImpl _impl;
SimpleCPUConvolution(Tensor<?> in1, Tensor<?> in2, Tensor<?> out ) {
Conv2DImpl impl = null;
try {
impl = _tryCreatingImplFor( in1, in2, out );
}
catch ( Exception ignored ) {}
try {
impl = _tryCreatingImplFor( in2, in1, out );
}
catch ( Exception ignored ) {}
_impl = impl;
}
public void run() {
if ( _impl == null ) throw new IllegalStateException("Not runnable!");
_impl.run();
}
public boolean isSuitable() { return _impl != null; }
private static Conv2DImpl _tryCreatingImplFor(
final Tensor<?> image,
final Tensor<?> kernel,
final Tensor<?> result
) {
validate(image);
validate(kernel);
validate(result);
int batchSize = (image.rank() == 3 ? image.shape(0) : 1);
int shapeOffset = (image.rank() == 3 ? 1 : 0);
int width = image.shape(shapeOffset + 1);
int height = image.shape(shapeOffset + 0);
int kernelWidth = kernel.shape(shapeOffset + 1);
int kernelHeight = kernel.shape(shapeOffset + 0);
int kernelBatchSize = (kernel.rank() == 3 ? kernel.shape(0) : 1);
int resultWidth = width - kernelWidth + 1;
int resultHeight = height - kernelHeight + 1;
if ( kernelBatchSize > 1 )
throw new IllegalArgumentException("Kernel batch size must be 1!");
if ( batchSize * resultHeight * resultWidth != result.size() )
throw new IllegalArgumentException("The result array must have the same length as the batch size times the result height times the result width!");
Class<?> c1 = image.itemType();
Class<?> c2 = kernel.itemType();
Class<?> c3 = result.itemType();
if ( c1 != c2 || c2 != c3 )
throw new IllegalArgumentException("All inputs must be of the same type!");
if ( c1 == Float.class )
return new ImplF32(
image.mut().getDataAs(float[].class),
kernel.mut().getDataAs(float[].class),
result.mut().getDataForWriting(float[].class),
width,
height,
kernelWidth,
kernelHeight,
resultWidth,
resultHeight,
batchSize
);
else if ( c1 == Double.class )
return new ImplF64(
image.mut().getDataAs(double[].class),
kernel.mut().getDataAs(double[].class),
result.mut().getDataForWriting(double[].class),
width,
height,
kernelWidth,
kernelHeight,
resultWidth,
resultHeight,
batchSize
);
else
throw new IllegalArgumentException("Unsupported data type!");
}
interface Conv2DImpl {
void run();
}
private static class ImplF32 implements Conv2DImpl {
private final float[] _image;
private final float[] _kernel;
private final float[] _result;
private final int _width, _height, _kernelWidth, _kernelHeight, _resultWidth, _resultHeight, _batchSize;
private ImplF32(
float[] image,
float[] kernel,
float[] result,
int width,
int height,
int kernelWidth,
int kernelHeight,
int resultWidth,
int resultHeight,
int batchSize
) {
_image = image;
_kernel = kernel;
_width = width;
_height = height;
_kernelWidth = kernelWidth;
_kernelHeight = kernelHeight;
_resultWidth = resultWidth;
_resultHeight = resultHeight;
_batchSize = batchSize;
if ( _batchSize * _resultHeight * _resultWidth != result.length )
throw new IllegalArgumentException("The result array must have the same length as the batch size times the result height times the result width!");
_result = result;
}
@Override
public void run() {
int work = _resultHeight * _resultWidth;
if ( work < 1000 )
for ( int bi = 0; bi < _batchSize; bi++ ) run(bi);
else
CPU.get().getExecutor().threaded(_batchSize, this::run);
}
private void run(int batchIndex) {
int imageOffset = batchIndex * _width * _height;
int resultOffset = batchIndex * _resultWidth * _resultHeight;
for ( int y = 0; y < _resultHeight; y++ ) {
for ( int x = 0; x < _resultWidth; x++ ) {
float sum = 0;
for ( int ky = 0; ky < _kernelHeight; ky++ )
for ( int kx = 0; kx < _kernelWidth; kx++ )
sum +=
_image[imageOffset + (y + ky) * _width + (x + kx)]
*
_kernel[ky * _kernelWidth + kx];
_result[resultOffset + y * _resultWidth + x] = sum;
}
}
}
}
private static class ImplF64 implements Conv2DImpl {
private final double[] _image;
private final double[] _kernel;
private final double[] _result;
private final int _width, _height, _kernelWidth, _kernelHeight, _resultWidth, _resultHeight, _batchSize;
private ImplF64(
double[] image,
double[] kernel,
double[] result,
int width,
int height,
int kernelWidth,
int kernelHeight,
int resultWidth,
int resultHeight,
int batchSize
) {
_image = image;
_kernel = kernel;
_width = width;
_height = height;
_kernelWidth = kernelWidth;
_kernelHeight = kernelHeight;
_resultWidth = resultWidth;
_resultHeight = resultHeight;
_batchSize = batchSize;
if ( _batchSize * _resultHeight * _resultWidth != result.length )
throw new IllegalArgumentException("The result array must have the same length as the batch size times the result height times the result width!");
_result = result;
}
@Override
public void run() {
int work = _resultHeight * _resultWidth;
if ( work < 1000 )
for ( int bi = 0; bi < _batchSize; bi++ ) run(bi);
else
CPU.get().getExecutor().threaded(_batchSize, this::run);
}
private void run(int batchIndex) {
int imageOffset = batchIndex * _width * _height;
int resultOffset = batchIndex * _resultWidth * _resultHeight;
for ( int y = 0; y < _resultHeight; y++ ) {
for ( int x = 0; x < _resultWidth; x++ ) {
double sum = 0;
for ( int ky = 0; ky < _kernelHeight; ky++ )
for ( int kx = 0; kx < _kernelWidth; kx++ )
sum +=
_image[imageOffset + (y + ky) * _width + (x + kx)]
*
_kernel[ky * _kernelWidth + kx];
_result[resultOffset + y * _resultWidth + x] = sum;
}
}
}
}
private static void validate(Tensor<?> t) {
if ( t.getRank() != 2 && t.getRank() != 3 )
throw new IllegalArgumentException("The rank of the tensor must be 2 or 3!");
NDConfiguration.Layout layout = t.getNDConf().getLayout();
if ( layout != NDConfiguration.Layout.ROW_MAJOR && layout != NDConfiguration.Layout.SYMMETRIC )
throw new IllegalArgumentException("The layout of the tensor must be row major or symmetric!");
}
}