Convolution.java

package neureka.backend.main.operations.linear;

import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.NDConvolution;
import neureka.backend.main.operations.ConvUtil;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.devices.Device;

public class Convolution extends AbstractOperation
{
    public Convolution()
    {
        super(
            new OperationBuilder()
                .identifier(       "mul_conv"  )
                .operator(         "x"         )
                .arity(            2           )
                .isOperator(       true        )
                .isIndexer(        false       )
                .isDifferentiable( true        )
                .isInline(         false       )
        );

        setAlgorithm(
            NDConvolution.class,
            new NDConvolution()
            .setAutogradModeFor( call -> {
                if ( call.getOperation().supports( NDConvolution.class ) ) return AutoDiffMode.BACKWARD_ONLY;
                Tensor<?> last = null;
                for ( Tensor<?> t : call.inputs() ) {
                    if ( last != null && !last.shape().equals(t.shape()) ) return AutoDiffMode.BACKWARD_ONLY;
                    last = t; // Note: shapes are cached!
                }
                return AutoDiffMode.FORWARD_AND_BACKWARD;
            })
            .setExecution(
                (outerCaller, outerCall) ->
                    Result.of(AbstractDeviceAlgorithm.prepareAndExecute(
                        outerCall,
                        call ->
                                AbstractDeviceAlgorithm.executeDeviceAlgorithm(
                                        call
                                )
                    ))
                    .withAutoDiff(( Function f, ExecutionCall<? extends Device<?>> adCall ) ->
                    {
                        int d = adCall.getDerivativeIndex();
                        Function deConv = new FunctionParser( Neureka.get().backend() ).parse(
                                "I[ 0 ] x>> I[ 1 ] x>> I[ 2 ]",
                                false
                        );
                        Tensor<?> derivative = f.derive( (Tensor[]) adCall.inputs(), d );
                        assert d >= 0 && d <= 1;
                        assert derivative != null;
                        assert deConv != null;
                        assert adCall.arity() >= 2 && adCall.arity() <= 3;
                        // Now we need to remember the shape of the input which is targeted for back prop.
                        Shape shape = Shape.of(adCall.input( adCall.arity() > 2 ? d + 1 : d ).getNDConf().shape());
                        Number zero;
                        if ( derivative.getItemType() == Double.class         ) zero = 0d;
                        else if ( derivative.getItemType() == Float.class     ) zero = 0f;
                        else if ( derivative.getItemType() == Integer.class   ) zero = 0;
                        else if ( derivative.getItemType() == Long.class      ) zero = 0L;
                        else if ( derivative.getItemType() == Short.class     ) zero = (short) 0;
                        else if ( derivative.getItemType() == Byte.class      ) zero = (byte) 0;
                        else {
                            zero = null;
                            throw new IllegalArgumentException("Unsupported item type for convolution derivative: " + derivative.getItemType());
                        }
                        // This is because it will be the shape of the output to the de-convolution!
                        return ADAction.of( target ->
                                deConv.execute(
                                        target.error(),
                                        derivative,
                                        Tensor.of(shape, zero).mut().setIsIntermediate( false )
                                )
                        );
                    })
            )
            .setCallPreparation(
                 call -> {
                     if ( call.arity() <= 2 ) call = call.withAddedInputAt( 0, null );
                     Device<Number> device = call.getDeviceFor(Number.class);
                     Shape shp = ConvUtil.shapeOfCon(call.input( 1 ).getNDConf().shape(), call.input( 2 ).getNDConf().shape());
                     Tensor<Number> output = (Tensor<Number>) Tensor.of( call.input(1).getItemType(), shp, 0 )
                                                             .mut()
                                                             .setIsIntermediate( true );
                     output.mut().setIsVirtual( false );
                     //device.store( output );//Todo: find out why this causes problems
                     return call.withInputAt( 0, output );
                 }
            )
            .buildFunAlgorithm()
        );

    }


    @Override
    public Result execute( final Function caller, final ExecutionCall<?> call )
    {
        if ( !caller.isFlat() ) {
            Function reducedCaller = reducePairwise(caller);
            ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
            Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
            for ( Tensor<?> t : flatCall.inputs() ) if ( t != null ) t.mut().setIsIntermediate(false);
            return this.execute( flat, flatCall );
        }
        if ( call.getDerivativeIndex() >= 0 ) {
            int d = call.getDerivativeIndex();
            /*
                In autograd convolution is similar to matrix multiplication.
                If the derivative index is 0 then the second operand is used for backward broadcasting.
                If the derivative index is 1 then the first operand is used for backward broadcasting.
             */
            return Result.of( call.input( d == 0 ? 1 : 0 ) );
        }
        Function reducedCaller = reducePairwise(caller);
        ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
        Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
        for ( Tensor<?> t : flatCall.inputs() ) if ( t != null ) t.mut().setIsIntermediate(false);
        return super.execute( flat, flatCall );
    }

    private Function reducePairwise( final Function fun ) {
        Function reduced = fun;
        if ( reduced.getSubFunctions().size() > 2 ) {
            /*
                So currently we have something like this: a x b x c x d...
                However, this is how it is really executed:  ((((a x b) x c) x d)..)
                ...so let's create a function that is nested like the above:
            */
            Function nested = reduced.getSubFunctions().get(0);
            for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
                nested = Function.of( nested + " x " + reduced.getSubFunctions().get(i), true );

            reduced = nested;
        }
        return reduced;
    }

    @Override
    public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }
}