DotProductAlgorithm.java

package neureka.backend.main.algorithms;

import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.devices.Device;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.config.NDConfiguration;
import neureka.ndim.config.types.simple.Simple1DConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DotProductAlgorithm extends AbstractFunDeviceAlgorithm<DotProductAlgorithm>
{
    static Logger _LOG = LoggerFactory.getLogger(DotProductAlgorithm.class);

    public DotProductAlgorithm() {
        super("dot_algorithm");
        setIsSuitableFor(
            call -> call.validate()
                    .allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
                    .allNotNull( t -> t.shape().count( d -> d > 1 ) <= 1 )
                    .getEstimator()
                    .goodIfAnyNonNull( t -> t.getNDConf() instanceof Simple1DConfiguration)
                    .getEstimation() * 1.1f
        );
        setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY );
        setExecution(
            (function, call) -> {
                call = _prepare( call );
                return
                    Result.of(AbstractDeviceAlgorithm.executeDeviceAlgorithm( call ))
                    .withAutoDiff( (Function f, ExecutionCall<? extends Device<?>> adCall ) ->
                    {
                        if ( adCall.autogradMode().allowsForward() )
                            throw new IllegalArgumentException("Dot product does not support forward-AD!");
                        Function mul = Neureka.get().backend().getFunction().mul();
                        int d = ( 1 + adCall.getValOf( Arg.DerivIdx.class ) ) % 2;
                        Tensor<?> derivative = Util.transpose(adCall.input( d )).deepCopy().mut().setIsIntermediate( true ); // We need to clone it to make it have a simple nd configuration...
                        derivative.to(adCall.getDevice());
                        return ADAction.of( target -> mul.execute( target.error(), derivative ) );
                    });
            }
        );
        setCallPreparation( c -> c );
    }


    private static ExecutionCall<Device<Object>> _prepare( ExecutionCall call )
    {
        assert call.arity() <= 3;
        if ( call.arity() == 2 ) call = call.withAddedInputAt(0, null);

        call = _withDimTrim( call );

        if ( call.input( 0 ) == null ) // Creating a new tensor:
            call = _withNewOutput( call );

        return (ExecutionCall<Device<Object>>) _autoClone( call );
    }

    private static ExecutionCall<?> _withDimTrim( ExecutionCall<?> call ) {
        Tensor<?> a = call.input( 0 );
        Tensor<?> b = call.input( 1 );
        Tensor<?> c = call.input( 2 );
        Function dimTrim = Neureka.get().backend().getAutogradFunction().dimTrim();
        if ( a != null && a.rank() > 1 ) call = call.withInputAt( 0, dimTrim.execute( a ).deepClone() );
        if ( b != null && b.rank() > 1 ) call = call.withInputAt( 1, dimTrim.execute( b ).deepClone() );
        if ( c != null && c.rank() > 1 ) call = call.withInputAt( 2, dimTrim.execute( c ).deepClone() );
        return call;
    }

    private static ExecutionCall<?> _withNewOutput( ExecutionCall<?> call )
    {
        Class<Number> type = (Class<Number>) call.input(  1 ).getDataType().getItemTypeClass();

        Tensor<Number> output = Tensor.of( type ).withShape( 1 ).all( 0 ).mut().setIsIntermediate( true );

        call = _checkAndPrepareLayout( call, output );

        call.getDeviceFor(Number.class).store( output );
        return call.withInputAt( 0, output );
    }

    private static ExecutionCall<?> _checkAndPrepareLayout( ExecutionCall<?> call, Tensor<?> c )
    {
        Tensor<?> a = call.input( 1 );
        Tensor<?> b = call.input( 2 );
        // We need to make sure that the vectors have a common/compatible layout,
        // ..before we can do the actual a . b = c dot product!
        NDConfiguration.Layout layoutA = a.getNDConf().getLayout();
        NDConfiguration.Layout layoutB = b.getNDConf().getLayout();
        NDConfiguration.Layout layoutC = c.getNDConf().getLayout();

        boolean aIsCompatible = isSymmetric( layoutA );
        boolean bIsCompatible = isSymmetric( layoutB );
        /*
            Symmetric means that the tensor can either be interpreted as a row vector or a column vector.
            Row major means that items are stored in a row-wise fashion
            and column major means that items are stored in a column-wise fashion.
            A vector can be interpreted as a row vector or a column vector and thus is symmetric.
        */

        if ( aIsCompatible ) {
            b = _toInline( b, layoutA );
            layoutC = layoutA;
        } else if ( bIsCompatible ) {
            a = _toInline( a, layoutB );
            layoutC = layoutB;
        } else {
            // Ok so the inputs are unspecific (or RM or CM)
            // So we just need to decide on any valid layout really:
            layoutC = isSymmetric(layoutC) ? layoutC : NDConfiguration.Layout.SYMMETRIC;

            b = _toInline( b, layoutA );
            a = _toInline( a, layoutB );
        }
        c.mut().toLayout( layoutC );
        c.mut().setIsVirtual( false ); // This statement is after the layout conversion for performance reasons (virtual tensors barely need copying).

        return call.withInputAt( 1, a ).withInputAt( 2, b );
    }

    private static Tensor<?> _toInline(Tensor<?> t, NDConfiguration.Layout targetLayout ) {
        Function relayout = Neureka.get().backend().getFunction().relayout();
        if ( t.isVirtual() ) {
            t = t.deepCopy().mut().setIsVirtual(false);
            if ( targetLayout != NDConfiguration.Layout.SYMMETRIC && targetLayout != NDConfiguration.Layout.UNSPECIFIC )
                t = t.mut().toLayout(targetLayout); // We choose a valid layout based on a
        } else
            t = relayout.with(Arg.Layout.of(targetLayout)).call( t ); // We choose a valid layout based on a
        return t;
    }

    private static boolean isSymmetric( NDConfiguration.Layout layout ) {
        return layout == NDConfiguration.Layout.SYMMETRIC;
    }

    /**
     *  This method will clone {@link Tensor} instances if they do not
     *  possess a simple {@link neureka.ndim.config.NDConfiguration}.
     *  This is usually the case when they are slices or permuted views on data...
     *  The reason for this is simply that we need inline data for the OpenCL/CPU kernels!
     *
     * @param call The execution call whose tensors ought to be cloned based on the complexity of their access patterns.
     */
    private static ExecutionCall<?> _autoClone( ExecutionCall<?> call ) {
        for (int i = 0; i < call.arity(); i++ ) {
            if (
                    !_isSimpleSymmetric( call.input( i ) )
                            ||
                    call.input( i ).isPartialSlice()
            ) {
                _LOG.debug("Auto cloning a tensor which does not have a simple symmetric ND configuration...");
                call = call.withInputAt( i, call.input( i ).deepCopy().mut().setIsIntermediate( true ) );
                /*
                    The user should do cloning explicitly because using slices
                    will cause the backend to perform auto cloning every time the
                    slice is being used for operations like this one...
                 */
            }
        }
        return call;
    }

    private static boolean _isSimpleSymmetric( Tensor<?> t ) {
        return t.rank() == 1 && t.getNDConf().getLayout() == NDConfiguration.Layout.SYMMETRIC;
    }

}