CLSum.java

package neureka.backend.main.operations.linear.internal.opencl;

import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.devices.opencl.KernelCaller;
import neureka.devices.opencl.OpenCLDevice;

import java.util.function.Supplier;

public class CLSum implements ImplementationFor<OpenCLDevice>
{
    @Override
    public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
        return run(call.input(Float.class, 0), call.getDevice());
    }

    /**
     *  This method compiles and executes the kernel that will return the sum of the
     *  elements in the {@code in} tensor.
     */
    public static Tensor<Float> run(
            Tensor<Float> in, OpenCLDevice device
    ) {
        final long RTS = device.maxWorkGroupSize(); // Register tile size
        final int SIZE = in.size();
        long localSize = device.maxWorkGroupSize();
        while (SIZE % localSize != 0) { localSize--; } // We want to have a multiple of the max workgroup size + as large as possible
        final int N = (int) (SIZE / localSize); // The number of partial sums

        long[] local  = new long[]{ localSize };
        long[] global = new long[]{(long) SIZE };

        Tensor<Float> out;

        if ( localSize == 1 ) { // Oh, the user wants to process a prime number of elements... sigh! Ok let's do it (slower)!
            double fraction = (double) SIZE / (double) RTS;
            final int newN;
            // Determining optimal number of tiles!
            // Check if fraction is an integer
            if ( fraction == Math.floor(fraction) )
                newN = (int) fraction;
            else
                newN = (int) Math.ceil(fraction); // The last tile we do a partial reduction (bound check)

            out = Tensor.of(Float.class, Shape.of(newN), 0).to(device).mut().setIsVirtual(false);
            KernelCaller caller = _processPrivate(RTS, device);
            caller.pass(SIZE).pass(in).pass(out).call(global, local);
        }
        else
        {
            out = Tensor.of(Float.class, Shape.of( N ), 0).to(device).mut().setIsVirtual(false);
            KernelCaller caller = _processLocal(device);
            caller.pass(in).pass(out).passLocalFloats((int) localSize).call(global, local);
        }

        if ( N > 1 ) {
            Tensor<Float> reduced = run(out, device);
            out.mut().delete();
            return reduced;
        }
        return out;
    }

    private static KernelCaller _processPrivate( long RTS, OpenCLDevice device )
    {
        String kernelName = "fast_private_sum_reduction_RTS"+RTS;
        Supplier<String> code = () ->
                        "   #define RTS "+RTS+"                                                                    \n" +
                        "   __kernel void "+kernelName+"(                                                          \n" +
                        "               const int size,                                                            \n" +
                        "               const __global float* in,                                                  \n" +
                        "                     __global float* out                                                  \n" +
                        "   ) {                                                                                    \n" +
                        "       size_t ni = get_global_id(0); //   global N-tile id                                \n" +
                        "                                                                                          \n" +
                        "       int offset = ni * RTS;                                                             \n" +
                        "       int limit = min( offset + RTS, size ); // Boundary condition!                      \n" +
                        "       float value = in[offset];                                                          \n" +
                        "       offset++;                                                                          \n" +
                        "                                                                                          \n" +
                        "       #pragma unroll                                                                     \n" +
                        "       for ( uint i=offset; i < limit; ++i )                                              \n" +
                        "           value += in[i];                                                                \n" +
                        "                                                                                          \n" +
                        "       out[ni] = value;                                                                   \n" +
                        "   }                                                                                      \n";

        return device.findOrCompileAdHocKernel(kernelName, code);
    }

    private static KernelCaller _processLocal(
            OpenCLDevice device
    ) {
        String kernelName = "fast_local_mem_based_sum";
        Supplier<String> code = () ->
                "                                                                                                  \n" +
                "    int div(int i) { return i % 2 == 1 ? (i+1) / 2 : i / 2;  }                                    \n" +
                "                                                                                                  \n" +
                "    __kernel void "+kernelName+" (                                                                \n" +
                "       __global const float *input,                                                               \n" +
                "       __global float *partialSums,                                                               \n" +
                "       __local float *localSums                                                                   \n" +
                "    ){                                                                                            \n" +
                "        uint local_id = get_local_id(0);                                                          \n" +
                "        uint group_size = get_local_size(0);                                                      \n" +
                "                                                                                                  \n" +
                "        // Copy from global to local memory                                                       \n" +
                "        localSums[local_id] = input[get_global_id(0)];                                            \n" +
                "                                                                                                  \n" +
                "        // Loop for computing localSums : divide WorkGroup into 2 parts                           \n" +
                "        uint last = group_size;                                                                   \n" +
                "        for (uint stride = div(group_size); stride > 0 && last > stride; stride=div(stride))      \n" +
                "        {                                                                                         \n" +
                "             // Waiting for each 2x2 addition into given workgroup                                \n" +
                "             barrier(CLK_LOCAL_MEM_FENCE);                                                        \n" +
                "                                                                                                  \n" +
                "             // Add elements 2 by 2 between local_id and local_id + stride                        \n" +
                "             uint right_id = local_id + stride; // We copy from the right part.                   \n" +
                "             if (local_id < stride && right_id < last )                                           \n" +
                "                 localSums[local_id] += localSums[local_id + stride];                             \n" +
                "             last = stride;                                                                       \n" +
                "        }                                                                                         \n" +
                "                                                                                                  \n" +
                "        // Write result into partialSums[nWorkGroups]                                             \n" +
                "        if (local_id == 0)                                                                        \n" +
                "            partialSums[get_group_id(0)] = localSums[0];                                          \n" +
                "    }                                                                                             \n";

        return device.findOrCompileAdHocKernel(kernelName, code);
    }

}