CLDot.java

  1. package neureka.backend.main.implementations.linear;

  2. import neureka.Shape;
  3. import neureka.Tensor;
  4. import neureka.backend.api.ExecutionCall;
  5. import neureka.backend.api.ImplementationFor;
  6. import neureka.backend.main.operations.linear.internal.opencl.CLSum;
  7. import neureka.devices.opencl.KernelCaller;
  8. import neureka.devices.opencl.OpenCLDevice;

  9. import java.util.function.Supplier;

  10. /**
  11.  *  Performs a dot product on two vectors using OpenCL.
  12.  */
  13. public class CLDot implements ImplementationFor<OpenCLDevice>
  14. {
  15.     @Override
  16.     public Tensor<?> run(ExecutionCall<OpenCLDevice> call ) {
  17.         // First we unpack the input tensors:
  18.         Tensor<Float> c = call.input(Float.class, 0);
  19.         Tensor<Float> a = call.input(Float.class, 1);
  20.         Tensor<Float> b = call.input(Float.class, 2);
  21.         OpenCLDevice device = call.getDevice();

  22.         if ( a.rank() != 1 || b.rank() != 1 )
  23.             throw new IllegalArgumentException("Input tensors must be vectors.");

  24.         int size = a.shape(0);
  25.         if ( b.shape(0) != size )
  26.             throw new IllegalArgumentException("Input vectors must have the same length.");

  27.         // First we multiply the two vectors:
  28.         String kernelName = "multiply_arrays_for_dot_product";
  29.         Supplier<String> code = () ->
  30.                     "__kernel void " + kernelName + "(__global const float* a, \n" +
  31.                     "                              __global const float* b, \n" +
  32.                     "                              __global float* c,\n" +
  33.                     "                              const int n) {\n" +
  34.                     "    int i = get_global_id(0);\n" +
  35.                     "    if (i < n) {\n" +
  36.                     "        c[i] = a[i] * b[i];\n" +
  37.                     "    }\n" +
  38.                     "}";

  39.         Tensor<Float> temp = Tensor.of(Float.class, Shape.of(size), 0).to(device).mut().setIsVirtual(false);

  40.         // Kernels are cached, so if it is already compiled, it will be retrieved from the cache:
  41.         KernelCaller caller = device.findOrCompileAdHocKernel(kernelName, code);
  42.         // We call OpenCL to do the work:
  43.         caller.pass(a).pass(b).pass(temp).pass(size).call(new long[]{size}, null);

  44.         Tensor<Float> out = CLSum.run(temp, device);
  45.         c.mut().at(0).set(out.item());
  46.         return c;
  47.     }
  48. }