BiElementwise.java

package neureka.backend.main.algorithms;

import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.devices.Device;
import neureka.dtype.NumericType;
import neureka.ndim.NDimensional;

public final class BiElementwise extends AbstractFunDeviceAlgorithm<BiElementwise>
{
    public BiElementwise() {
        super("elementwise");
        setIsSuitableFor(
            call -> call
                    .validate()
                    .allNotNullHaveSame(NDimensional::size)
                    .allNotNullHaveSame(NDimensional::shape)
                    .allNotNull( t -> t.getDataType().typeClassImplements( NumericType.class ) )
                    .basicSuitability()
        );
        setAutogradModeFor( call -> AutoDiffMode.FORWARD_AND_BACKWARD );
        setExecution(
            (outerCaller, outerCall) ->
                Result.of(AbstractDeviceAlgorithm.executeFor(
                        outerCaller, outerCall,
                        innerCall -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( innerCall )
                ))
            //(outerCaller, outerCall) -> {
            //    ExecutionCall<? extends Device<?>> finalOuterCall = _prepare(outerCall);
            //    return Result.of(executeOnCommonDevice(finalOuterCall, ()->{
            //        return AbstractDeviceAlgorithm.executeDeviceAlgorithm(finalOuterCall);
            //    }));
            //}
        );
        setCallPreparation(this::_prepare);
    }

    private ExecutionCall<?> _prepare( final ExecutionCall<?> inputCall ) {
        ExecutionCall<?> call = inputCall;
        if ( call.arity() < 3 ) call = call.withAddedInputAt(0, null);
        Device<Object> device = (Device<Object>) call.getDevice();
        if ( call.input( 0 ) == null ) // Creating a new tensor:
        {
            int[] outShape = call.input( 1 ).getNDConf().shape();

            Class<Object> type = (Class<Object>) call.input(  1 ).getItemType();
            Tensor<Object> output = Tensor.of( type ).withShape( outShape ).all( 0.0 ).mut().setIsIntermediate( true );
            output.mut().setIsVirtual( false );
            try {
                device.store( output );
            } catch( Exception e ) {
                e.printStackTrace();
            }
            call = call.withInputAt( 0, output );
        }
        return call;
    }

}