CLReduce.java
package neureka.backend.main.operations.linear.internal.opencl;
import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.backend.ocl.CLBackend;
import neureka.backend.ocl.CLSettings;
import neureka.devices.opencl.KernelCaller;
import neureka.devices.opencl.OpenCLDevice;
import java.util.function.Supplier;
public class CLReduce implements ImplementationFor<OpenCLDevice>
{
public static String INDICES_MAPPER_ID = "indices_to_values_mapper";
public enum Type { MIN, MAX }
private final Type _type;
private final String _comparator;
public CLReduce(Type type) {
String comparator;
switch (type) {
case MIN: comparator = "current < value"; break;
case MAX: comparator = "current > value"; break;
default: throw new IllegalArgumentException("Unsupported reduction type: "+type);
}
_comparator = comparator;
_type = type;
}
@Override
public Tensor<Integer> run(ExecutionCall<OpenCLDevice> call) {
CLBackend context = Neureka.get().backend().find(CLBackend.class).orElse(null);
CLSettings settings = context == null ? null : context.getSettings();
boolean autoConvert = context == null || settings.isAutoConvertToFloat();
if ( settings != null ) settings.setAutoConvertToFloat(false);
Tensor<Float> in = call.input(0) == null ? call.input(Float.class, 1) : call.input(Float.class, 0);
int index = _runRecursively(in, call.getDevice());
if ( settings != null ) settings.setAutoConvertToFloat(autoConvert);
return Tensor.of(Integer.class, Shape.of( 1 ), index);
}
private int _runRecursively(Tensor<Float> in, OpenCLDevice device)
{
final long RTS = device.maxWorkGroupSize(); // Register tile size
final int SIZE = in.size();
double fraction = (double) SIZE / (double) RTS;
// Determining optimal number of tiles!
int N;
// Check if fraction is an integer
if ( fraction == Math.floor(fraction) )
N = (int) fraction;
else
N = (int) Math.ceil(fraction); // The last tile we do a partial reduction (bound check)
Tensor<Integer> out = Tensor.of(Integer.class, Shape.of(N), 0).to(device);
out.mut().setIsVirtual(false);
if ( in.size() == 1 ) {
assert out.size() == 1;
return out.at(0).get();
}
String kernelName = "fast_"+_type.name().toLowerCase()+"_reduce_RTS"+RTS;
Supplier<String> code = () ->
" #define RTS "+RTS+" \n" +
" __kernel void "+kernelName+"( \n" +
" const int size, \n" +
" const __global float* in, \n" +
" __global int* out // indices \n" +
" ) { \n" +
" size_t ni = get_global_id(0); // global N-tile id \n" +
" \n" +
" int offset = ni * RTS; \n" +
" int limit = min( offset + RTS, size ); // Boundary condition! \n" +
" float value = in[offset]; \n" +
" int found_index = offset; \n" +
" offset++; \n" +
" \n" +
" #pragma unroll \n" +
" for ( uint i=offset; i < limit; ++i ) { \n" +
" float current = in[i]; \n" +
" if ( "+ _comparator +" ) { \n" +
" value = current; \n" +
" found_index = i; \n" +
" } \n" +
" } \n" +
" out[ni] = found_index; \n" +
" } \n";
KernelCaller caller = device.findOrCompileAdHocKernel(kernelName, code);
long[] local = null; // This kernel does not have local memory (uses register/private memory instead)
long[] global = new long[]{ N };
caller.pass(SIZE).pass( in ).pass( out ).call( global, local );
int i;
if ( N > 1 ) {
Tensor<Float> reduced = _fetch(in, out, device);
i = out.at(_runRecursively(reduced, device)).get();
reduced.mut().delete();
}
else
i = out.at(0).get();
out.mut().delete();
return i;
}
/**
* Creates and return a new tensor with the size of the
* {@code indices} tensor but with the values targeted in the {@code in}
* argument.
* All of this is done on a simple index to array entry mapping kernel!
*/
private Tensor<Float> _fetch(
Tensor<Float> in, Tensor<Integer> indices, OpenCLDevice device
) {
Tensor<Float> out = Tensor.of(Float.class, Shape.of(indices.size()), 0).to(device);
out.mut().setIsVirtual(false);
String kernelName = INDICES_MAPPER_ID;
Supplier<String> code = () ->
" __kernel void " + kernelName + "( \n" +
" const __global int* indices, \n" +
" const __global float* in, \n" +
" __global float* out \n" +
" ) { \n" +
" size_t i = get_global_id(0); // global id \n" +
" out[i] = in[indices[i]]; \n" +
" } \n";
KernelCaller caller =
device.hasAdHocKernel(kernelName)
? device.findAdHocKernel(kernelName).orElseThrow(()-> new RuntimeException("Could not find kernel: "+kernelName))
: device.compileAndGetAdHocKernel(kernelName, code.get());
long[] local = null; // This kernel does not have local memory (uses register/private memory instead)
long[] global = new long[]{ indices.size() };
caller.pass( indices ).pass( in ).pass( out ).call( global, local );
return out;
}
}