OpenCLDevice.java

  1. /*
  2. MIT License

  3. Copyright (c) 2019 Gleethos

  4. Permission is hereby granted, free of charge, to any person obtaining a copy
  5. of this software and associated documentation files (the "Software"), to deal
  6. in the Software without restriction, including without limitation the rights
  7. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  8. copies of the Software, and to permit persons to whom the Software is
  9. furnished to do so, subject to the following conditions:

  10. The above copyright notice and this permission notice shall be included in all
  11. copies or substantial portions of the Software.

  12. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  13. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  15. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  18. SOFTWARE.
  19. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

  20.    ____                    _____ _      _____             _
  21.   / __ \                  / ____| |    |  __ \           (_)
  22.  | |  | |_ __   ___ _ __ | |    | |    | |  | | _____   ___  ___ ___
  23.  | |  | | '_ \ / _ \ '_ \| |    | |    | |  | |/ _ \ \ / / |/ __/ _ \
  24.  | |__| | |_) |  __/ | | | |____| |____| |__| |  __/\ V /| | (_|  __/
  25.   \____/| .__/ \___|_| |_|\_____|______|_____/ \___| \_/ |_|\___\___|
  26.         | |
  27.         |_|

  28. ------------------------------------------------------------------------------------------------------------------------

  29.    'Any fool can write code that a computer can understand.
  30.     Good programmers write code that humans can understand.'
  31.     – Martin Fowler

  32.     Use the following as search keys :)

  33.     $(1) : FIELD VARIABLES
  34.     $(2) : CONSTRUCTION
  35.     $(3) : OPENCL PROPERTIES
  36.     $(4) : NESTED CLASSES

  37. */

  38. package neureka.devices.opencl;

  39. import neureka.Data;
  40. import neureka.Neureka;
  41. import neureka.Tensor;
  42. import neureka.backend.api.*;
  43. import neureka.backend.main.implementations.CLImplementation;
  44. import neureka.backend.ocl.CLBackend;
  45. import neureka.common.composition.Component;
  46. import neureka.common.utility.DataConverter;
  47. import neureka.common.utility.LogUtil;
  48. import neureka.devices.*;
  49. import neureka.devices.opencl.utility.CLFunctionCompiler;
  50. import neureka.dtype.DataType;
  51. import neureka.dtype.NumericType;
  52. import neureka.dtype.custom.F32;
  53. import neureka.framing.Relation;
  54. import neureka.math.Function;
  55. import neureka.ndim.config.NDConfiguration;
  56. import org.jocl.*;
  57. import org.slf4j.Logger;
  58. import org.slf4j.LoggerFactory;

  59. import java.nio.ByteBuffer;
  60. import java.nio.ByteOrder;
  61. import java.util.Map;
  62. import java.util.Objects;
  63. import java.util.Optional;
  64. import java.util.WeakHashMap;
  65. import java.util.function.Supplier;

  66. import static org.jocl.CL.*;

  67. /**
  68.  * This class models OpenCL supporting accelerator hardware like GPUs or FPGAs
  69.  * for storing tensors and executing operations on them.
  70.  */
  71. public class OpenCLDevice extends AbstractDevice<Number>
  72. {
  73.     private static final Logger _LOG = LoggerFactory.getLogger(OpenCLDevice.class);

  74.     static OpenCLDevice of( OpenCLPlatform platform, cl_device_id did ) {
  75.         if (!platform.has(did)) platform.put(did, new OpenCLDevice(platform, did));
  76.         return platform.get(did);
  77.     }

  78.     public enum Type {
  79.         CPU, GPU, ACCELERATOR, DEFAULT, CUSTOM, ALL, UNKNOWN
  80.     }

  81.     enum cl_dtype { F32, F64, I64, I32, I16, I8, U32, U16, U8 }

  82.     /*==================================================================================================================
  83.     |
  84.     |       §(1) : FIELD VARIABLES
  85.     |   ---------------------------
  86.     */

  87.     private final KernelCache _kernelCache = new KernelCache();

  88.     private final cl_device_id _deviceId;

  89.     /**
  90.      * The OpenCLPlatform :
  91.      * This method is a simple getter for the OpenCLPlatform instance hosting this current device.
  92.      * A platform would for example be vendor specific like Intel, AMD, Nvidia...
  93.      */
  94.     private final OpenCLPlatform _platform;

  95.     /**
  96.      * The OpenCL command queue
  97.      */
  98.     private final cl_command_queue _queue;


  99.     private final Map<NDConfiguration, cl_config> _configs = new WeakHashMap<>();

  100.     /*==================================================================================================================
  101.     |
  102.     |       §(2) : CONSTRUCTION
  103.     |   ---------------------------
  104.     */

  105.     /**
  106.      * @param platform The platform containing this device.
  107.      * @param deviceId The underlying OpenCL id of this device.
  108.      */
  109.     private OpenCLDevice( OpenCLPlatform platform, cl_device_id deviceId ) {
  110.         super();
  111.         _deviceId = deviceId;
  112.         _platform = platform;
  113.         _queue = clCreateCommandQueueWithProperties(// Create a command-queue for the selected device
  114.                         platform.getContext(), deviceId,
  115.                         null,
  116.                         null
  117.                     );
  118.         _cleaning(this, () -> clReleaseCommandQueue(_queue));
  119.     }

  120.     public final String toString() {
  121.         return "OpenCLDevice[id=0x" + Long.toHexString(_deviceId.getNativePointer()) + ",platform=0x" + Long.toHexString(_platform.getId()) + "]";
  122.     }

  123.     public final cl_device_id getId() { return _deviceId; }

  124.     public final OpenCLPlatform getPlatform() { return _platform; }

  125.     /**
  126.      * @param name The name of the kernel whose presents should be checked.
  127.      * @return True if the kernel is present in the cache, false otherwise.
  128.      */
  129.     public boolean hasAdHocKernel( String name ) { return _kernelCache.has(name); }

  130.     /**
  131.      * @param name The name of the kernel which should be retrieved.
  132.      * @return The kernel with the given name if it is present in the cache, throws an exception otherwise.
  133.      */
  134.     public KernelCaller getAdHocKernel( String name ) {
  135.         cl_ad_hoc adHoc = _kernelCache.get(name);
  136.         if (adHoc != null) return new KernelCaller(adHoc.kernel, _queue);
  137.         else throw new IllegalArgumentException("No ad hoc kernel with name '" + name + "' found!");
  138.     }

  139.     /**
  140.      * @param name The name of the kernel which should be retrieved.
  141.      * @return An {@link Optional} containing the kernel with the given name if it is present in the cache, an empty optional otherwise.
  142.      */
  143.     public Optional<KernelCaller> findAdHocKernel( String name ) {
  144.         cl_ad_hoc adHoc = _kernelCache.get(name);
  145.         if (adHoc != null) return Optional.of(new KernelCaller(adHoc.kernel, _queue));
  146.         else return Optional.empty();
  147.     }

  148.     /**
  149.      * @param name The name of the kernel which should be retrieved.
  150.      * @param source The source code of the kernel which should be compiled if it is not present in the cache.
  151.      * @return The kernel caller for the kernel of the requested name, either from cache,
  152.      *          or compiled from the given source code if it was not present in the cache.
  153.      */
  154.     public KernelCaller findOrCompileAdHocKernel( String name, Supplier<String> source ) {
  155.         cl_ad_hoc adHoc = _kernelCache.get(name);
  156.         if ( adHoc != null ) return new KernelCaller(adHoc.kernel, _queue);
  157.         else return compileAndGetAdHocKernel(name, source.get());
  158.     }

  159.     /**
  160.      * This method compiles and returns the {@link KernelCaller} for a so called "ad hoc" kernel.
  161.      * Ad hoc is a Latin phrase meaning literally 'to this'.
  162.      * In English, it generally signifies a solution designed for a specific problem or task,
  163.      * non-generalizable, and not intended to be adapted to other purposes.
  164.      * This leads to the purpose of ad hoc kernel compilation, namely to be able to compile
  165.      * unique kernels with a specific purpose created on the fly during runtime by operations.
  166.      * This might be useful for high performance operations on tensors with specific dimensions and
  167.      * or possibly other variables / properties which might be taken into account...
  168.      *
  169.      * @param name   The name of the kernel which ought to be compiled.
  170.      * @param source The source of the kernel which ought to be compiled.
  171.      * @return The {@link KernelCaller} for the compiled kernel.
  172.      */
  173.     public synchronized KernelCaller compileAndGetAdHocKernel( String name, String source ) {
  174.         return compileAdHocKernel( name, source )
  175.                 .findAdHocKernel( name )
  176.                 .orElseThrow(() -> new RuntimeException("Failed to compile kernel: " + name));
  177.     }

  178.     /**
  179.      * This method compiles so called "ad hoc" kernel.
  180.      * Ad hoc is a Latin phrase meaning literally 'to this'.
  181.      * In English, it generally signifies a solution designed for a specific problem or task,
  182.      * non-generalizable, and not intended to be adapted to other purposes.
  183.      * This leads to the purpose of ad hoc kernel compilation, namely to be able to compile
  184.      * unique kernels with a specific purpose created on the fly during runtime by operations.
  185.      * This might be useful for high performance operations on tensors with specific dimensions and
  186.      * or possibly other variables / properties which might be taken into account...
  187.      *
  188.      * @param name   The name of the kernel which ought to be compiled.
  189.      * @param source The source of the kernel which ought to be compiled.
  190.      * @return This very instance in order to enable the factory pattern.
  191.      */
  192.     public synchronized OpenCLDevice compileAdHocKernel( String name, String source ) {
  193.         if (this.hasAdHocKernel(name)) {
  194.             cl_ad_hoc adHoc = _kernelCache.get(name);
  195.             String message =
  196.                 "Cannot compile kernel source for name '" + name + "' because the name is already taken.\n" +
  197.                 "Use another name or find out why this kernel already exists.\n" +
  198.                 (
  199.                         adHoc.source.equals(source)
  200.                                 ? "Besides the name, the source code of the existing kernel is also identical.\n" : ""
  201.                 );
  202.             _log.error(message);
  203.             throw new IllegalArgumentException(message);
  204.         }

  205.         // Create the program for the kernel
  206.         cl_program cpProgram = clCreateProgramWithSource(
  207.                 getPlatform().getContext(),
  208.                 1,
  209.                 new String[]{source},
  210.                 null,
  211.                 null
  212.         );

  213.         // Build the program
  214.         int err = clBuildProgram(
  215.                         cpProgram,
  216.                         1,
  217.                         new cl_device_id[]{_deviceId},
  218.                         "-cl-mad-enable",
  219.                         null,
  220.                         null
  221.                 );

  222.         if ( err != CL_SUCCESS )
  223.             _log.error("Error when trying to compile 'ad hoc kernel' named '"+name+"'! Error code: "+err);

  224.         //TODO: check compilation errors!
  225.         cl_kernel kernel;
  226.         try {
  227.             // Create the kernel
  228.             kernel = clCreateKernel(cpProgram, name, null);
  229.         } catch (Exception e) {
  230.             if (e.getMessage().equals("CL_INVALID_KERNEL_NAME") && !source.contains("__kernel void " + name)) {
  231.                 String message = "Method 'clCreateKernel' failed! The name of the '__kernel' method declared inside \n" +
  232.                                  "the source String does not match the provided name needed for kernel creation.";
  233.                 _log.error(message, e);
  234.                 throw new IllegalArgumentException(message);
  235.             }
  236.             _log.error("Method call 'clCreateKernel(.., name=\"" + name + "\", ..)' failed!", e);
  237.             throw e;
  238.         }
  239.         cl_ad_hoc adHoc = new cl_ad_hoc(source, kernel, cpProgram);

  240.         // Storing the ad hoc object in a weak hash map for fast access by operations :
  241.         _kernelCache.put( name, adHoc );

  242.         _cleaning(adHoc, () -> {
  243.             clReleaseKernel(kernel);
  244.             clReleaseProgram(cpProgram);
  245.         });
  246.         return this;
  247.     }

  248.     @Override
  249.     public Operation optimizedOperationOf( Function function, String name ) {
  250.         return new CLFunctionCompiler( this, function, name ).optimize();
  251.     }

  252.     /**
  253.      * This method tells the to restore all tensors stored on it and release all resources.
  254.      */
  255.     @Override
  256.     public void dispose() {
  257.         _numberOfTensors = 0;
  258.         clFinish( _queue );
  259.         clReleaseCommandQueue( _queue );
  260.     }

  261.     /**
  262.      * This method assumes that the passed tensor is stored on this device instance.
  263.      * If the tensor is stored on the device then the method loads the outsourced
  264.      * data of the tensor back into primitive JVM arrays and restores the tensor
  265.      * freshly in RAM.
  266.      *
  267.      * @param tensor The tensor whose data ought to be restored (loaded to RAM).
  268.      * @return This device, which enables method chaining.
  269.      */
  270.     @Override
  271.     public Device<Number> restore( Tensor<Number> tensor ) {
  272.         if ( !this.has( tensor ) ) {
  273.             String message = "The passed tensor cannot be restored from this OpenCL device " +
  274.                                 "because the tensor is not stored on the device.\n";
  275.             _log.error(message);
  276.             throw new IllegalArgumentException(message);
  277.         }

  278.         Object value  = _read(JVMData.of(tensor.itemType(), tensor.isVirtual() ? 1 : tensor.size()), tensor, 0).getArray();

  279.         Class<?> arrayType = Objects.requireNonNull(tensor.getDataType().getTypeClassInstance(NumericType.class)).holderArrayType();

  280.         value = DataConverter.get().convert( value, arrayType );

  281.         this.free( tensor );
  282.         tensor.find( Tensor.class ).ifPresent( this::restore );
  283.         tensor.getMut().setItems( value );
  284.         return this;
  285.     }


  286.     /**
  287.      * Implementations of this method ought to store the value
  288.      * of the given tensor and the "parent" tensor in whatever
  289.      * formant suites the underlying implementation and or final type.
  290.      * {@link Device} implementations are also tensor storages
  291.      * which may also have to store tensors which are slices of bigger tensors.   <br><br>
  292.      *
  293.      * @param tensor The tensor whose data ought to be stored.
  294.      */
  295.     private <T extends Number> void _store(Tensor<T> tensor, Tensor<T> parent ) {
  296.         if (!parent.isOutsourced()) throw new IllegalStateException("Data parent is not outsourced!");
  297.         _add(
  298.             tensor.getMut().upcast(Number.class),
  299.             parent.getMut().getData(),
  300.             () -> tensor.set((Component) this)
  301.         );
  302.     }

  303.     private <T extends Number> void _add(
  304.         Tensor<Number> tensor,
  305.         Data<T> parentData,
  306.         Runnable migration // Causes the device to be a component of the tensor!
  307.     ) {
  308.         if ( this.has( tensor ) ) {
  309.             _LOG.debug("Trying to add a tensor to a device which already reports hosting it.");
  310.             return;
  311.         }

  312.         boolean convertToFloat = Neureka.get()
  313.                                     .backend()
  314.                                     .find(CLBackend.class)
  315.                                     .map( it -> it.getSettings().isAutoConvertToFloat() )
  316.                                     .orElse(false);

  317.         Data<Number> data;
  318.         if ( parentData == null ) {
  319.             if ( tensor.getMut().getData().owner() == this ) {
  320.                 migration.run();
  321.                 return;
  322.             }
  323.             JVMData jvmData = null;
  324.             jvmData = JVMData.of( tensor.getMut().getData().getOrNull(), convertToFloat );
  325.             cl_tsr<Number, Number> newClt;
  326.             newClt = _storeNew( jvmData );
  327.             if ( tensor.rqsGradient() && tensor.hasGradient() )
  328.                 this.store(tensor.gradient().orElseThrow(()->new IllegalStateException("Gradient missing!")));

  329.             cl_mem[] memos = new cl_mem[]{newClt.value.data};
  330.             clEnqueueMigrateMemObjects(
  331.                     _queue, memos.length, memos,
  332.                     CL_MIGRATE_MEM_OBJECT_HOST,
  333.                     0,
  334.                     null,
  335.                     null
  336.                 );

  337.             data = _dataArrayOf(newClt, (DataType<Number>) _dataTypeOf(newClt));
  338.         }
  339.         else
  340.             data = (Data<Number>) parentData;

  341.         tensor.getMut().setData( data );
  342.         migration.run();

  343.         // When tensors get stored on this device,
  344.         // they can be implicitly converted to a float tensor:
  345.         if ( convertToFloat )
  346.             tensor.getMut().toType(F32.class);
  347.     }

  348.     private cl_tsr<Number, Number> _storeNew( JVMData jvmData ) {
  349.         return _storeNew( jvmData, false );
  350.     }

  351.     private cl_tsr<Number, Number> _storeNew( JVMData jvmData, boolean allocateTargetSize ) {
  352.         cl_tsr.cl_value newVal = new cl_tsr.cl_value((int) (allocateTargetSize ? jvmData.getTargetLength() : jvmData.getLength()));
  353.         cl_tsr<Number, Number> newClt = new cl_tsr<>(newVal, jvmData.getType());
  354.         _store( jvmData, newClt, allocateTargetSize );
  355.         return newClt;
  356.     }

  357.     public cl_config clConfigOf(Tensor<?> t ) {
  358.         return clConfigOf( t.getNDConf() );
  359.     }

  360.     public cl_config clConfigOf(NDConfiguration ndc ) {
  361.         cl_config config = _configs.get(ndc);
  362.         if ( config == null ) {
  363.             config = _writeNewNDConfig( ndc );
  364.             _configs.put(ndc, config);
  365.         }
  366.         return config;
  367.     }

  368.     private cl_config _writeNewNDConfig(NDConfiguration ndc ) {

  369.         cl_config clf = new cl_config();

  370.         //Config format: <[ shape | strides | indicesMap | indices | scale ]>
  371.         int[] config = ndc.asInlineArray();

  372.         //shape/strides/map/offset/spread
  373.         clf.data = clCreateBuffer(
  374.                     _platform.getContext(),
  375.                     CL_MEM_READ_WRITE,
  376.                     (long) config.length * Sizeof.cl_int,
  377.                     null, null
  378.                 );

  379.         clEnqueueWriteBuffer(
  380.                 _queue, clf.data, CL_TRUE, 0,
  381.                 (long) config.length * Sizeof.cl_int,
  382.                 Pointer.to(config),
  383.                 0,
  384.                 null, null
  385.             );
  386.         final cl_mem clConfMem = clf.data;
  387.         _cleaning( clf, () -> clReleaseMemObject(clConfMem) );
  388.         return clf;
  389.     }

  390.     private void _store(
  391.        JVMData jvmData,
  392.        cl_tsr<?, ?> newClTensor,
  393.        boolean allocateTarget
  394.     ) {
  395.         long bufferLength = allocateTarget ? jvmData.getTargetLength() : jvmData.getLength();

  396.         cl_mem mem = clCreateBuffer(
  397.                         _platform.getContext(),
  398.                         CL_MEM_READ_WRITE,
  399.                         (long) jvmData.getItemSize() * bufferLength,
  400.                         null,
  401.                         null
  402.                     );

  403.         newClTensor.value.data = mem;

  404.         // Virtual means that there is only a single value in the JVM array.
  405.         // So we don't have to write the whole array to the device!
  406.         // Instead, we can just fill the device memory with the single value.
  407.         boolean isASingleValue = jvmData.isVirtual();

  408.         if ( isASingleValue )
  409.             clEnqueueFillBuffer(
  410.                     _queue, mem, jvmData.getPointer(), // pattern
  411.                     jvmData.getItemSize(), 0,
  412.                     (long) jvmData.getItemSize() * bufferLength,
  413.                     0, null, null
  414.                 );
  415.         else
  416.             clEnqueueWriteBuffer(
  417.                     _queue, mem,
  418.                     CL_TRUE, 0,
  419.                     (long) jvmData.getItemSize() * bufferLength,
  420.                     jvmData.getPointer(), 0, null, null
  421.                 );
  422.     }

  423.     @Override
  424.     public final <T extends Number> Device<Number> free( Tensor<T> tensor ) {
  425.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);
  426.         if ( clt == null ) return this;
  427.         tensor.getMut().setData(null);
  428.         tensor.find(Device.class).ifPresent(
  429.             device -> {
  430.                 tensor.remove( Device.class );
  431.                 tensor.find(Tensor.class).ifPresent(
  432.                     gradient ->
  433.                         ( (Tensor<Number>) gradient ).find(Device.class).ifPresent(
  434.                             gradDevice -> {
  435.                                 try {
  436.                                     if ( this.has( gradient ) ) gradDevice.restore( gradient );
  437.                                 }
  438.                                 catch ( Exception exception ) {
  439.                                     _LOG.error(
  440.                                         "Gradient could not be restored from device component when trying to migrate it back to RAM.",
  441.                                         exception
  442.                                     );
  443.                                     throw exception;
  444.                                 }
  445.                                 gradient.remove( Device.class );
  446.                             })
  447.                 );
  448.             }
  449.         );
  450.         return this;
  451.     }

  452.     @Override
  453.     protected final <T extends Number> T _readItem( Tensor<T> tensor, int index ) {
  454.         return (T) _read(JVMData.of(tensor.itemType(), 1), tensor.getMut().upcast(Number.class), index).getElementAt(0);
  455.     }

  456.     @Override
  457.     protected final <T extends Number, A> A _readArray( Tensor<T> tensor, Class<A> arrayType, int start, int size ) {
  458.         return (A) _read(JVMData.of(tensor.itemType(), size), tensor.getMut().upcast(Number.class), start).getArray();
  459.     }

  460.     @Override
  461.     protected final <T extends Number> void _writeItem( Tensor<T> tensor, T item, int start, int size ) {
  462.         _overwrite( tensor, start, JVMData.of(item, size, 0) );
  463.     }

  464.     @Override
  465.     protected final <T extends Number> void _writeArray(
  466.         Tensor<T> tensor,
  467.         Object array,
  468.         int offset,
  469.         int start,
  470.         int size
  471.     ) {
  472.         _overwrite( tensor, start, JVMData.of(array, size, offset) );
  473.     }

  474.     @Override
  475.     public <T extends Number> Data<T> allocate( DataType<T> dataType, NDConfiguration ndc ) {
  476.         JVMData jvmData = JVMData.of( dataType.getItemTypeClass(), ndc.size() );
  477.         cl_tsr<Number, Number> clt = _storeNew(jvmData );
  478.         return (Data<T>) _dataArrayOf(clt, (DataType<Number>) _dataTypeOf(clt));
  479.     }

  480.     @Override
  481.     public <T extends Number> Data<T> allocateFromOne( DataType<T> dataType, NDConfiguration ndc, T initialValue ) {
  482.         JVMData jvmData = JVMData.of( initialValue, ndc.size(), false, true );
  483.         cl_tsr<Number, Number> clt = _storeNew(jvmData );
  484.         return (Data<T>) _dataArrayOf(clt, (DataType<Number>) _dataTypeOf(clt));
  485.     }

  486.     @Override
  487.     public <T extends Number> Data<T> allocateFromAll( DataType<T> dataType, NDConfiguration ndc, Object data ) {
  488.         JVMData jvmData = JVMData.of( data );
  489.         cl_tsr<Number, Number> clt = _storeNew(jvmData );
  490.         return (Data<T>) _dataArrayOf(clt, (DataType<Number>) _dataTypeOf(clt));
  491.     }

  492.     @Override
  493.     protected Data<Number> _actualize( Tensor<?> tensor ) {
  494.         NDConfiguration ndc = tensor.getNDConf();
  495.         Object initialValue = tensor.item();
  496.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);
  497.         if ( clt == null ) throw new IllegalStateException("The tensor has no device component!");
  498.         JVMData jvmData = JVMData.of( initialValue, ndc.size(), false, true );
  499.         clt = _storeNew( jvmData, true );
  500.         return _dataArrayOf(clt, (DataType<Number>) _dataTypeOf(clt));
  501.     }

  502.     @Override
  503.     protected Data<Number> _virtualize( Tensor<?> tensor ) {
  504.         NDConfiguration ndc = tensor.getNDConf();
  505.         Object initialValue = tensor.item();
  506.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);
  507.         if ( clt == null ) throw new IllegalStateException("The tensor has no device component!");
  508.         JVMData jvmData = JVMData.of( initialValue, ndc.size(), false, true );
  509.         clt = _storeNew( jvmData, false );
  510.         return _dataArrayOf(clt, (DataType<Number>) _dataTypeOf(clt));
  511.     }

  512.     @Override
  513.     protected final DataType<?> _dataTypeOf( Object rawData ) {
  514.         LogUtil.nullArgCheck( rawData, "rawData", Object.class );
  515.         if ( rawData instanceof cl_tsr ) {
  516.             cl_dtype type = ((cl_tsr) rawData).dtype;
  517.             switch ( type ) {
  518.                 case F32: return DataType.of( Float.class );
  519.                 case F64: return DataType.of( Double.class );
  520.                 case I32: case U32:
  521.                     return DataType.of( Integer.class );
  522.                 case I64: return DataType.of( Long.class );
  523.                 case I16: case U16:
  524.                     return DataType.of( Short.class );
  525.                 case I8: case U8:
  526.                     return DataType.of( Byte.class );
  527.                 default: throw new IllegalStateException("Unknown OpenCL data type!");
  528.             }
  529.         }
  530.         throw new IllegalStateException("Unknown data type "+rawData.getClass()+"!");
  531.     }

  532.     private void _overwrite(
  533.         Tensor<?> tensor, long offset, JVMData jvmData
  534.     ) {
  535.         if ( jvmData.getLength() == 0 ) return;
  536.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);

  537.         if ( clt.value.event != null ) clWaitForEvents(1, new cl_event[]{clt.value.event});
  538.         clt.value.event = new cl_event();
  539.         long start = offset * jvmData.getItemSize();
  540.         long size  = jvmData.getItemSize() * jvmData.getLength();
  541.         clEnqueueWriteBuffer(
  542.                 _queue, clt.value.data, CL_TRUE,
  543.                 start, size,
  544.                 jvmData.getPointer(), 0, null,
  545.                 clt.value.event
  546.             );
  547.     }

  548.     @Override
  549.     protected final <T extends Number> void _swap( Tensor<T> former, Tensor<T> replacement) {
  550.         cl_tsr<Number, T> clTensor = former.mut().getData().as( cl_tsr.class);
  551.         former.getMut().setData(null);
  552.         replacement.getMut().setData( _dataArrayOf(clTensor, (DataType<T>) _dataTypeOf(clTensor)) );
  553.     }

  554.     @Override
  555.     public boolean update( OwnerChangeRequest<Tensor<Number>> changeRequest ) {
  556.         super.update(changeRequest);
  557.         if ( changeRequest.type() == IsBeing.ADDED ) {
  558.             Tensor<Number> newOwner = changeRequest.getNewOwner();
  559.             _updateInternal(newOwner, changeRequest::executeChange);
  560.         } else
  561.             changeRequest.executeChange(); // This can be an 'add', 'remove' or 'transfer' of this component!
  562.         return true;
  563.     }

  564.     @Override
  565.     protected <T extends Number> int _sizeOccupiedBy( Tensor<T> tensor ) { return tensor.getMut().getData().as( cl_tsr.class).value.size; }

  566.     @Override
  567.     protected <T extends Number> Object _readAll( Tensor<T> tensor, boolean clone ) {
  568.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);
  569.         return _readArray( tensor, tensor.getDataType().dataArrayType(), 0, clt.value.size );
  570.     }

  571.     private void _updateInternal( Tensor<Number> newOwner, Runnable migration) {
  572.         Tensor<Number> root = _findRoot( newOwner );
  573.         if (root != null) _store(newOwner, root);
  574.         else _add( newOwner, null, migration );
  575.     }

  576.     private Tensor<Number> _findRoot( Tensor<Number> newOwner ) {
  577.         Tensor<Number> root = null;
  578.         Relation<Number> relation = newOwner.get(Relation.class);
  579.         if ( relation != null )
  580.             root = ((Relation<Number>) newOwner.get(Relation.class)).findRootTensor().orElse(null);

  581.         return root;
  582.     }

  583.     private JVMData _read( JVMData jvmData, Tensor<Number> tensor, int offset ) {
  584.         cl_tsr<?, ?> clt = tensor.getMut().getData().as( cl_tsr.class);
  585.         clEnqueueReadBuffer(
  586.                 _queue,
  587.                 clt.value.data,
  588.                 CL_TRUE,
  589.                 (long) offset * jvmData.getItemSize(), // one double == eight byte
  590.                 (long) jvmData.getItemSize() * jvmData.getLength(),
  591.                 jvmData.getPointer(),
  592.                 0,
  593.                 null,
  594.                 null
  595.         );
  596.         return jvmData;
  597.     }

  598.     /**
  599.      * @param call The {@link ExecutionCall} which will be queried for a {@link CLImplementation} holding the kernel.
  600.      * @return The kernel call which uses the builder pattern to receive kernel arguments.
  601.      */
  602.     public KernelCaller getKernel( ExecutionCall<OpenCLDevice> call ) {
  603.         String chosen;
  604.         Algorithm algorithm = call.getAlgorithm();
  605.         DeviceAlgorithm<?> deviceAlgorithm = ( algorithm instanceof DeviceAlgorithm ? ((DeviceAlgorithm<?>) algorithm) : null );
  606.         // We create the kernel name from the chosen algorithm:
  607.         ImplementationFor<OpenCLDevice> impl = ( deviceAlgorithm == null ? null : deviceAlgorithm.getImplementationFor(OpenCLDevice.class) );
  608.         if ( impl instanceof CLImplementation && _platform.hasKernel(((CLImplementation) impl).getKernelFor(call).getName()) ) {
  609.             chosen = ((CLImplementation) impl).getKernelFor( call ).getName();
  610.         }
  611.         else
  612.             chosen = call.getAlgorithm().getName() + "_" + call.getOperation().getIdentifier();

  613.         cl_kernel kernel = _platform.getKernel( chosen );
  614.         if ( kernel == null )
  615.             throw new IllegalStateException(
  616.                     "No kernel found for signature '" + chosen + "' for operation '" +  call.getOperation().getIdentifier() + "'."
  617.                 );

  618.         return new KernelCaller(kernel, _queue);
  619.     }

  620.     /**
  621.      * @param name The name of the kernel for which a {@link KernelCaller} should be returned.
  622.      * @return A {@link KernelCaller} for calling the requested kernel.
  623.      */
  624.     public KernelCaller getKernel( String name ) {
  625.         cl_kernel kernel = _platform.getKernel( name );
  626.         if ( kernel == null )
  627.             throw new IllegalStateException("No kernel found with name '" + name + "'.");
  628.         return new KernelCaller(kernel, _queue);
  629.     }

  630.     @Override
  631.     protected boolean _approveExecutionOf(Tensor<?>[] tensors, int d, Operation type ) { return true; }


  632.     /*==================================================================================================================
  633.     |
  634.     |       §(3) : OPENCL PROPERTIES
  635.     |   ---------------------------
  636.     */

  637.     public String name() { return Query.getString( _deviceId, CL_DEVICE_NAME ); }

  638.     public String vendor() { return Query.getString(_deviceId, CL_DEVICE_VENDOR); }

  639.     public String version() { return Query.getString(_deviceId, CL_DRIVER_VERSION); }

  640.     public Type type() {
  641.         long deviceType = Query.getLong(_deviceId, CL_DEVICE_TYPE);
  642.         if ( (deviceType & CL_DEVICE_TYPE_CPU         ) != 0 ) return Type.CPU;
  643.         if ( (deviceType & CL_DEVICE_TYPE_GPU         ) != 0 ) return Type.GPU;
  644.         if ( (deviceType & CL_DEVICE_TYPE_ACCELERATOR ) != 0 ) return Type.ACCELERATOR;
  645.         if ( (deviceType & CL_DEVICE_TYPE_DEFAULT     ) != 0 ) return Type.DEFAULT;
  646.         if ( (deviceType & CL_DEVICE_TYPE_CUSTOM      ) != 0 ) return Type.CUSTOM;
  647.         if ( (deviceType & CL_DEVICE_TYPE_ALL         ) != 0 ) return Type.ALL;
  648.         return Type.UNKNOWN;
  649.     }

  650.     public int maxComputeUnits() { return Query.getInt(_deviceId, CL_DEVICE_MAX_COMPUTE_UNITS); }

  651.     public long maxWorkItemSimensions() { return Query.getLong(_deviceId, CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS); }

  652.     public long[] maxWorkItemSizes() { return Query.getSizes(_deviceId, CL_DEVICE_MAX_WORK_ITEM_SIZES, 3); }

  653.     public long maxWorkGroupSize() { return Query.getSize(_deviceId, CL_DEVICE_MAX_WORK_GROUP_SIZE); }

  654.     public long maxClockFrequenzy() { return Query.getLong(_deviceId, CL_DEVICE_MAX_CLOCK_FREQUENCY); }

  655.     public int maxAddressBits() { return Query.getInt(_deviceId, CL_DEVICE_ADDRESS_BITS); }

  656.     public long maxMemAllocSize() { return Query.getLong(_deviceId, CL_DEVICE_MAX_MEM_ALLOC_SIZE); }

  657.     public long globalMemSize() { return Query.getLong(_deviceId, CL_DEVICE_GLOBAL_MEM_SIZE); }

  658.     public int errorCorrectionSupport() { return Query.getInt(_deviceId, CL_DEVICE_ERROR_CORRECTION_SUPPORT); }

  659.     public int localMemType() { return Query.getInt(_deviceId, CL_DEVICE_LOCAL_MEM_TYPE); }

  660.     public long localMemSize() { return Query.getLong(_deviceId, CL_DEVICE_LOCAL_MEM_SIZE); }

  661.     public long maxConstantBufferSize() { return Query.getLong(_deviceId, CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE); }

  662.     public long maxConstantBufferSizeKB() { return (int) (Query.getLong(_deviceId, CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE) / 1024); }

  663.     public int imageSupport() { return Query.getInt(_deviceId, CL_DEVICE_IMAGE_SUPPORT); }

  664.     public int maxReadImageArgs() { return Query.getInt(_deviceId, CL_DEVICE_MAX_READ_IMAGE_ARGS); }

  665.     public int maxWriteImageArgs() { return Query.getInt(_deviceId, CL_DEVICE_MAX_WRITE_IMAGE_ARGS); }

  666.     public long singleFPConfig() { return Query.getLong(_deviceId, CL_DEVICE_SINGLE_FP_CONFIG); }

  667.     public long image2DMaxWidth() { return Query.getSize(_deviceId, CL_DEVICE_IMAGE2D_MAX_WIDTH); }

  668.     public long image2DMaxHeight() { return Query.getSize(_deviceId, CL_DEVICE_IMAGE2D_MAX_HEIGHT); }

  669.     public long image3DMaxWidth() { return Query.getSize(_deviceId, CL_DEVICE_IMAGE3D_MAX_WIDTH); }

  670.     public long image3DMaxHeight() { return Query.getSize(_deviceId, CL_DEVICE_IMAGE3D_MAX_HEIGHT); }

  671.     public long image3DMaxDepth() { return Query.getSize(_deviceId, CL_DEVICE_IMAGE3D_MAX_DEPTH); }

  672.     public int prefVecWidthChar() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_CHAR); }

  673.     public int prefVecWidthShort() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_SHORT); }

  674.     public int prefVecWidthInt() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_INT); }

  675.     public int prefVecWidthLong() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_LONG); }

  676.     public int prefVecWidthFloat() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_FLOAT); }

  677.     public int prefVecWidthDouble() { return Query.getInt(_deviceId, CL_DEVICE_PREFERRED_VECTOR_WIDTH_DOUBLE); }

  678.     public static class Query {
  679.         /**
  680.          * Returns the value of the device info parameter with the given name
  681.          *
  682.          * @param device    The device
  683.          * @param paramName The parameter name
  684.          * @return The value
  685.          */
  686.         public static int getInt(cl_device_id device, int paramName) {
  687.             return getInts(device, paramName, 1)[0];
  688.         }

  689.         /**
  690.          * Returns the values of the device info parameter with the given name
  691.          *
  692.          * @param device    The device
  693.          * @param paramName The parameter name
  694.          * @param numValues The number of values
  695.          * @return The value
  696.          */
  697.         public static int[] getInts(cl_device_id device, int paramName, int numValues) {
  698.             int[] values = new int[numValues];
  699.             clGetDeviceInfo(device, paramName, (long) Sizeof.cl_int * numValues, Pointer.to(values), null);
  700.             return values;
  701.         }

  702.         /**
  703.          * Returns the value of the device info parameter with the given name
  704.          *
  705.          * @param device    The device
  706.          * @param paramName The parameter name
  707.          * @return The value
  708.          */
  709.         public static long getLong(cl_device_id device, int paramName) {
  710.             return getLongs(device, paramName, 1)[0];
  711.         }

  712.         /**
  713.          * Returns the values of the device info parameter with the given name
  714.          *
  715.          * @param device    The device
  716.          * @param paramName The parameter name
  717.          * @param numValues The number of values
  718.          * @return The value
  719.          */
  720.         public static long[] getLongs(cl_device_id device, int paramName, int numValues) {
  721.             long[] values = new long[numValues];
  722.             clGetDeviceInfo(device, paramName, (long) Sizeof.cl_long * numValues, Pointer.to(values), null);
  723.             return values;
  724.         }

  725.         /**
  726.          * Returns the value of the device info parameter with the given name
  727.          *
  728.          * @param device    The device
  729.          * @param paramName The parameter name
  730.          * @return The value
  731.          */
  732.         public static String getString(cl_device_id device, int paramName) {
  733.             // Obtain the length of the string that will be queried
  734.             long[] size = new long[1];
  735.             clGetDeviceInfo(device, paramName, 0, null, size);

  736.             // Create a buffer of the appropriate size and fill it with the info
  737.             byte[] buffer = new byte[(int) size[0]];
  738.             clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);

  739.             // Create a string from the buffer (excluding the trailing \0 byte)
  740.             return new String(buffer, 0, buffer.length - 1);
  741.         }

  742.         /**
  743.          * Returns the value of the platform info parameter with the given name
  744.          *
  745.          * @param platform  The platform
  746.          * @param paramName The parameter name
  747.          * @return The value
  748.          */
  749.         public static String getString(cl_platform_id platform, int paramName) {
  750.             // Obtain the length of the string that will be queried
  751.             long[] size = new long[1];
  752.             clGetPlatformInfo(platform, paramName, 0, null, size);

  753.             // Create a buffer of the appropriate size and fill it with the info
  754.             byte[] buffer = new byte[(int) size[0]];
  755.             clGetPlatformInfo(platform, paramName, buffer.length, Pointer.to(buffer), null);

  756.             // Create a string from the buffer (excluding the trailing \0 byte)
  757.             return new String(buffer, 0, buffer.length - 1);
  758.         }

  759.         /**
  760.          * Returns the value of the device info parameter with the given name
  761.          *
  762.          * @param device    The device
  763.          * @param paramName The parameter name
  764.          * @return The value64
  765.          */
  766.         public static long getSize(cl_device_id device, int paramName) {
  767.             return getSizes(device, paramName, 1)[0];
  768.         }

  769.         /**
  770.          * Returns the values of the device info parameter with the given name
  771.          *
  772.          * @param device    The device
  773.          * @param paramName The parameter name
  774.          * @param numValues The number of values
  775.          * @return The value64
  776.          */
  777.         public static long[] getSizes(cl_device_id device, int paramName, int numValues) {
  778.             // The size of the returned data has to depend on
  779.             // the size of a size_t, which is handled here
  780.             ByteBuffer buffer = ByteBuffer.allocate(numValues * Sizeof.size_t).order(ByteOrder.nativeOrder());
  781.             clGetDeviceInfo(
  782.                     device,
  783.                     paramName,
  784.                     (long) Sizeof.size_t * numValues,
  785.                     Pointer.to(buffer),
  786.                     null
  787.             );
  788.             long[] values = new long[numValues];
  789.             return getLongs(numValues, buffer, values);
  790.         }

  791.         public static long[] getLongs(int numValues, ByteBuffer buffer, long[] values) {
  792.             if (Sizeof.size_t == 4)
  793.                 for (int i = 0; i < numValues; i++)
  794.                     values[i] = buffer.getInt(i * Sizeof.size_t);
  795.             else
  796.                 for ( int i = 0; i < numValues; i++ )
  797.                     values[i] = buffer.getLong(i * Sizeof.size_t);

  798.             return values;
  799.         }

  800.     }


  801.     private <T extends Number> Data<T> _dataArrayOf( Object data, DataType<T> dataType ) {
  802.         return (Data<T>) new CLData(this, data, (DataType<Number>) dataType);
  803.     }

  804.     private static class CLData extends AbstractDeviceData<Number> {

  805.         public CLData( AbstractBaseDevice<Number> owner, Object dataRef, DataType<Number> dataType ) {
  806.             super(owner, dataRef, dataType, ()->{
  807.                 // In this lambda we free the memory, because the data is no longer needed!
  808.                 cl_tsr<?,?> clTsr = (cl_tsr<?,?>) dataRef;
  809.                 if ( clTsr.value.event != null ) clWaitForEvents(1, new cl_event[]{clTsr.value.event});
  810.                 clReleaseMemObject(clTsr.value.data); // Removing data from the device!
  811.             });
  812.             assert !(dataRef instanceof Data);
  813.         }

  814.     }

  815.     /*==================================================================================================================
  816.     |
  817.     |       §(4) : NESTED CLASSES
  818.     |   ---------------------------
  819.     */

  820.     /**
  821.      * This class is an OpenCL-Device specific tensor component used to store
  822.      * the floating point size ( 1:float, 2:double, ...),
  823.      * a reference to a wrapper containing a pointer to the tensor's configuration (cl_config),
  824.      * and
  825.      * a reference to a wrapper containing a pointer to the tensor's data (cl_data)
  826.      * The latter two lend their identity for garbage collection!
  827.      */
  828.     static class cl_tsr<V, T extends V> {

  829.         cl_tsr(cl_tsr.cl_value value, cl_dtype  dtype) {
  830.             this.value = value;
  831.             this.dtype = dtype;
  832.         }

  833.         /**
  834.          * This class is responsible for representing the
  835.          * data of a tensor stored on the device.
  836.          * Instances of this class lend their identity to utilize garbage collection
  837.          * of the data that they reference via their "cl_mem" field.
  838.          * Meaning this inner memory object "cl_mem" will
  839.          * be freed via a call hook stored inside a Cleaner instance...
  840.          */
  841.         static class cl_value
  842.         {
  843.             cl_value( int size ) { this.size = size; }

  844.             public final int size;
  845.             public cl_mem    data;
  846.             public cl_event  event;
  847.         }

  848.         public final cl_dtype  dtype;
  849.         public final cl_value  value;

  850.         @Override
  851.         public boolean equals(Object obj) {
  852.             if ( !(obj instanceof cl_tsr) ) return false;
  853.             return ((cl_tsr) obj).value == this.value;
  854.         }

  855.         @Override public int hashCode() {
  856.             return value.hashCode();
  857.         }
  858.     }

  859.     /**
  860.      * This class manages a reference to a so called "ad hoc" program & kernel.
  861.      * Ad hoc is a Latin phrase meaning literally 'to this'.
  862.      * In English, it generally signifies a solution designed for a specific problem or task,
  863.      * non-generalizable, and not intended to be adapted to other purposes.
  864.      * This leads to the purpose of this class, namely to hold the context to a unique kernel with
  865.      * a uniquely associated purpose which has been created by an operation possibly for specific
  866.      * tensor dimensions or possibly other properties...
  867.      */
  868.     static final class cl_ad_hoc
  869.     {
  870.         public final String source;
  871.         public final cl_kernel kernel;
  872.         public final cl_program program;

  873.         public cl_ad_hoc(
  874.                 String source, cl_kernel kernel, cl_program program
  875.         ) {
  876.             this.source = source;
  877.             this.kernel = kernel;
  878.             this.program = program;
  879.         }
  880.     }

  881.     /**
  882.      * This is the class responsible for representing NDConfiguration data.
  883.      * Instances of this class lend their identity to utilize garbage collection
  884.      * of the data that they reference via their "cl_mem" field.
  885.      * Meaning this inner memory object "cl_mem" will
  886.      * be freed via a call hook stored inside a Cleaner instance...
  887.      */
  888.     static final class cl_config {
  889.         public cl_mem data;
  890.     }
  891. }