CLFunctionCompiler.java

package neureka.devices.opencl.utility;

import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.*;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.implementations.FunctionInput;
import neureka.math.implementations.FunctionVariable;
import neureka.devices.opencl.KernelCaller;
import neureka.devices.opencl.OpenCLDevice;
import neureka.dtype.DataType;
import neureka.dtype.NumericType;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 *  Turns a {@link Function} into OpenCL kernel code to make
 *  optimized just in time compilation possible.
 */
public final class CLFunctionCompiler
{
    private final OpenCLDevice _device;
    private final Function _functionToBeOptimized;
    private final String _functionName;
    private final int[] _argPointer;

    public CLFunctionCompiler( OpenCLDevice device, Function toBeOptimized, String functionName )
    {
        _device = device;
        _functionToBeOptimized = toBeOptimized;
        _functionName = functionName;
        _argPointer = toBeOptimized.getAllFunctions()
                                    .stream()
                                    .filter( fun -> fun instanceof FunctionInput )
                                    .mapToInt( fun -> ( (FunctionInput) fun ).index() )
                                    .distinct()
                                    .toArray();
    }


    public Operation optimize() {
        int numberOfArgs = _functionToBeOptimized.numberOfArgs();
        if ( _functionToBeOptimized.getSubFunctions().stream().anyMatch(fun -> fun instanceof FunctionVariable ) )
            numberOfArgs = -1; // The function is an indexer which means that it can have any number of arguments...
        return Operation
                .builder()
                .identifier( _functionName )
                .operator( _functionName )
                .arity( numberOfArgs )
                .isIndexer( numberOfArgs < 0 )
                .isOperator( false )
                .isDifferentiable( true )
                .isInline( false )
                .stringifier(
                    children -> {
                        String expression = String.join( ", ", children );
                        if ( expression.charAt(0) == '(' && expression.charAt(expression.length() - 1) == ')' )
                            return _functionName + expression;
                        return _functionName + "(" + expression + ")";
                    }
                )
                .build()
                .setAlgorithm(
                    DeviceAlgorithm
                        .withName( "generic_algorithm_for_"+ _functionName )
                        .setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
                        .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
                        .setExecution(
                            (outerCaller, outerCall) ->
                                Result.of(AbstractDeviceAlgorithm.executeFor(
                                    outerCaller, outerCall,
                                    call -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( call )
                                ))
                                .withAutoDiff((caller, call) -> ADAction.of( target -> Function.of(caller.toString(), false).derive(new Tensor[]{target.error()}, 0) ))
                        )
                        .setCallPreparation(
                            call -> {
                                if ( call.input( 0 ) == null ) // Creating a new tensor:
                                {
                                    Tensor<Number> output = Tensor.like( (Tensor<Number>) call.input( 1 ) ).all(0);
                                    output.getMut().setIsVirtual( false );
                                    call.getDeviceFor(Number.class).store(output);
                                    call = call.withInputAt( 0, output );
                                }
                                return call;
                            }
                        )
                        .buildFunAlgorithm()
                        .setImplementationFor( OpenCLDevice.class, this::_adHocKernelFor )
                );
    }


