KernelCaller.java

package neureka.devices.opencl;

import neureka.Tensor;
import org.jocl.*;

import java.util.ArrayList;
import java.util.List;

import static org.jocl.CL.*;

/**
 *  Instances of this class are utility factories provided by {@link OpenCLDevice} instances.
 *  When building new operations for tensors then this {@link KernelCaller} class is essential
 *  for calling compiled kernels residing within the gpu.
 */
public class KernelCaller
{
    private final cl_command_queue  _queue;
    private final cl_kernel         _kernel;
    private final List<Tensor<Number>> _inputs;

    private int _argId = 0;

    /**
     *
     * @param kernel The kernel which ought to be called.
     * @param queue The queue on which calls ought to be dispatched.
     */
    public KernelCaller( cl_kernel kernel, cl_command_queue queue ) {
        _queue  = queue;
        _kernel = kernel;
        _inputs = new ArrayList<>();
    }

    /**
     * This method passes 2 arguments to the kernel.
     * One for the data of the tensor and one for the configuration data!
     * @param tensor The tensor whose data and configuration ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller passAllOf( Tensor<Number> tensor ) {
        _inputs.add( tensor );
        clSetKernelArg( _kernel, _argId, Sizeof.cl_mem, Pointer.to( tensor.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.data ) );
        _argId++;
        return passConfOf( tensor );
    }

    /**
     *  This method passes the ND-Configuration in the form of a flattened int array to the kernel.
     *  Kernels can use this information for more complex indexing mechanisms as one would
     *  expect them to be present in tensor which have been permuted or are simply
     *  slices of other tensors.
     *
     *  @param tensor The tensor whose ND configuration ought to be passed to the kernel.
     *  @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller passConfOf( Tensor<Number> tensor ) {
        OpenCLDevice device = (OpenCLDevice) tensor.getDevice();
        clSetKernelArg( _kernel, _argId, Sizeof.cl_mem, Pointer.to( device.clConfigOf(tensor ).data ) );
        _argId++;
        return this;
    }

    /**
     * This method passes 1 argument to the kernel.
     * Namely, the data of the tensor!
     * @param tensor The tensor whose data ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public <T extends Number> KernelCaller pass( Tensor<T> tensor ) {
        _inputs.add( tensor.getMut().upcast(Number.class) );
        clSetKernelArg( _kernel, _argId, Sizeof.cl_mem, Pointer.to( tensor.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.data ) );
        _argId++;
        return this;
    }

    /**
     *
     * @param value An int value which ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller pass( int value ) {
        return this.pass( new int[]{ value } );
    }

    /**
     *  Use this to pass an array of int values to the kernel.
     *
     * @param values An array of int values which ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller pass( int... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_int * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    /**
     *  Use this to pass an array of float values to the kernel.
     *
     * @param values An array of float values which ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller pass( float... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_float * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    public KernelCaller pass( double... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_double * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    public KernelCaller pass( short... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_short * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    public KernelCaller pass( long... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_long * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    public KernelCaller pass( byte... values ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_char * (long) values.length, Pointer.to( values ) );
        _argId++;
        return this;
    }

    /**
     * @param value A float value which ought to be passed to the kernel.
     * @return This very KernelCaller instance (factory pattern).
     */
    public KernelCaller pass( float value ) {
        return this.pass( new float[]{ value } );
    }

    public KernelCaller pass( double value ) {
        return this.pass( new double[]{ value } );
    }

    public KernelCaller pass( short value ) {
        return this.pass( new short[]{ value } );
    }

    public KernelCaller pass( long value ) {
        return this.pass( new long[]{ value } );
    }

    public KernelCaller pass( byte value ) {
        return this.pass( new byte[]{ value } );
    }

    public KernelCaller pass( Number value ) {
        if ( value instanceof Float ) return this.pass( value.floatValue() );
        else if ( value instanceof Double ) return this.pass( value.doubleValue() );
        else if ( value instanceof Integer ) return this.pass( value.intValue() );
        else if ( value instanceof Long ) return this.pass( value.longValue() );
        else if ( value instanceof Short ) return this.pass( value.shortValue() );
        else if ( value instanceof Byte ) return this.pass( value.byteValue() );
        else throw new IllegalArgumentException( "Unsupported number type: " + value.getClass().getName() );
    }

    public KernelCaller passLocalFloats( long size ) {
        clSetKernelArg( _kernel, _argId, Sizeof.cl_float * (long) size, null );
        _argId++;
        return this;
    }

    /**
     *
     * @param globalWorkSize The number of global threads which will be dispatched.
     */
    public void call( int globalWorkSize )
    {
        cl_event[] events = _getWaitList( _inputs.toArray( new Tensor[ 0 ] ) );
        if ( events.length > 0 ) {
            clWaitForEvents( events.length, events );
            _releaseEvents( _inputs.toArray( new Tensor[ 0 ] ) );
        }
        clEnqueueNDRangeKernel(
                _queue, _kernel,
                1,
                null,
                new long[]{ globalWorkSize },
                null,
                0,
                null,
                null
        );
    }

     /**
     *  Use this to call the kernel with 2 long arrays defining how the kernel should be indexed and parallelized.
     *  The {@code globalWorkSizes} span an n-dimensional grid of global threads,
     *  whereas the {@code localWorkSizes} defines the dimensions of a grid of local work items (which are called "work groups").
     *  The total number of work items is equal to the product of the {@code localWorkSizes} array entries
     *  exactly like the product of the {@code globalWorkSizes} array is the total number of (global) threads.        <br>
     *  Both sizes have to fulfill the following condition: {@code globalWorkSize = localWorkSize * numberOfGroups}. <br>
     *  Note: The {@code localWorkSizes} is optional, so the second argument may be null
     *  in which case OpenCL will choose a local group size appropriately for you.
     *  This is usually also the optimal choice,
     *  however if the global work size is a prime number (that is larger than the maximum local work size),
     *  then an OpenCL implementation may be forced to use a local work size of 1...
     * <p>
     * This can usually be circumvented by padding the data to be a multiple of a more appropriate
     * local work size or by introducing boundary checks in your kernel.
     *
     * @param globalWorkSizes An arrays of long values which span a nd-grid of global threads.
     * @param localWorkSizes  An arrays of long values which span a nd-grid of local threads (work groups).
     */
    public void call( long[] globalWorkSizes, long[] localWorkSizes )
    {
        cl_event[] events = _getWaitList( _inputs.toArray( new Tensor[ 0 ] ) );
        if ( events.length > 0 ) {
            clWaitForEvents( events.length, events );
            _releaseEvents( _inputs.toArray( new Tensor[ 0 ] ) );
        }
        assert localWorkSizes == null || globalWorkSizes.length == localWorkSizes.length;
        clEnqueueNDRangeKernel(
                _queue, _kernel,
                globalWorkSizes.length,
                null,
                globalWorkSizes,
                localWorkSizes,
                0,
                null,
                null
        );
    }

    
    private void _releaseEvents( Tensor<Number>[] tensors ) {
        for ( Tensor<Number> t : tensors ) {
            if ( t.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.event != null ) {
                clReleaseEvent(t.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.event);
                t.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.event = null;
            }
        }
    }

    
    private cl_event[] _getWaitList( Tensor<Number>[] tensors ) {
        List<cl_event> list = new ArrayList<>();
        for ( Tensor<Number> t : tensors ) {
            cl_event event = t.getMut().getData().as( OpenCLDevice.cl_tsr.class ).value.event;
            if ( event != null && !list.contains(event) ) {
                list.add( event );
            }
        }
        return list.toArray( new cl_event[ 0 ] );
    }

}