DimFit.java

package neureka.backend.main.operations.other;

import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;

public class DimFit extends AbstractOperation
{
    public DimFit()
    {
        super(
            new OperationBuilder()
                    .identifier(       "dimfit"    )
                    .operator(         "dimfit"    )
                    .arity(            -1          )
                    .isOperator(       false       )
                    .isIndexer(        false       )
                    .isDifferentiable( true        )
                    .isInline(         false       )
        );

        setAlgorithm(
            Algorithm
            .withName("dimFit")
            .setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
            .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
            .setExecution(
                ( caller, call ) ->
                {
                    assert call.getValOf( Arg.DerivIdx.class ) < 0;
                    Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten( caller, call ).inputs();

                    int largest = -1;
                    int[] shape = null;
                    for ( Tensor<?> t : inputs ) if ( t.rank() > largest ) {
                        largest = t.rank();
                        shape = t.getNDConf().shape();
                    }
                    int prefix = 0;
                    for ( int s : shape ) if ( s == 1 ) prefix++; else break;
                    int postfix = 0;
                    for ( int i = shape.length-1; i>=0; i-- ) if ( shape[ i ] == 1 ) postfix++; else break;

                    int[][] change = new int[inputs.length][];

                    for ( int i=0; i<inputs.length; i++)
                    {
                        if ( inputs[ i ].rank()!=largest)
                        {
                            int[] oldShape = inputs[ i ].getNDConf().shape();
                            int[] newReshape = new int[largest];
                            int padding = largest-oldShape.length;

                            int handle = ( postfix <= prefix )? padding : largest-padding;
                            for ( int ii = 0; ii < handle; ii++ ) newReshape[ ii ]      = ( postfix <= prefix )? -1 : ii;
                            for ( int ii = handle; ii < largest; ii++) newReshape[ ii ] = ( postfix <= prefix )? ii-padding : -1;

                            change[ i ] = newReshape;
                        }
                    }
                    return Result.of(null).withADAction(null);
                }
            )
            .buildFunAlgorithm()
        );
    }

    @Override
    public double calculate( double[] inputs, int j, int d, Function[] src ) {
        return src[ 0 ].call( inputs, j );
    }
}