- package neureka.backend.main.operations.linear.internal.opencl;
- import neureka.Shape;
- import neureka.Tensor;
- import neureka.backend.api.ExecutionCall;
- import neureka.backend.api.ImplementationFor;
- import neureka.devices.opencl.KernelCaller;
- import neureka.devices.opencl.OpenCLDevice;
- import java.util.function.Supplier;
- public class CLSum implements ImplementationFor<OpenCLDevice>
- {
- @Override
- public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
- return run(call.input(Float.class, 0), call.getDevice());
- }
- /**
- * This method compiles and executes the kernel that will return the sum of the
- * elements in the {@code in} tensor.
- */
- public static Tensor<Float> run(
- Tensor<Float> in, OpenCLDevice device
- ) {
- final long RTS = device.maxWorkGroupSize(); // Register tile size
- final int SIZE = in.size();
- long localSize = device.maxWorkGroupSize();
- while (SIZE % localSize != 0) { localSize--; } // We want to have a multiple of the max workgroup size + as large as possible
- final int N = (int) (SIZE / localSize); // The number of partial sums
- long[] local = new long[]{ localSize };
- long[] global = new long[]{(long) SIZE };
- Tensor<Float> out;
- if ( localSize == 1 ) { // Oh, the user wants to process a prime number of elements... sigh! Ok let's do it (slower)!
- double fraction = (double) SIZE / (double) RTS;
- final int newN;
- // Determining optimal number of tiles!
- // Check if fraction is an integer
- if ( fraction == Math.floor(fraction) )
- newN = (int) fraction;
- else
- newN = (int) Math.ceil(fraction); // The last tile we do a partial reduction (bound check)
- out = Tensor.of(Float.class, Shape.of(newN), 0).to(device).mut().setIsVirtual(false);
- KernelCaller caller = _processPrivate(RTS, device);
- caller.pass(SIZE).pass(in).pass(out).call(global, local);
- }
- else
- {
- out = Tensor.of(Float.class, Shape.of( N ), 0).to(device).mut().setIsVirtual(false);
- KernelCaller caller = _processLocal(device);
- caller.pass(in).pass(out).passLocalFloats((int) localSize).call(global, local);
- }
- if ( N > 1 ) {
- Tensor<Float> reduced = run(out, device);
- out.mut().delete();
- return reduced;
- }
- return out;
- }
- private static KernelCaller _processPrivate( long RTS, OpenCLDevice device )
- {
- String kernelName = "fast_private_sum_reduction_RTS"+RTS;
- Supplier<String> code = () ->
- " #define RTS "+RTS+" \n" +
- " __kernel void "+kernelName+"( \n" +
- " const int size, \n" +
- " const __global float* in, \n" +
- " __global float* out \n" +
- " ) { \n" +
- " size_t ni = get_global_id(0); // global N-tile id \n" +
- " \n" +
- " int offset = ni * RTS; \n" +
- " int limit = min( offset + RTS, size ); // Boundary condition! \n" +
- " float value = in[offset]; \n" +
- " offset++; \n" +
- " \n" +
- " #pragma unroll \n" +
- " for ( uint i=offset; i < limit; ++i ) \n" +
- " value += in[i]; \n" +
- " \n" +
- " out[ni] = value; \n" +
- " } \n";
- return device.findOrCompileAdHocKernel(kernelName, code);
- }
- private static KernelCaller _processLocal(
- OpenCLDevice device
- ) {
- String kernelName = "fast_local_mem_based_sum";
- Supplier<String> code = () ->
- " \n" +
- " int div(int i) { return i % 2 == 1 ? (i+1) / 2 : i / 2; } \n" +
- " \n" +
- " __kernel void "+kernelName+" ( \n" +
- " __global const float *input, \n" +
- " __global float *partialSums, \n" +
- " __local float *localSums \n" +
- " ){ \n" +
- " uint local_id = get_local_id(0); \n" +
- " uint group_size = get_local_size(0); \n" +
- " \n" +
- " // Copy from global to local memory \n" +
- " localSums[local_id] = input[get_global_id(0)]; \n" +
- " \n" +
- " // Loop for computing localSums : divide WorkGroup into 2 parts \n" +
- " uint last = group_size; \n" +
- " for (uint stride = div(group_size); stride > 0 && last > stride; stride=div(stride)) \n" +
- " { \n" +
- " // Waiting for each 2x2 addition into given workgroup \n" +
- " barrier(CLK_LOCAL_MEM_FENCE); \n" +
- " \n" +
- " // Add elements 2 by 2 between local_id and local_id + stride \n" +
- " uint right_id = local_id + stride; // We copy from the right part. \n" +
- " if (local_id < stride && right_id < last ) \n" +
- " localSums[local_id] += localSums[local_id + stride]; \n" +
- " last = stride; \n" +
- " } \n" +
- " \n" +
- " // Write result into partialSums[nWorkGroups] \n" +
- " if (local_id == 0) \n" +
- " partialSums[get_group_id(0)] = localSums[0]; \n" +
- " } \n";
- return device.findOrCompileAdHocKernel(kernelName, code);
- }
- }