ConvUtil.java

package neureka.backend.main.operations;

import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.main.algorithms.NDConvolution;
import neureka.backend.main.operations.other.Permute;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.devices.Device;

public class ConvUtil
{
    public static NDConvolution createDeconvolutionFor(String op ) {
        return new NDConvolution()
                .setAutogradModeFor( call -> {
                    if ( call.getOperation().supports( NDConvolution.class ) ) return AutoDiffMode.BACKWARD_ONLY;
                    Tensor<?> last = null;
                    for ( Tensor<?> t : call.inputs() ) {
                        if ( last != null && !last.shape().equals(t.shape()) ) return AutoDiffMode.BACKWARD_ONLY;
                        last = t; // Note: shapes are cached!
                    }
                    return AutoDiffMode.FORWARD_AND_BACKWARD;
                })
                .setExecution(
                    (outerCaller, outerCall) ->
                    Result.of(AbstractDeviceAlgorithm.executeFor(
                        outerCaller, outerCall,
                        call ->
                        {
                            int offset = ( call.input(0) == null ? 1 : 0 );
                            Tensor<?>[] tensors = new Tensor[]{call.input(offset+0), call.input(offset+1), call.input(offset+2)};
                            Permute.makeFit(tensors, false); // This might not fit here... (fitting should probably be a setup thing...)
                            for ( Tensor<?> t : tensors ) t.mut().setIsVirtual( false );
                            return AbstractDeviceAlgorithm.prepareAndExecute(
                                    ExecutionCall.of( tensors )
                                            .andArgs( Arg.DerivIdx.of(0) )
                                            .running( call.getOperation() )
                                            .on( call.getDevice() ),
                                    a -> ConvUtil.executeRecursively(op, a)
                            );
                        }
                    ))
                    .withAutoDiff( ( Function f, ExecutionCall<? extends Device<?>> adCall ) -> {
                        throw new UnsupportedOperationException("Not yet implemented!");
                    } )
                )
                .setCallPreparation(
                     call -> {
                         if ( call.input( 0 ) == null )
                             return call.withRemovedInputAt( 0 );
                         return call;
                     }
                )
                .buildFunAlgorithm();
    }

    public static Shape shapeOfCon(int[] shape1, int[] shape2 ) {
        int[] shape = new int[ ( shape1.length + shape2.length ) / 2 ];
        for ( int i = 0; i < shape1.length && i < shape2.length; i++ )
            shape[ i ] = Math.abs( shape1[ i ] - shape2[ i ] ) + 1;
        return Shape.of(shape);
    }
    
    public static Tensor<?> executeRecursively(
            String op,
            ExecutionCall<? extends Device<?>> call
    ) {
        int d = call.getValOf( Arg.DerivIdx.class );
        if ( op.equals("x") ) {
            if ( d >= 0 ) {
                if ( d == 0 )
                    call = call.withInputAt( 0, call.input( 2 ) );
                else
                    call = call.withInputAt( 0, call.input( 1 ) );
                return
                    call.input( 0 );
            } else {
                call.rearrangeInputs( 0, 1, 2 );
            }
        } else if ( op.equals("x"+ ((char) 187)) ) {
            call.rearrangeInputs( 2, 1, 0 );
        }
        return AbstractDeviceAlgorithm.executeDeviceAlgorithm( call );
    }

}