CPUBackend.java

package neureka.backend.cpu;

import neureka.backend.api.BackendExtension;
import neureka.backend.api.ini.BackendLoader;
import neureka.backend.api.ini.ReceiveForDevice;
import neureka.backend.main.algorithms.*;
import neureka.backend.main.implementations.broadcast.*;
import neureka.backend.main.implementations.convolution.CPUConvolution;
import neureka.backend.main.implementations.elementwise.*;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.backend.main.implementations.linear.CPUDot;
import neureka.backend.main.implementations.matmul.CPUMatMul;
import neureka.backend.main.implementations.scalar.CPUScalarFunction;
import neureka.backend.main.operations.functions.*;
import neureka.backend.main.operations.linear.*;
import neureka.backend.main.operations.operator.*;
import neureka.backend.main.operations.other.AssignLeft;
import neureka.backend.main.operations.other.Randomization;
import neureka.backend.main.operations.other.Sum;
import neureka.backend.main.operations.other.internal.CPUSum;
import neureka.devices.host.CPU;

/**
 *  This class loads the CPU operations into the Neureka library context.
 */
public class CPUBackend implements BackendExtension
{
    @Override
    public DeviceOption find(String searchKey) {
        if ( searchKey.equalsIgnoreCase("cpu")  ) new DeviceOption( CPU.get(), 1f );
        if ( searchKey.equalsIgnoreCase("jvm")  ) new DeviceOption( CPU.get(), 1f );
        if ( searchKey.equalsIgnoreCase("java") ) new DeviceOption( CPU.get(), 1f );
        return new DeviceOption( CPU.get(), 0f );
    }

    @Override
    public void dispose() { CPU.get().dispose(); }

    @Override
    public BackendLoader getLoader() { return registry -> _load( registry.forDevice(CPU.class) ); }

    private void _load( ReceiveForDevice<CPU> receive )
    {
        receive.forOperation( Power.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalaBroadcastPower() )
                .set( Broadcast.class,     context -> new CPUBroadcastPower() )
                .set( BiElementwise.class, context -> new CPUBiElementWisePower() );

        receive.forOperation( Addition.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastAddition() )
                .set( Broadcast.class,     context -> new CPUBroadcastAddition() )
                .set( BiElementwise.class, context -> new CPUBiElementWiseAddition() );

        receive.forOperation( Subtraction.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastSubtraction() )
                .set( Broadcast.class,     context -> new CPUBroadcastSubtraction() )
                .set( BiElementwise.class, context -> new CPUBiElementWiseSubtraction() );

        receive.forOperation( Multiplication.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastMultiplication() )
                .set( Broadcast.class,     context -> new CPUBroadcastMultiplication() )
                .set( BiElementwise.class, context -> new CPUBiElementWiseMultiplication() );

        receive.forOperation( Division.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastDivision() )
                .set( Broadcast.class,     context -> new CPUBroadcastDivision() )
                .set( BiElementwise.class, context -> new CPUBiElementWiseDivision() );

        receive.forOperation( Modulo.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastModulo() )
                .set( Broadcast.class,     context -> new CPUBroadcastModulo() )
                .set( BiElementwise.class, context -> new CPUBiElementWiseModulo() );

        receive.forOperation( AssignLeft.class )
                .set( BiScalarBroadcast.class, context -> new CPUScalarBroadcastIdentity() )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseAssignFun() );

        receive.forOperation( Convolution.class )
               .set( NDConvolution.class, context -> new CPUConvolution() );
        receive.forOperation( XConvLeft.class )
                .set( NDConvolution.class, context -> new CPUConvolution() );
        receive.forOperation( XConvRight.class )
                .set( NDConvolution.class, context -> new CPUConvolution() );

        receive.forOperation( MatMul.class )
                .set( MatMulAlgorithm.class, context -> new CPUMatMul() );

        receive.forOperation( DotProduct.class )
                .set( DotProductAlgorithm.class, context -> new CPUDot() );

        receive.forOperation( Sum.class )
                .set( SumAlgorithm.class, context -> new CPUSum() );

        receive.forOperation( Randomization.class )
                .set( ElementwiseAlgorithm.class, context -> new CPURandomization() );

        receive.forOperation( Absolute.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.ABSOLUTE) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.ABSOLUTE) );
        receive.forOperation( Cosinus.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.COSINUS) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.COSINUS) );
        receive.forOperation( GaSU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.GASU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.GASU) );
        receive.forOperation( GaTU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.GATU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.GATU) );
        receive.forOperation( Gaussian.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.GAUSSIAN) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.GAUSSIAN) );
        receive.forOperation( GaussianFast.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.GAUSSIAN_FAST) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.GAUSSIAN_FAST) );
        receive.forOperation( GeLU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.GELU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.GELU) );
        receive.forOperation( Identity.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseAssignFun() )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.IDENTITY) );
        receive.forOperation( Logarithm.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.LOGARITHM) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.LOGARITHM) );
        receive.forOperation( Quadratic.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.QUADRATIC) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.QUADRATIC) );
        receive.forOperation( ReLU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.RELU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.RELU) );
        receive.forOperation( SeLU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SELU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SELU) );
        receive.forOperation( Sigmoid.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SIGMOID) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SIGMOID) );
        receive.forOperation( SiLU.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SILU) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SILU) );
        receive.forOperation( Sinus.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SINUS) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SINUS) );
        receive.forOperation( Softplus.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SOFTPLUS) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SOFTPLUS) );
        receive.forOperation( Softsign.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SOFTSIGN) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SOFTSIGN) );
        receive.forOperation( Tanh.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.TANH) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.TANH) );
        receive.forOperation( TanhFast.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.TANH_FAST) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.TANH_FAST) );
        receive.forOperation( Exp.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.EXP) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.EXP) );
        receive.forOperation( Cbrt.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.CBRT) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.CBRT) );
        receive.forOperation( Log10.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.LOG10) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.LOG10) );
        receive.forOperation( Sqrt.class )
                .set( ElementwiseAlgorithm.class, context -> new CPUElementwiseFunction( ScalarFun.SQRT) )
                .set( ScalarAlgorithm.class, context -> new CPUScalarFunction(ScalarFun.SQRT) );
    }

}