FunctionVariable.java
package neureka.math.implementations;
import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.args.Args;
import neureka.math.parsing.FunctionParser;
import java.util.ArrayList;
import java.util.List;
/**
* Instances of this implementation of the {@link Function} interface
* are leave nodes within the abstract syntax tree of a function, representing indexed inputs to a function.
* When parsing an expression into a function then these inputs are recognized by the character 'i' or 'I',
* followed by the character 'j' or 'J' (optionally wrapped by '[' and ']'), which is a placeholder for the index
* of the argument within the list/array of arguments passed to a concrete {@link Function} instance. <br>
* So for example, when creating a function by calling the following factory method... <br>
* <br>
* {@link Function#of}( "3 * sum( (I[j] + 4) * I[0] )" ) <br>
* <br>
* ...then the substrings "I[j]" will be parsed into instances of this class! <br>
* The substring "I[0]" on the other hand will not be parsed into an instance of this class!
*/
public final class FunctionVariable implements Function, GradientProvider
{
private final boolean _providesGradient;
public FunctionVariable( String equation ) { _providesGradient = equation.contains("g"); }
@Override public boolean providesGradient() { return _providesGradient; }
@Override public boolean isFlat() { return true; }
@Override public boolean isDoingAD() { return false; }
@Override public AbstractOperation getOperation() { return null; }
@Override public boolean dependsOn( int index ) { return true; }
@Override public Function getDerivative( int index ) { return Function.of( "1" ); }
@Override public List<Function> getSubFunctions() { return new ArrayList<>(); }
@Override
public double call( final double[] inputs, int j ) {
if ( j < 0 ) {
double sum = 0;
for ( int i = 0; i < inputs.length; i++ ) sum += call(inputs, i);
return sum;
}
return inputs[j];
}
@Override
public double derive( final double[] inputs, final int index ) { return 1.0; }
@Override
public double derive( double[] inputs, int index, int j ) {
if ( j != index ) return 0;
return derive( inputs, index );
}
@Override
public Tensor<?> execute(Args arguments, Tensor<?>... inputs ) {
int d = ( arguments.has(Arg.DerivIdx.class) ? arguments.valOf(Arg.DerivIdx.class) : -1 );
int j = ( arguments.has(Arg.VarIdx.class) ? arguments.valOf(Arg.VarIdx.class) : -1 );
if ( d >= 0 ) {
if ( j < 0 )
return Tensor.of( inputs[ 0 ].shape(), 1.0 ).getMut().setIsIntermediate( true );
return j != d ? Tensor.of( inputs[ 0 ].shape(), 0.0 ).getMut().setIsIntermediate( true ) : executeDerive(inputs, d );
}
if ( j < 0 ) {
StringBuilder exp = new StringBuilder("I[ 0 ]");
for (int i = 1; i < inputs.length; i++ )
exp.append("+I[").append(i).append("]");
return new FunctionParser( Neureka.get().backend() )
.parse(exp.toString(), false)
.execute(inputs);
}
return inputs[j];
}
@Override
public String toString() { return "I" + ( (this.providesGradient()) ? "g" : "" ) + "[j]"; }
}