OpenCLPlatform.java

package neureka.devices.opencl;

import neureka.Neureka;
import neureka.backend.api.ImplementationFor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.DeviceAlgorithm;
import neureka.backend.api.Operation;
import neureka.backend.ocl.CLBackend;
import neureka.backend.main.algorithms.*;
import neureka.backend.main.implementations.CLImplementation;
import neureka.backend.main.implementations.SimpleCLImplementation;
import org.jocl.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

import static org.jocl.CL.*;

/**
 *  This class models the OpenCL concept of platforms, which refer to device
 *  vendors / or vendor OpenCL runtime drivers.
 *  For example, in a system with 1 Intel CPU, 1 Nvidia GPUs and 2 AMD GPU,
 *  you will have 3 OpenCL platforms exposed by the OpenCL API, one for Intel,
 *  one for Nvidia and one for AMD. With an AMD CPU and AMD GPU,
 *  you will have a single Platform for both.
 *  Same with Intel CPU and Intel GPU/FPGA, also 1 Platform only.
 *
 *  Here an example for an exception to the Platforms=vendors rule:
 *  There is 1 Intel CPU in the system,
 *  but the Intel OpenCL runtime and also POCL OpenCL runtime installed.
 *  Then you have 2 Platforms (Intel and POCL),
 *  each with the same Intel CPU as device.
 *
 *  For every platform exposed by the OpenCL runtime (modelled by a {@link CLBackend} instance),
 *  there will be a {@link OpenCLPlatform} instance.
 */
public class OpenCLPlatform
{
    private final static Logger _LOG = LoggerFactory.getLogger( OpenCLPlatform.class );

    private final cl_platform_id _pid;
    private final cl_context _context;

    private final Map<cl_device_id, OpenCLDevice> _id_device;
    private final Map<String, cl_kernel> _kernels = new HashMap<>();


    public OpenCLPlatform(cl_platform_id pid)
    {
        _id_device = new TreeMap<>(Comparator.comparingInt(NativePointerObject::hashCode));
        _pid = pid;
        // Obtain the number of devices for the current platform
        int[] numDevices = new int[ 1 ];
        clGetDeviceIDs(pid, CL_DEVICE_TYPE_ALL, 0, null, numDevices);
        cl_device_id[] devicesArray = new cl_device_id[numDevices[ 0 ]];
        clGetDeviceIDs(pid, CL_DEVICE_TYPE_ALL, numDevices[ 0 ], devicesArray, null);
        if ( numDevices[0] == 0 ) {
            String vendor = OpenCLDevice.Query.getString(pid, CL_PLATFORM_VENDOR);
            String platformName = OpenCLDevice.Query.getString(pid, CL_PLATFORM_NAME);
            _LOG.warn(
                "Could not find any OpenCL devices for platform '{}' with id '0x{}' from vendor '{}'. \n" +
                "Although an OpenCL platform is present, it does not seem to find any devices. \n" +
                "Does your hardware support OpenCL? \n",
                platformName, Long.toHexString(pid.getNativePointer()), vendor,
                new Throwable()
            );
        }


        // Enable exceptions and subsequently omit error checks in this sample
        setExceptionsEnabled( true );

        // Initialize the context properties
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(CL_CONTEXT_PLATFORM, pid);

        // Create a context for the selected device
        _context = clCreateContext(
                contextProperties, devicesArray.length, devicesArray,
                null, null, null
        );

        List<cl_device_id> successfullyLoaded = new ArrayList<>();

        List<String> failures = new ArrayList<>();
        // Collect all devices of this platform
        for (cl_device_id did : devicesArray) {
            try {
                OpenCLDevice clDevice = OpenCLDevice.of(this, did);
                _id_device.put(did, clDevice);
                successfullyLoaded.add(did);
            } catch ( Exception e ) {
                String message =
                        "Failed to create '"+OpenCLDevice.class.getSimpleName()+"' instance for " +
                        "OpenCL device id '0x" + Long.toHexString(did.getNativePointer()) + "' under platform id '0x"+Long.toHexString(pid.getNativePointer())+"'!";
                _LOG.error(message, e);
                failures.add(message + " Reason: " + e.getMessage());
            }
        }
        if ( !successfullyLoaded.isEmpty() )
            _compile(successfullyLoaded.toArray(new cl_device_id[0]));
        else
            _LOG.warn(
                "'"+this.getClass().getSimpleName()+"' with id '"+Long.toHexString(pid.getNativePointer())+"' does not have a valid device attached to it!"
            );

        if ( successfullyLoaded.isEmpty() && devicesArray.length > 0 )
            throw new RuntimeException(
                "Failed to create '"+OpenCLDevice.class.getSimpleName()+"' instances for all devices of platform id '0x"+Long.toHexString(pid.getNativePointer())+"'! \n" +
                "Reasons: \n    " + failures.stream().collect(Collectors.joining("\n    "))
            );
    }

