CLSum.java

  1. package neureka.backend.main.operations.linear.internal.opencl;

  2. import neureka.Shape;
  3. import neureka.Tensor;
  4. import neureka.backend.api.ExecutionCall;
  5. import neureka.backend.api.ImplementationFor;
  6. import neureka.devices.opencl.KernelCaller;
  7. import neureka.devices.opencl.OpenCLDevice;

  8. import java.util.function.Supplier;

  9. public class CLSum implements ImplementationFor<OpenCLDevice>
  10. {
  11.     @Override
  12.     public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
  13.         return run(call.input(Float.class, 0), call.getDevice());
  14.     }

  15.     /**
  16.      *  This method compiles and executes the kernel that will return the sum of the
  17.      *  elements in the {@code in} tensor.
  18.      */
  19.     public static Tensor<Float> run(
  20.             Tensor<Float> in, OpenCLDevice device
  21.     ) {
  22.         final long RTS = device.maxWorkGroupSize(); // Register tile size
  23.         final int SIZE = in.size();
  24.         long localSize = device.maxWorkGroupSize();
  25.         while (SIZE % localSize != 0) { localSize--; } // We want to have a multiple of the max workgroup size + as large as possible
  26.         final int N = (int) (SIZE / localSize); // The number of partial sums

  27.         long[] local  = new long[]{ localSize };
  28.         long[] global = new long[]{(long) SIZE };

  29.         Tensor<Float> out;

  30.         if ( localSize == 1 ) { // Oh, the user wants to process a prime number of elements... sigh! Ok let's do it (slower)!
  31.             double fraction = (double) SIZE / (double) RTS;
  32.             final int newN;
  33.             // Determining optimal number of tiles!
  34.             // Check if fraction is an integer
  35.             if ( fraction == Math.floor(fraction) )
  36.                 newN = (int) fraction;
  37.             else
  38.                 newN = (int) Math.ceil(fraction); // The last tile we do a partial reduction (bound check)

  39.             out = Tensor.of(Float.class, Shape.of(newN), 0).to(device).mut().setIsVirtual(false);
  40.             KernelCaller caller = _processPrivate(RTS, device);
  41.             caller.pass(SIZE).pass(in).pass(out).call(global, local);
  42.         }
  43.         else
  44.         {
  45.             out = Tensor.of(Float.class, Shape.of( N ), 0).to(device).mut().setIsVirtual(false);
  46.             KernelCaller caller = _processLocal(device);
  47.             caller.pass(in).pass(out).passLocalFloats((int) localSize).call(global, local);
  48.         }

  49.         if ( N > 1 ) {
  50.             Tensor<Float> reduced = run(out, device);
  51.             out.mut().delete();
  52.             return reduced;
  53.         }
  54.         return out;
  55.     }

  56.     private static KernelCaller _processPrivate( long RTS, OpenCLDevice device )
  57.     {
  58.         String kernelName = "fast_private_sum_reduction_RTS"+RTS;
  59.         Supplier<String> code = () ->
  60.                         "   #define RTS "+RTS+"                                                                    \n" +
  61.                         "   __kernel void "+kernelName+"(                                                          \n" +
  62.                         "               const int size,                                                            \n" +
  63.                         "               const __global float* in,                                                  \n" +
  64.                         "                     __global float* out                                                  \n" +
  65.                         "   ) {                                                                                    \n" +
  66.                         "       size_t ni = get_global_id(0); //   global N-tile id                                \n" +
  67.                         "                                                                                          \n" +
  68.                         "       int offset = ni * RTS;                                                             \n" +
  69.                         "       int limit = min( offset + RTS, size ); // Boundary condition!                      \n" +
  70.                         "       float value = in[offset];                                                          \n" +
  71.                         "       offset++;                                                                          \n" +
  72.                         "                                                                                          \n" +
  73.                         "       #pragma unroll                                                                     \n" +
  74.                         "       for ( uint i=offset; i < limit; ++i )                                              \n" +
  75.                         "           value += in[i];                                                                \n" +
  76.                         "                                                                                          \n" +
  77.                         "       out[ni] = value;                                                                   \n" +
  78.                         "   }                                                                                      \n";

  79.         return device.findOrCompileAdHocKernel(kernelName, code);
  80.     }

  81.     private static KernelCaller _processLocal(
  82.             OpenCLDevice device
  83.     ) {
  84.         String kernelName = "fast_local_mem_based_sum";
  85.         Supplier<String> code = () ->
  86.                 "                                                                                                  \n" +
  87.                 "    int div(int i) { return i % 2 == 1 ? (i+1) / 2 : i / 2;  }                                    \n" +
  88.                 "                                                                                                  \n" +
  89.                 "    __kernel void "+kernelName+" (                                                                \n" +
  90.                 "       __global const float *input,                                                               \n" +
  91.                 "       __global float *partialSums,                                                               \n" +
  92.                 "       __local float *localSums                                                                   \n" +
  93.                 "    ){                                                                                            \n" +
  94.                 "        uint local_id = get_local_id(0);                                                          \n" +
  95.                 "        uint group_size = get_local_size(0);                                                      \n" +
  96.                 "                                                                                                  \n" +
  97.                 "        // Copy from global to local memory                                                       \n" +
  98.                 "        localSums[local_id] = input[get_global_id(0)];                                            \n" +
  99.                 "                                                                                                  \n" +
  100.                 "        // Loop for computing localSums : divide WorkGroup into 2 parts                           \n" +
  101.                 "        uint last = group_size;                                                                   \n" +
  102.                 "        for (uint stride = div(group_size); stride > 0 && last > stride; stride=div(stride))      \n" +
  103.                 "        {                                                                                         \n" +
  104.                 "             // Waiting for each 2x2 addition into given workgroup                                \n" +
  105.                 "             barrier(CLK_LOCAL_MEM_FENCE);                                                        \n" +
  106.                 "                                                                                                  \n" +
  107.                 "             // Add elements 2 by 2 between local_id and local_id + stride                        \n" +
  108.                 "             uint right_id = local_id + stride; // We copy from the right part.                   \n" +
  109.                 "             if (local_id < stride && right_id < last )                                           \n" +
  110.                 "                 localSums[local_id] += localSums[local_id + stride];                             \n" +
  111.                 "             last = stride;                                                                       \n" +
  112.                 "        }                                                                                         \n" +
  113.                 "                                                                                                  \n" +
  114.                 "        // Write result into partialSums[nWorkGroups]                                             \n" +
  115.                 "        if (local_id == 0)                                                                        \n" +
  116.                 "            partialSums[get_group_id(0)] = localSums[0];                                          \n" +
  117.                 "    }                                                                                             \n";

  118.         return device.findOrCompileAdHocKernel(kernelName, code);
  119.     }

  120. }