FallbackAlgorithm.java

package neureka.backend.api.template.algorithms;

import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Operation;
import neureka.backend.api.fun.ADActionSupplier;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.fun.ExecutionPreparation;
import neureka.backend.api.Result;
import neureka.backend.main.implementations.CPUImplementation;
import neureka.backend.main.memory.MemUtil;
import neureka.backend.main.operations.linear.MatMul;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.math.parsing.ParseUtil;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.dtype.NumericType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.stream.Stream;

public final class FallbackAlgorithm extends AbstractDeviceAlgorithm<FallbackAlgorithm>
implements ExecutionPreparation, ADActionSupplier
{

    private static final Logger _LOG = LoggerFactory.getLogger(FallbackAlgorithm.class);

    public FallbackAlgorithm( String name, int arity, Operation type )
    {
        super( name );
        setImplementationFor(
                CPU.class,
                CPUImplementation
                    .withArity( arity )
                    .andImplementation(
                        call -> {
                            Function f = new FunctionParser( Neureka.get().backend() )
                                                .parse( type, call.arity() - 1, false );

                            boolean allNumeric = call.validate()
                                                        .all( t -> t.getDataType().typeClassImplements(NumericType.class) )
                                                        .isValid();

                            Class<?> typeClass = Stream.of( call.inputs() )
                                                        .map( t -> t.getDataType().getItemTypeClass() )
                                                        .findFirst()
                                                        .get();
                            if ( allNumeric )
                            {
                                double[] inputs = new double[ call.arity()-1 ];
                                call.getDevice()
                                    .getExecutor()
                                    .threaded(
                                        call.input( Number.class, 0 ).size(),
                                        ( start, end ) -> {
                                            for ( int i = start; i < end; i++ ) {
                                                for ( int ii = 0; ii < inputs.length; ii++ )
                                                    inputs[ ii ] = call.input( Number.class, 1 + ii ).at( i ).get().doubleValue();

                                                call.input( Number.class, 0 ).mut().set( i, f.call( inputs ) );
                                            }
                                        }
                                    );
                            }
                            else if ( typeClass == String.class && call.getOperation().getIdentifier().equals("add") )
                            {
                                call.getDevice()
                                    .getExecutor()
                                    .threaded(
                                        call.input( Object.class, 0 ).size(),
                                        ( start, end ) -> {
                                            for ( int i = start; i < end; i++ ) {
                                                StringBuilder b = new StringBuilder();
                                                for (int ii = 1; ii < call.arity(); ii++ ) {
                                                    b.append(call.input( Object.class, ii ).item(i));
                                                }
                                                setAt( call.input( Object.class, 0 ), i, b.toString() );
                                            }
                                        }
                                    );
                            }
                            else
                                _tryExecute(call, typeClass);

                            return call.input( typeClass, 0 );
                        }
                    )
        );
    }

    @Override
    public float isSuitableFor( ExecutionCall<? extends Device<?>> call ) {
        int[] shape = null;
        for ( Tensor<?> t : call.inputs() ) {
            if ( t != null ) {
                if ( shape == null ) shape = t.getNDConf().shape();
                else if ( !Arrays.equals(shape, t.getNDConf().shape()) ) return 0.0f;
            }
        }
        if ( call.getOperation().getClass() == MatMul.class ) return 0;
        return 0.5f;
    }

    @Override
    public ADAction supplyADActionFor( Function function, ExecutionCall<? extends Device<?>> call )
    {
        return ADAction( function, call );
    }

    public static ADAction ADAction( Function function, ExecutionCall<? extends Device<?>> call )
    {
        Tensor<?> derivative = (Tensor<?>) call.getValOf(Arg.Derivative.class);
        Function mul = Neureka.get().backend().getFunction().mul();
        if ( derivative != null )
            return ADAction.of( target -> mul.execute( target.error(), derivative ) );

        Tensor<?> localDerivative = MemUtil.keep( call.inputs(), () -> function.executeDerive( call.inputs(), call.getDerivativeIndex() ) );
        localDerivative.mut().setIsIntermediate( false );
        return ADAction.of( target -> mul.execute( target.error(), localDerivative ) );
        // TODO: Maybe delete local derivative??
    }

    public Tensor<?> dispatch(Function caller, ExecutionCall<? extends Device<?>> call ) {
        return AbstractDeviceAlgorithm.executeFor( caller, call, AbstractDeviceAlgorithm::executeDeviceAlgorithm );
    }

    @Override
    public ExecutionCall<? extends Device<?>> prepare( ExecutionCall<? extends Device<?>> call )
    {
        Device<Object> device = call.getDeviceFor(Object.class);
        if ( call.input( 0 ) == null ) // Creating a new tensor:
        {
            Shape shp = call.input( 1 ).shape();
            Tensor<Object> output;

            if ( call.getOperation().isInline() )
                output = call.input( Object.class, 1 );
            else
                output = (Tensor<Object>) Tensor.of( call.input( 1 ).getDataType(), shp )
                                                    .mut()
                                                    .setIsIntermediate(true);
            output.mut().setIsVirtual( false );
            try {
                device.store( output );
            } catch ( Exception e ) {
                e.printStackTrace();
            }
            call = call.withInputAt( 0, output );
        }
        return call;
    }

    private void _tryExecute( ExecutionCall<CPU> call, Class<?> typeClass ) {
        Method m = _findMethod( call.getOperation().getIdentifier(), typeClass );
        if ( m == null ) {
            switch (call.getOperation().getOperator()) {
                case "+": m = _findMethod("plus", typeClass);break;
                case "-": m = _findMethod("minus", typeClass);break;
                case "*":
                    m = _findMethod("times", typeClass);
                    if ( m == null ) m = _findMethod("multiply", typeClass);
                    if ( m == null ) m = _findMethod("mul", typeClass);
                    break;
                case "%": m = _findMethod("mod", typeClass);break;
            }
        }
        Method finalMethod = m;
        call
            .getDevice()
            .getExecutor()
            .threaded(
                call.input( Object.class, 0 ).size(),
                ( start, end ) -> {
                    Object[] inputs = new Object[ call.arity() - 1 ];
                    for ( int i = start; i < end; i++ ) {
                        for ( int ii = 0; ii < inputs.length; ii++ ) {
                            inputs[ ii ] = call.input( Object.class, 1 + ii ).item(i);
                        }
                        setAt( call.input( Object.class, 0 ), i, _tryExecute(finalMethod, inputs, 0));
                    }
                }
            );
    }

    private static void setAt(Tensor<Object> t, int i, Object o ) {
        t.mut().setDataAt( t.getNDConf().indexOfIndex( i ), o );
    }

    private static Object _tryExecute( Method m, Object[] args, int offset ) {
        if ( offset == args.length - 1 ) return args[offset];
        else {
            try {
                args[offset + 1] = m.invoke(args[offset], args[offset + 1]);
            } catch ( Exception e ) {
                _LOG.debug("Failed to execute method '"+m.getName()+"'. "+e.getMessage());
                return null;
            }
            return _tryExecute( m, args, offset + 1 );
        }
    }

    private static Method _findMethod( String name, Class<?> typeClass ) {
        try {
            return typeClass.getMethod(name, typeClass);
        } catch ( SecurityException e ) {
            e.printStackTrace();
        } catch ( NoSuchMethodException e ) {
            String message =
                    "Failed finding method named '"+name+"' on instance of type '"+typeClass.getSimpleName()+"'.\n" +
                    "Cause: "+e.getMessage();
            _LOG.debug(message);
        } finally {
            Method[] methods = typeClass.getDeclaredMethods();
            Method currentBest = null;
            double currentScore = 0;
            for ( Method m : methods ) {
                int numberOfParams = m.getParameterCount();
                Class<?> type = (numberOfParams == 0) ? null : m.getParameterTypes()[0];
                if ( numberOfParams == 1 && type == typeClass ) {
                    double score = ParseUtil.similarity( m.getName(), name );
                    if ( score > currentScore ) {
                        currentBest = m;
                        currentScore = score;
                    }
                }
            }
            if ( currentScore > 0.5 ) return currentBest;
        }
        return null;
    }

    @Override
    public AutoDiffMode autoDiffModeFrom(ExecutionCall<? extends Device<?>> call ) { return AutoDiffMode.FORWARD_AND_BACKWARD; }

    @Override
    public Result execute(Function caller, ExecutionCall<? extends Device<?>> call) {
        return Result.of(this.dispatch(caller, call)).withAutoDiff(this);
    }
}