    private Tensor<?> _adHocKernelFor(ExecutionCall<?> call ) {

        List<Tensor<Number>> args = Arrays.stream( _argPointer )
                                    .mapToObj( p -> call.input( p + 1 ).getMut().upcast(Number.class) )
                                    .collect(Collectors.toList());

        args.add(0, call.input(Number.class, 0));

        List<String> types = args.stream()
                                    .map( CLFunctionCompiler::_clTypeOf )
                                    .collect(Collectors.toList());

        String kernelSignature =
                                _functionName + ( call.getValOf( Arg.DerivIdx.class ) >= 0 ? "_derivative" : "" ) +
                                "_" +
                                        args.stream()
                                                .map( arg ->
                                                        arg.getDataType().getRepresentativeType().getSimpleName() +
                                                        "$" +
                                                        (
                                                            arg.getNDConf().isSimple()
                                                            ? Arrays.stream( arg.getNDConf().shape() )
                                                            : Arrays.stream( arg.getNDConf().asInlineArray() )
                                                        )
                                                        .mapToObj( String::valueOf )
                                                        .collect( Collectors.joining("x") )
                                                )
                                    .collect( Collectors.joining( "_" ) );

        if ( _device.hasAdHocKernel( kernelSignature ) ) {
            KernelCaller caller = _device.getAdHocKernel( kernelSignature );
            args.forEach( caller::passAllOf);
            caller.call( args.get(0).size() );
            return call.input(0);
        }
        // So no kernel with this signature was found...
        // Therefore we compile a new kernel specific to the provided call contents (shapes and types)!

        int rank = args.get(0).rank();

        List<List<String>> configs = args.stream()
                                            .map( arg -> arg.getNDConf().asInlineArray() )
                                            .map(
                                                    array ->
                                                            Arrays.stream(array)
                                                                    .mapToObj( String::valueOf )
                                                                    .collect(Collectors.toList())
                                            )
                                            .collect(Collectors.toList());

        String argString = IntStream.range( 0, args.size() )
                                    .mapToObj( i -> "__global "+types.get(i)+"* arg" + i )
                                    .collect(Collectors.joining(", "));

        Function toBeCompiled = call.getValOf( Arg.DerivIdx.class ) < 0
                                    ? _functionToBeOptimized
                                    : _functionToBeOptimized.getDerivative( call.getValOf( Arg.DerivIdx.class ) );

        String compilableFun = IntStream.range( 0, _argPointer.length )
                                        .mapToObj( String::valueOf )
                                        .reduce(
                                            toBeCompiled.toString(),
                                             (source, index) ->
                                                     source.replace(
                                                             "I["+_argPointer[Integer.parseInt(index)]+"]",
                                                             "v" + (Integer.parseInt(index) + 1)
                                                     )
                                        );

        String kernelCode =
                "\n" +
                    _readAndGetIndexMapper() +
                "\n" +
                "    __kernel void " + kernelSignature + "(\n" +
                "        " + argString + "\n" +
                "    ) {                                                                                     \n" +
                "        " + IntStream
                                .range(0, configs.size())
                                .mapToObj(
                                    i -> "int cfg"+i+"[] = {" + String.join( ",", configs.get(i) ) + "};"
                                )
                                .collect(Collectors.joining("\n        ")) +
                "                                                                                          \n" +
                "        unsigned int i = get_global_id( 0 );                                              \n" +
                "        " + IntStream
                                .range(1, args.size()) // We start at 1 because 0 is the output!
                                .mapToObj(
                                        i -> types.get(i) + " v" + i + " = arg" + i + "[_i_of_i(i, cfg"+i+", "+rank+")];"
                                )
                                .collect(Collectors.joining("\n        ")) +
                "                                                                                          \n" +
                "        arg0[_i_of_i(i, cfg0, "+rank+")] = " + compilableFun + ";                         \n" +
                "    }                                                                                     \n\n";

        KernelCaller caller = _device.compileAndGetAdHocKernel( kernelSignature, kernelCode );
        args.forEach( caller::pass );
        caller.call( args.get(0).size() );
        return call.input(0);
    }

    private static String _clTypeOf( Tensor<?> tensor ) {
        DataType<?> dtype = tensor.getDataType();
        java.util.function.Function<Class<?>, String> formatter = type -> type.getSimpleName()
                                                                                 .toLowerCase()
                                                                                 .replace("integer", "int");
        if ( dtype.typeClassImplements(NumericType.class) ) {
            NumericType<?,?,?,?> instance = (NumericType<?,?,?,?>) dtype.getTypeClassInstance(NumericType.class);
            if ( instance.holderType() == instance.targetType() )
                return formatter.apply(instance.holderType()); // Float, Double, Long, Short...
            else // Unsigned types:
                return "u" + formatter.apply(instance.holderType());
        }
        return formatter.apply(dtype.getRepresentativeType());
    }

    /**
     *  This method simply reads the "utility.cl" resource to extract and
     *  return the "_i_of_i" method in the form of a simple {@link String}.
     *
     * @return The "_i_of_i" method from the "utility.cl" file.
     */
    private static String _readAndGetIndexMapper() {
        String resource = Neureka.get()
                                    .utility()
                                    .readResource("kernels/utility.cl");
        return
                "    int _i_of_idx_on_tln" +
                        resource
                                .split("int _i_of_idx_on_tln")[1]
                                .split("// _i_of_idx_on_tln end!")[0] +
               "\n" +
               "    int _i_of_i" +
                        resource
                                .split("int _i_of_i")[1]
                                .split("// _i_of_i end!")[0];
    }

}