FunctionInput.java

package neureka.math.implementations;


import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.args.Args;
import neureka.math.parsing.FunctionParser;

import java.util.ArrayList;
import java.util.List;

/**
 *  Instances of this implementation of the {@link Function} interface
 *  are leave nodes within the abstract syntax tree of a function, representing inputs to a function.
 *  When parsing an expression into a function then these inputs are recognized by the character 'i' or 'I',
 *  followed by a whole number starting at zero (optionally wrapped by '[' and ']'), which is the index
 *  of the argument within the list/array of arguments passed to a concrete {@link Function} instance. <br>
 *  So for example, when creating a function by calling the following factory method...     <br>
 *                                                                                          <br>
 *  {@link Function#of}( "I[1] + (4 * I[0]) / 2" )                                      <br>
 *                                                                                          <br>
 *  ...then the substrings "I[1]" and "I[0]" will be parsed into instances of this class!   <br>
 *  When calling this function by passing two arguments, let's say (first, second) then
 *  the {@link FunctionInput} "I[0]" will pick the first argument, whereas "I[1]"
 *  will pick the second argument when evaluating the array of arguments.
 *
 */
public class FunctionInput implements Function, GradientProvider
{
    private final int _index;


    public static Function of(String equation, boolean doAD) {
        if ( equation.charAt( 0 ) == '-' )
            return new FunctionParser( Neureka.get().backend() )
                    .parse(
                            equation.substring(1)+"*-1",
                            doAD
                    ); // TODO: This might be false!
        int number = 0;
        for ( int i = 0; i < equation.length(); ++i) {
            if ( equation.charAt( i ) <= '9' && equation.charAt( i ) >= '0' ) {
                number *= 10;
                number += Integer.parseInt(equation.charAt( i ) + "");
            }
        }
        if ( equation.contains("g") ) {
            number = -( number + 1 );
        }
        return new FunctionInput(number);
    }


    private FunctionInput( int number ) { _index = number; }


    public int index() { return ( this.providesGradient() ? ( Math.abs(_index) - 1 ) : _index ); }

    @Override public boolean providesGradient() { return ( _index < 0 ); }

    @Override public boolean isFlat() { return true; }

    @Override public boolean isDoingAD() { return false; }

    @Override public AbstractOperation getOperation() { return null; }

    @Override public boolean dependsOn( int index ) { return index() == index; }

    @Override public Function getDerivative( int index ) { return ( index == _index ) ? Function.of( "1" ) : Function.of( "0" ); }

    @Override public List<Function> getSubFunctions() { return new ArrayList<>(); }

    private Tensor<?> _extract(Tensor<?> t )
    {
        if ( this.providesGradient() ) {
            Tensor<?> gradient = t.gradient().orElse(null);
            if ( t.rqsGradient() ) {
                if ( gradient == null ) {
                    gradient = Tensor.of( (Class<? extends Number>) t.getItemType(), t.shape(), 0.0 );
                    t.set( (Tensor) gradient );
                }
                return gradient;
            }
            throw new IllegalArgumentException(
                    "The provided tensor does not require gradients, this function input however " +
                    "expects to receive such tensors (gradient receivers)."
            );
        }
        return t;
    }

    @Override
    public double call( final double[] inputs, int j ) {
        if ( j < 0 ) {
            return inputs[ ( _index >= 0 ) ? _index : ( Math.abs( _index ) - 1 ) ];
        }
        return inputs[index()];
    }

    @Override public double derive( final double[] inputs, final int index ) { return ( index == index() ) ? 1 : 0; }

    @Override
    public double derive( double[] inputs, int index, int j ) {
        if ( j < 0 || j == index() )
            return derive( inputs, index );
        else
            return 0;
    }

    @Override
    public Tensor<?> execute(Args arguments, Tensor<?>... inputs ) {
        int d = ( arguments.has(Arg.DerivIdx.class) ? arguments.valOf(Arg.DerivIdx.class) : -1 );
        if ( d >= 0 )
            return ( d == index() )
                ? Tensor.of( (Class<? extends Number>) inputs[ 0 ].getItemType(), inputs[ 0 ].shape(), 1.0 ).getMut().setIsIntermediate( true )
                : Tensor.of( (Class<? extends Number>) inputs[ 0 ].getItemType(), inputs[ 0 ].shape(), 0.0 ).getMut().setIsIntermediate( true );

        if ( index() >= inputs.length )
            throw new IllegalArgumentException(
                "Function input '"+index()+"' not satisfied! " +
                "Please supply at least "+(index()+1)+" input tensors."
            );

        return _extract( inputs[ index() ] );
    }

    @Override
    public String toString() { return "I" + ( this.providesGradient() ? "g" : "" ) + "[" + index() + "]"; }

}