    public void recompile() {
        List<OpenCLDevice> devices = getDevices();
        cl_device_id[] devicesArray = new cl_device_id[devices.size()];
        for ( int i = 0; i < devicesArray.length; i++) devicesArray[ i ] = devices.get( i ).getId();
        _compile(devicesArray);
    }

    /**
     *   This is where all the kernels defined by all the {@link CLImplementation}
     *   in the standard backend, will be compiled to OpenCL programs.
     *   These kernels are usually based on pre-made template kernel source files...
     *   They are supposed to be as general purpose as possible, meaning they use
     *   a rather complicated indexing mechanism (see 'utility.cl').
     *
     * @param devicesArray The array of devices for which kernels should be compiled.
     */
    private void _compile( cl_device_id[] devicesArray )
    {
        //Reading all kernels!
        List<String> templateSources = new ArrayList<>();

        String[] fileNames = {
                "activation_template.cl",
                "broadcast_template.cl",
                "convolution_template.cl",
                "elementwise_template.cl",
                "scalarization_template.cl",
                "scalar_broadcast.cl",
                "utility.cl"
        };
        for ( String name : fileNames )
            templateSources.add(Neureka.get().utility().readResource("kernels/"+name));

        ArrayList<String> names = new ArrayList<>();
        ArrayList<String> sources = new ArrayList<>();
        for ( int i = 0; i < fileNames.length; i++ )
        {
            String kernelSource = templateSources.get( i );
            boolean templateFound = false;
            if ( kernelSource.contains( "__kernel" ) )
            {
                String[] parts = kernelSource.split("__kernel")[ 1 ].split("\\(")[ 0 ].split(" ");

                templateFound = parts[parts.length - 1].contains("template");
                if ( !templateFound ) names.add(parts[parts.length - 1]);
                else
                {
                    String preName = parts[ parts.length - 1 ].replace("template", "");
                    // Tensor t0_origin, Tensor t1_handle, Tsr t2_drain ... when d>=0
                    // Tsr t0_drain,  Tsr t1_src1,   Tsr t2_src2
                    // drn[di], src1[_i_of_idx_on_tln(prv_src1_cfg, rank)], src2[_i_of_idx_on_tln(prv_src2_cfg, rank)]
                    // default:  src1 o src2 -> drain
                    // inverse:  src1/fdrn <- src2 <- drain
                    //===========================================================================
                    Map<String, String> code = new HashMap<>();
                    ImplementationFor<OpenCLDevice> impl = null;
                    for ( Operation type : Neureka.get().backend().getOperations() ) {
                        if ( preName.contains("activation") && type.supportsAlgorithm(ElementwiseAlgorithm.class) )
                            impl = type.getAlgorithm(ElementwiseAlgorithm.class).getImplementationFor( OpenCLDevice.class );
                        else if ( preName.contains("elementwise") && type.supportsAlgorithm(BiElementwise.class) )
                            impl = type.getAlgorithm(BiElementwise.class).getImplementationFor( OpenCLDevice.class );
                        else if ( preName.contains("scalarization") && type.supportsAlgorithm(BiScalarBroadcast.class) )
                            impl = type.getAlgorithm(BiScalarBroadcast.class).getImplementationFor( OpenCLDevice.class );
                        else if ( preName.contains("broadcast") && type.supportsAlgorithm(Broadcast.class) )
                            impl = type.getAlgorithm(Broadcast.class).getImplementationFor( OpenCLDevice.class );
                        else if ( preName.contains("convolution") && type.supportsAlgorithm(NDConvolution.class) )
                            impl = type.getAlgorithm(NDConvolution.class).getImplementationFor( OpenCLDevice.class );
                        else if (
                                type.supportsAlgorithm(DeviceAlgorithm.class)
                                &&
                                preName.contains(type.getAlgorithm(DeviceAlgorithm.class).getName())
                        ) { // TODO: cover!
                            impl = type.getAlgorithm(DeviceAlgorithm.class).getImplementationFor( OpenCLDevice.class );
                        }
                        if ( impl instanceof CLImplementation ) {
                            for ( KernelCode kernelCode : ((CLImplementation) impl).getKernelCode() ) {
                                if (kernelCode.getCode() != null)
                                    code.put(kernelCode.getName(), kernelCode.getCode());
                            }
                        }
                    }
                    code.forEach( ( n, s ) -> { names.add( n ); sources.add( s ); } );
                }
            }
            if ( !templateFound ) sources.add( kernelSource );
        }

        for ( Operation type : Neureka.get().backend().getOperations() ) {
            for ( Algorithm algorithm : type.getAllAlgorithms()) {
                DeviceAlgorithm<?> deviceAlgorithm = ( algorithm instanceof DeviceAlgorithm ? ((DeviceAlgorithm<?>) algorithm) : null );
                ImplementationFor<OpenCLDevice> impl =  ( deviceAlgorithm == null ? null : deviceAlgorithm.getImplementationFor(OpenCLDevice.class) );
                if ( impl instanceof CLImplementation ) {
                    CLImplementation cli = ((CLImplementation) impl);
                    if ( cli instanceof SimpleCLImplementation ) {
                        for ( KernelCode kernelCode : cli.getKernelCode() ) {
                            names.add( kernelCode.getName() );
                            sources.add( kernelCode.getCode() );
                        }
                    }
                }
            }
        }

        // Create the program
        cl_program cpProgram = clCreateProgramWithSource(
                _context,
                sources.size(),
                sources.toArray( new String[ 0 ] ),
                null,
                null
        );

        // Build the program
        int err = clBuildProgram(
                cpProgram,
                devicesArray.length,
                devicesArray,
                "-cl-mad-enable",
                null,
                null
        );
        if ( err != CL_SUCCESS )
            _LOG.error("Failed to compile the OpenCL code of the current context. Error code: '"+err+"'.");

        //TODO: check compilation errors!

        // Create the kernels
        for ( String name : names )
            if ( name != null ) _kernels.put( name, clCreateKernel( cpProgram, name, null ) );
    }

