ReLayout.java

package neureka.backend.main.operations.other;

import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.Util;
import neureka.backend.main.memory.MemUtil;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.config.NDConfiguration;

import java.util.function.Supplier;

public class ReLayout extends AbstractOperation
{
    public ReLayout()
    {
        super(
            new OperationBuilder()
            .identifier(       "layout"  )
            .operator(         "layout"  )
            .arity(            1          )
            .isOperator(       false      )
            .isIndexer(        false      )
            .isDifferentiable( true       )
            .isInline(         false      )
        );
        setAlgorithm(
            Algorithm
            .withName( "layout" )
            .setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
            .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
            .setExecution(
                ( caller, call ) ->
                {
                    Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
                    Tensor<Object> input = (Tensor<Object>) inputs[0];

                    NDConfiguration.Layout originalLayout = input.getNDConf().getLayout();
                    NDConfiguration.Layout newLayout = call.getValOf( Arg.Layout.class );

                    Tensor<?> reLayout = toLayout( input.deepCopy(), newLayout );

                    return Result.of(reLayout.mut().setIsIntermediate(true))
                            .withADAction( target -> {
                                Tensor<Object> error = (Tensor<Object>) target.error().deepCopy();
                                return error.mut().toLayout(originalLayout);
                            });
                }
            )
            .buildFunAlgorithm()
        );
    }

    @Override
    public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }


    public static Tensor<?> toLayout(Tensor<?> t, NDConfiguration.Layout target )
    {
        if ( target == t.getNDConf().getLayout() ) return t;
        if ( target == NDConfiguration.Layout.SYMMETRIC )
            throw new UnsupportedOperationException(
                    "Conversion of a non-symmetric tensor to a symmetric tensor is not possible!"
            );
        if ( target == NDConfiguration.Layout.UNSPECIFIC )
            throw new UnsupportedOperationException(
                    "Conversion of a tensor to an unspecific layout is not possible!"
            );

        if ( target == NDConfiguration.Layout.ROW_MAJOR || target == NDConfiguration.Layout.COLUMN_MAJOR ) {
            if ( t.getNDConf().getLayout() == NDConfiguration.Layout.SYMMETRIC )
                return t; // Symmetric tensors are both row and column major.
        }

        NDConfiguration old = t.getNDConf();

        if ( target == NDConfiguration.Layout.ROW_MAJOR )
            _fromCMToRM( t );
        else
            _fromRMToCM( t );

        _checkLayoutConversion( t.getNDConf(), old, target );
        return t;
    }

    /**
     *  Converts this tensor from column major to column major layout.
     */
    private static void _fromCMToRM( Tensor<?> t ) {
        if ( t.getNDConf().isVirtual() ) {
            t.mut().setIsVirtual( false ); // We actualized the tensor before conversion!
            if ( t.getNDConf().getLayout() == NDConfiguration.Layout.ROW_MAJOR )
                return;
        }
        Tensor<?> clone = t.deepCopy(); // A clone will have by default a row major layout.
        t.mut().setNDConf( clone.getNDConf() );
        _assignIfActual( t, () -> clone );
    }

    /**
     *  Converts this tensor from row major to column major layout.
     */
    private static void _fromRMToCM( Tensor<?> t ) {
        _assignIfActual( t, () -> Util.transpose(t).deepCopy().getMut().detach() );
        NDConfiguration old = t.getNDConf();
        int[] newTranslation = NDConfiguration.Layout.COLUMN_MAJOR.newStridesFor(old.shape());
        if ( old.isVirtual() ) {
            t.mut().setIsVirtual(false);
            old = t.getNDConf();
        }
        t.mut().setNDConf( _createNewNDCFrom( old, newTranslation ) );
    }

    /**
     *  This will only call the supplier and copy its result into this tensor
     *  if this tensor is not virtual (meaning this is an actual tensor).
     */
    private static void _assignIfActual(Tensor<?> t, Supplier<Tensor<?>> provider ) {
        if ( !t.isVirtual() ) {
            Tensor<?> toBeAssigned = provider.get();
            MemUtil.keep(t, toBeAssigned,
                    () -> Neureka.get().backend().getFunction().idy().execute( t, toBeAssigned )
            );
        }
    }

    private static NDConfiguration _createNewNDCFrom(
            NDConfiguration old, int[] newTranslation
    ) {
        assert !old.isVirtual();
        return NDConfiguration.of(
                    old.shape(), newTranslation, old.indicesMap(), old.spread(), old.offset()
                );
    }

    private static void _checkLayoutConversion(
            NDConfiguration newConf,
            NDConfiguration oldConf,
            NDConfiguration.Layout targetLayout
    ) {
        if ( newConf.isVirtual() )
            throw new IllegalStateException("Layout conversion produced a virtual nd-configuration!");
        if ( !newConf.getLayout().isCompatible(targetLayout) )
            throw new IllegalArgumentException(
                    "Failed to convert this tensor from its original layout '"+oldConf.getLayout()+"' " +
                            "to target layout '"+targetLayout+"'. Instead this tensor has layout '"+newConf.getLayout()+"'."
            );
    }

}