MatMulAlgorithm.java

package neureka.backend.main.algorithms;

import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.devices.Device;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.config.NDConfiguration;
import neureka.ndim.config.types.simple.Simple2DConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MatMulAlgorithm extends AbstractFunDeviceAlgorithm<MatMulAlgorithm>
{
    private static final Logger _LOG = LoggerFactory.getLogger(MatMulAlgorithm.class);

    public MatMulAlgorithm() {
        super("simple_matmul");
        setIsSuitableFor(
                call -> call.validate()
                        .allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
                        .getEstimator()
                        .goodIfAnyNonNull( t -> t.getNDConf() instanceof Simple2DConfiguration)
                        .badIfAnyNonNull( t -> !( t.getNDConf() instanceof Simple2DConfiguration) )
                        .getEstimation()
        );
        setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY );
        setExecution(
            (outerCaller, outerCall) ->
                Result.of(AbstractDeviceAlgorithm.executeDeviceAlgorithm(_prepare(outerCall)))
                .withAutoDiff( (Function f, ExecutionCall<? extends Device<?>> adCall ) ->
                {
                    if ( adCall.autogradMode().allowsForward() )
                        throw new IllegalArgumentException("Matrix multiplication does not support forward-AD!");
                    Function matMul = Neureka.get().backend().getFunction().matMul();
                    int d = ( 1 + adCall.getValOf( Arg.DerivIdx.class ) ) % 2;
                    Tensor<?> derivative = Util.transpose(adCall.input( d )).deepCopy().mut().setIsIntermediate( true ); // We need to clone it to make it have a simple nd configuration...
                    derivative.to(adCall.getDevice());
                    return ADAction.of(target -> {
                        Tensor<?> result;
                        switch ( d ) {
                            case 0:
                                result = matMul.execute(derivative, target.error());
                                break;
                            case 1:
                                result = matMul.execute(target.error(), derivative);
                                break;
                            default:
                                throw new IllegalStateException("This should never happen!");
                        }
                        return result;
                    });
                })
        );
        setCallPreparation(MatMulAlgorithm::_prepare);
    }

    private static ExecutionCall<Device<Object>> _prepare( ExecutionCall<?> call )
    {
        assert call.arity() <= 3;
        if ( call.arity() == 2 ) call = call.withAddedInputAt(0, null);
        if ( call.input( 0 ) == null ) // Creating a new tensor:
            call = _withNewOutput( call );

        return (ExecutionCall<Device<Object>>) _autoClone( call );
    }

    private static ExecutionCall<?> _withNewOutput( ExecutionCall<?> call )
    {
        Class<Number> type = (Class<Number>) call.input(  1 ).getDataType().getItemTypeClass();

        int[] shp = new int[]{ call.input( 1 ).shape(0), call.input( 2 ).shape(1) };
        Tensor<Number> output = Tensor.of( type ).withShape( shp ).all( 0 ).mut().setIsIntermediate( true );

        call = _checkAndPrepareLayout( call, output );

        call.getDeviceFor(Number.class).store( output );
        return call.withInputAt( 0, output );
    }

    private static ExecutionCall<?> _checkAndPrepareLayout( ExecutionCall<?> call, Tensor<?> c )
    {
        Tensor<?> a = call.input( 1 );
        Tensor<?> b = call.input( 2 );
        // We need to make sure that the matrices have a common/compatible layout,
        // ..before we can before the actual a @ b = c matrix multiplication!
        NDConfiguration.Layout layoutA = a.getNDConf().getLayout();
        NDConfiguration.Layout layoutB = b.getNDConf().getLayout();
        NDConfiguration.Layout layoutC = c.getNDConf().getLayout();

        boolean aIsCompatible = isRMOrCM( layoutA );
        boolean bIsCompatible = isRMOrCM( layoutB );

        Function relayout = Neureka.get().backend().getFunction().relayout();

        if ( aIsCompatible ) {
            if ( layoutB != NDConfiguration.Layout.SYMMETRIC )
                b = relayout.with(Arg.Layout.of(layoutA)).call(b); // We choose a valid layout based on a
            layoutC = layoutA;
        } else if ( bIsCompatible ) {
            if ( layoutA != NDConfiguration.Layout.SYMMETRIC )
                a = relayout.with(Arg.Layout.of(layoutB)).call(a); // We choose a valid layout based on b
            layoutC = layoutB;
        } else {
            // Ok so the inputs are unspecific/symmetric/ (not RM or CM)
            // So we just need to decide on any valid layout really:
            layoutC = isRMOrCM(layoutC) ? layoutC : NDConfiguration.Layout.ROW_MAJOR;
            a = relayout.with(Arg.Layout.of(layoutC)).call(a);
            b = relayout.with(Arg.Layout.of(layoutC)).call(b);
        }

        c.mut().toLayout( layoutC );
        c.mut().setIsVirtual( false ); // This statement is after the layout conversion for performance reasons (virtual tensors barely need copying).

        return call.withInputAt( 1, a ).withInputAt( 2, b );
    }

    private static boolean isRMOrCM(NDConfiguration.Layout layout ) {
        return layout == NDConfiguration.Layout.ROW_MAJOR ||
               layout == NDConfiguration.Layout.COLUMN_MAJOR;
    }

    /**
     *  This method will clone {@link Tensor} instances if they do not
     *  possess a simple {@link neureka.ndim.config.NDConfiguration}.
     *  This is usually the case when they are slices or permuted views on data...
     *  The reason for this is simply that we need inline data for the OpenCL kernels!
     *
     *
     * @param call The execution call whose tensors ought to be cloned based on the complexity of their access patterns.
     */
    private static ExecutionCall<?> _autoClone( ExecutionCall<?> call ) {
        for ( int i = 0; i < call.arity(); i++ )
            if (
                (!_isSimpleRowMajorMatrix( call.input( i ) ) && !_isSimpleColumnMajorMatrix( call.input( i ) ))
                        ||
                call.input( i ).isPartialSlice()
            ) {
                _LOG.debug("Auto cloning a tensor which does not have a simple ND configuration...");
                call = call.withInputAt( i, call.input( i ).deepCopy().mut().setIsIntermediate( true ) );
                /*
                    The user should do cloning explicitly because using slices
                    will cause the backend to perform auto cloning every time the
                    slice is being used for operations like this one...
                 */
            }

        return call;
    }

    private static boolean _isSimpleColumnMajorMatrix( Tensor<?> t ) {
        return t.rank() == 2 && t.getNDConf().getLayout() == NDConfiguration.Layout.COLUMN_MAJOR;
    }

    private static boolean _isSimpleRowMajorMatrix( Tensor<?> t ) {
        return t.rank() == 2 && t.getNDConf().getLayout() == NDConfiguration.Layout.ROW_MAJOR;
    }

}