Reshape.java

package neureka.backend.main.operations.other;

import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.framing.Relation;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.NDConstructor;
import neureka.ndim.config.NDConfiguration;

public class Reshape extends AbstractOperation
{
    public Reshape()
    {
        super(
            new OperationBuilder()
                .identifier(       "reshape"  )
                .operator(         "reshape"  )
                .arity(            1          )
                .isOperator(       false      )
                .isIndexer(        false      )
                .isDifferentiable( true       )
                .isInline(         false      )
        );
        setAlgorithm(
            Algorithm
            .withName( "reshape" )
            .setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
            .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
            .setExecution(
                ( caller, call ) ->
                {
                    Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
                    Tensor<Object> input = (Tensor<Object>) inputs[0];

                    int[] foundShape = call.getValOf( Arg.Shape.class );

                    if ( foundShape == null )
                        throw new IllegalArgumentException("Shape argument is missing!");

                    int[] shape = _resolveNewShape(input.size(), foundShape);

                    Tensor reshaped = Tensor.of(
                                    input.getDataType(),
                                    NDConstructor.of( shape ),
                                    input.mut().getData()
                                );

                    reshaped.set( Relation.newChildToParent( input ) );
                    Relation parent = input.find( Relation.class ).orElseGet(Relation::newParentToChildren);
                    parent.addChild( reshaped );
                    input.set( parent );

                    if ( input.isOutsourced() )
                        input.getDevice().store( reshaped );

                    NDConfiguration originalConfig = input.getNDConf();

                    return Result.of(reshaped.mut().setIsIntermediate(true))
                            .withADAction( target -> {
                                Tensor<Object> error = (Tensor<Object>) target.error();
                                return Tensor.of(
                                        error.getDataType(),
                                        NDConstructor.of( originalConfig ),
                                        error.mut().getData()
                                    );
                            });
                }
            )
            .buildFunAlgorithm()
        );
    }

    /**
     *   If the provided shape array contains a -1 as one of its elements,
     *   then this method will resolve the -1 to the correct value
     *   which results in a shape array which is compatible with the provided size,
     *   meaning that when we multiply all the elements of the resolved shape array
     *   we will get the provided size.
     *
     * @param size The size which the resolved shape array should be compatible with.
     * @param shape The shape array which may contain a -1.
     * @return The resolved shape array.
     */
    private static int[] _resolveNewShape( int size, int[] shape )
    {
        int[] resolvedShape = new int[ shape.length ];
        int minusOneIndex = -1;
        int minusOneCount = 0;
        int product = 1;
        for ( int i = 0; i < shape.length; i++ )
        {
            if ( shape[ i ] == -1 )
            {
                minusOneIndex = i;
                minusOneCount++;
            }
            else
            {
                resolvedShape[ i ] = shape[ i ];
                product *= shape[ i ];
            }
        }
        if ( minusOneCount > 1 )
            throw new IllegalArgumentException("The shape array contains more than one -1!");
        if ( minusOneCount == 1 )
            resolvedShape[ minusOneIndex ] = size / product;
        return resolvedShape;
    }

    @Override
    public double calculate( double[] inputs, int j, int d, Function[] src )
    {
        return src[ 0 ].call( inputs, j );
    }
}