CLBackend.java
package neureka.backend.ocl;
import neureka.backend.api.BackendContext;
import neureka.backend.api.BackendExtension;
import neureka.backend.api.Extensions;
import neureka.backend.api.ini.BackendLoader;
import neureka.backend.api.ini.ReceiveForDevice;
import neureka.backend.main.algorithms.*;
import neureka.backend.main.implementations.broadcast.*;
import neureka.backend.main.implementations.convolution.CLConvolution;
import neureka.backend.main.implementations.elementwise.*;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.backend.main.implementations.linear.CLDot;
import neureka.backend.main.implementations.matmul.CLMatMul;
import neureka.backend.main.implementations.scalar.CLScalarFunction;
import neureka.backend.main.operations.functions.*;
import neureka.backend.main.operations.linear.*;
import neureka.backend.main.operations.linear.internal.opencl.CLSum;
import neureka.backend.main.operations.operator.*;
import neureka.backend.main.operations.other.AssignLeft;
import neureka.backend.main.operations.other.Randomization;
import neureka.backend.main.operations.other.Sum;
import neureka.common.composition.Component;
import neureka.devices.Device;
import neureka.devices.opencl.OpenCLDevice;
import neureka.devices.opencl.OpenCLPlatform;
import neureka.devices.opencl.utility.Messages;
import neureka.math.parsing.ParseUtil;
import org.jocl.cl_platform_id;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.jocl.CL.clGetPlatformIDs;
/**
* This is an OpenCL context component for any given {@link BackendContext} which
* extends a given backend context instance for additional functionality, which in
* this case is the OpenCL backend storing platform and device information.
* {@link BackendContext}s are thread local states
* used for managing {@link neureka.backend.api.Operation}, {@link neureka.math.Function}
* as well as {@link Component} implementation instances like this one.
* A given state might not be compatible with the concepts introduced in other contexts
* which is why it makes sense to have separate "worlds" with potential different operations...
* The component system of the {@link BackendContext} exist so that a given context
* can be extended for more functionality
* and also to attach relevant states like for example in this case the {@link CLBackend}
* instance will directly or indirectly reference kernels, memory objects and other concepts
* exposed by OpenCL...
*/
public final class CLBackend implements BackendExtension
{
private static final Logger _LOG = LoggerFactory.getLogger(CLBackend.class);
private final List<OpenCLPlatform> _platforms = new ArrayList<>();
private final CLSettings _settings = new CLSettings();
/**
* Use this constructor if you want to create a new OpenCL world in which there
* are unique {@link OpenCLPlatform} and {@link OpenCLDevice} instances.
*/
public CLBackend() {}
/**
* @return The number of all {@link OpenCLDevice} instances across all {@link OpenCLPlatform}s.
*/
public int getTotalNumberOfDevices() {
List<OpenCLPlatform> platforms = getPlatforms();
if ( getPlatforms().isEmpty() ) return 0;
return platforms.stream().mapToInt( p -> p.getDevices().size() ).sum();
}
/**
* @return A list of context specific {@link OpenCLPlatform} instances possible containing {@link OpenCLDevice}s.
*/
public List<OpenCLPlatform> getPlatforms() { return Collections.unmodifiableList( _platforms ); }
/**
* @return A container for OpenCL specific settings.
*/
public CLSettings getSettings() { return _settings; }
/**
* Updating the CLContext will cause the list of existing {@link OpenCLPlatform} instances to be
* cleared and refilled with completely new {@link OpenCLPlatform} instances.
* This will in effect also cause the recreation of any {@link OpenCLDevice} instances
* as part of these {@link OpenCLPlatform}s.
* This will subsequently cause the recompilation of many OpenCL kernels.
*/
@Override
public boolean update( OwnerChangeRequest<Extensions> changeRequest ) {
_platforms.clear();
_platforms.addAll( _findLoadAndCompileForAllPlatforms() );
changeRequest.executeChange(); // This can be an 'add', 'remove' or 'transfer' of this component!
return true;
}
@Override
public String toString() {
return this.getClass().getSimpleName()+"@"+Integer.toHexString(hashCode())+"[" +
"platforms=["+
_platforms.stream().map(Object::toString).collect(Collectors.joining(","))+
"]" +
"]";
}
/**
* @return A new list of freshly created {@link OpenCLPlatform} instances containing freshly instantiated {@link OpenCLDevice}s and kernels.
*/
private static List<OpenCLPlatform> _findLoadAndCompileForAllPlatforms()
{
// Obtain the number of platforms
int[] numPlatforms = new int[ 1 ];
clGetPlatformIDs( 0, null, numPlatforms );
// Obtain the platform IDs
cl_platform_id[] platforms = new cl_platform_id[ numPlatforms[ 0 ] ];
clGetPlatformIDs( platforms.length, platforms, null );
List<OpenCLPlatform> loadedPlatforms = new ArrayList<>();
List<String> failures = new ArrayList<>();
for ( cl_platform_id id : platforms ) {
OpenCLPlatform newPlatform = null;
try {
newPlatform = new OpenCLPlatform( id );
} catch ( Exception e ) {
String message =
"Failed to instantiate '"+OpenCLPlatform.class.getSimpleName()+"' " +
"with id '0x"+Long.toHexString(id.getNativePointer())+"'!";
_LOG.error( message, e );
failures.add( message + " Reason: " + e.getMessage() );
}
if ( newPlatform != null )
loadedPlatforms.add( newPlatform );
}
if ( loadedPlatforms.isEmpty() || loadedPlatforms.stream().allMatch( p -> p.getDevices().isEmpty() ) )
_LOG.info( Messages.clContextCouldNotFindAnyDevices() );
if ( loadedPlatforms.isEmpty() && platforms.length > 0 )
// There should be at least one platform with at least one device!
throw new RuntimeException(
"Failed to instantiate any '"+OpenCLPlatform.class.getSimpleName()+"' instance!\n" +
"Reasons: \n " + failures.stream().collect(Collectors.joining("\n "))
);
return loadedPlatforms;
}
@Override
public DeviceOption find( String searchKey ) {
Device<Number> result = null;
double score = 0;
for ( OpenCLPlatform p : _platforms ) {
for ( OpenCLDevice d : p.getDevices() ) {
double similarity = Stream.of("opencl",d.type().name(),d.name(),d.vendor())
.map( word -> word.trim().toLowerCase() )
.mapToDouble( word -> ParseUtil.similarity( word, searchKey ) )
.max()
.orElse(0);
if ( similarity > score ) {
result = d;
score = similarity;
if ( score == 1 )
return new DeviceOption( result, score );
}
}
}
return new DeviceOption( result, score );
}
@Override
public void reset() {
_settings.reset();
}
/**
* This method will free all the resources occupied by this context,
* meaning that all platforms and their devices will be disposed.
* Their kernels will be removed and their tensors restored.
*/
@Override
public void dispose() {
for ( OpenCLPlatform platform : _platforms ) {
for ( OpenCLDevice device : platform.getDevices() ) device.dispose();
platform.dispose();
}
_platforms.clear();
}
@Override
public BackendLoader getLoader() {
return receiver -> _load( receiver.forDevice(OpenCLDevice.class) );
}
private void _load( ReceiveForDevice<OpenCLDevice> receive )
{
receive.forOperation( Power.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastPower( context.getOperationIdentidier() ) )
.set( Broadcast.class, context -> new CLBroadcastPower( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwisePower( context.getOperationIdentidier() ) );
receive.forOperation( Addition.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastAddition(context.getOperationIdentidier()) )
.set( Broadcast.class, context -> new CLBroadcastAddition( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwiseAddition( context.getOperationIdentidier() ));
receive.forOperation( Subtraction.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastSubtraction( context.getOperationIdentidier() ) )
.set( Broadcast.class, context -> new CLBroadcastSubtraction( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwiseSubtraction( context.getOperationIdentidier() ) );
receive.forOperation( Multiplication.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastMultiplication( context.getOperationIdentidier() ) )
.set( Broadcast.class, context -> new CLBroadcastMultiplication( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwiseMultiplication( context.getOperationIdentidier() ) );
receive.forOperation( Division.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastDivision( context.getOperationIdentidier() ) )
.set( Broadcast.class, context -> new CLBroadcastDivision( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwiseDivision( context.getOperationIdentidier() ) );
receive.forOperation( Modulo.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastModulo( context.getOperationIdentidier() ) )
.set( Broadcast.class, context -> new CLBroadcastModulo( context.getOperationIdentidier() ) )
.set( BiElementwise.class, context -> new CLBiElementwiseModulo( context.getOperationIdentidier() ) );
receive.forOperation( AssignLeft.class )
.set( BiScalarBroadcast.class, context -> new CLScalarBroadcastIdentity( context.getOperationIdentidier() ) )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.IDENTITY ) );
receive.forOperation( Convolution.class )
.set( NDConvolution.class, context -> new CLConvolution( context.getOperationIdentidier() ) );
receive.forOperation( XConvLeft.class )
.set( NDConvolution.class, context -> new CLConvolution( context.getOperationIdentidier() ) );
receive.forOperation( XConvRight.class )
.set( NDConvolution.class, context -> new CLConvolution( context.getOperationIdentidier() ) );
receive.forOperation( MatMul.class )
.set( MatMulAlgorithm.class, context -> new CLMatMul() );
receive.forOperation( DotProduct.class )
.set( DotProductAlgorithm.class, context -> new CLDot() );
receive.forOperation( Sum.class )
.set( SumAlgorithm.class, context -> new CLSum() );
receive.forOperation( Randomization.class )
.set( ElementwiseAlgorithm.class, context -> new CLRandomization() );
receive.forOperation( Absolute.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.ABSOLUTE) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.ABSOLUTE) );
receive.forOperation( Cosinus.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.COSINUS) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.COSINUS) );
receive.forOperation( GaSU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.GASU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.GASU) );
receive.forOperation( GaTU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.GATU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.GATU) );
receive.forOperation( Gaussian.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.GAUSSIAN) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.GAUSSIAN) );
receive.forOperation( GaussianFast.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.GAUSSIAN_FAST) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.GAUSSIAN_FAST) );
receive.forOperation( GeLU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.GELU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.GELU) );
receive.forOperation( Identity.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.IDENTITY) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.IDENTITY) );
receive.forOperation( Logarithm.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.LOGARITHM) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.LOGARITHM) );
receive.forOperation( Quadratic.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.QUADRATIC) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.QUADRATIC) );
receive.forOperation( ReLU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.RELU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.RELU) );
receive.forOperation( SeLU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SELU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SELU) );
receive.forOperation( Sigmoid.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SIGMOID) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SIGMOID) );
receive.forOperation( SiLU.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SILU) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SILU) );
receive.forOperation( Sinus.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SINUS) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SINUS) );
receive.forOperation( Softplus.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SOFTPLUS) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SOFTPLUS) );
receive.forOperation( Softsign.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SOFTSIGN) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SOFTSIGN) );
receive.forOperation( Tanh.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.TANH) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.TANH) );
receive.forOperation( TanhFast.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.TANH_FAST) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.TANH_FAST) );
receive.forOperation( Exp.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.EXP) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.EXP) );
receive.forOperation( Cbrt.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.CBRT) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.CBRT) );
receive.forOperation( Log10.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.LOG10) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.LOG10) );
receive.forOperation( Sqrt.class )
.set( ElementwiseAlgorithm.class, context -> new CLElementwiseFunction( ScalarFun.SQRT) )
.set( ScalarAlgorithm.class, context -> new CLScalarFunction(ScalarFun.SQRT) );
}
}