    public List<OpenCLDevice> getDevices() {
        List<OpenCLDevice> devices = new ArrayList<>();
        _id_device.forEach( ( k, v ) -> devices.add( v ) );
        return devices;
    }

    /**
     * @param did The {@link cl_device_id} representing an OpenCL supporting device.
     * @return The truth value determining if this platform hosts the device represented by the provided id.
     */
    public boolean has( cl_device_id did ) { return _id_device.containsKey( did ); }

    public OpenCLDevice get( cl_device_id did ) {
        return _id_device.get( did );
    }

    void put( cl_device_id did, OpenCLDevice device ) {
       _id_device.put( did, device );
    }

    public cl_kernel getKernel( String kernelName ) {
        return _kernels.get( kernelName );
    }

    public boolean hasKernel( String kernelName ) {
        return _kernels.containsKey( kernelName );
    }

    public final long getId() { return _pid.getNativePointer(); }

    public cl_context getContext() {
        return _context;
    }

    public void dispose() {
        clReleaseContext( _context );
    }

    @Override
    public String toString() {
        return this.getClass().getSimpleName()+"@"+Integer.toHexString(hashCode())+"[" +
                    "id=0x" + Long.toHexString(_pid.getNativePointer()) + "," +
                    "context=0x"+Long.toHexString(_context.getNativePointer()) + "," +
                    "kernels=[.."+_kernels.size()+"..]" +
                "]";
    }

}