AbstractActivationOperation.java

package neureka.backend.main.operations.functions;

import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Operation;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.ElementwiseAlgorithm;
import neureka.backend.main.algorithms.ScalarAlgorithm;
import neureka.backend.main.algorithms.ScalarBroadcast;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;

abstract class AbstractActivationOperation extends AbstractOperation
{
    private final ScalarFun _fun;

    AbstractActivationOperation(ScalarFun fun)
    {
        super(
            new OperationBuilder()
                .identifier(      fun.id()       )
                .operator(        fun.id()       )
                .arity(            1             )
                .isOperator(       false         )
                .isIndexer(        false         )
                .isDifferentiable( true          )
                .isInline(         false         )
        );

        _fun = fun;

        setAlgorithm(
            new ElementwiseAlgorithm()
                .setSupplyADActionFor( getDefaultAlgorithm() )
                .buildFunAlgorithm()
        );

        setAlgorithm(
            new ScalarBroadcast(fun).buildFunAlgorithm()
        );

        setAlgorithm(
            new ScalarAlgorithm().buildFunAlgorithm()
        );
    }

    @Override
    public Result execute(Function caller, ExecutionCall<?> call )
    {
        if ( !caller.isFlat() ) {
            int d = call.getDerivativeIndex();
            ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( caller, call.withArgs(Arg.DerivIdx.of(-1)) );
            if ( d < 0 ) {
                Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
                return super.execute( flat, flatCall );
            } else {
                Function noAdFun = Function.of( caller.toString(), false );
                Function innerFun = noAdFun.getSubFunctions().get(0);
                Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), false );
                // The user wants the derivative! So we need to do inner times outer derivative! (because the function is not flat)
                ExecutionCall<?> inner = AbstractDeviceAlgorithm.flatten( noAdFun, call.withArgs(Arg.DerivIdx.of(-1)) );
                Result innerDerivResult = innerFun.getOperation().execute( innerFun, call.withOperation(innerFun.getOperation()) );
                Tensor<?> innerDeriv = innerDerivResult.get();
                Tensor<?> outerDeriv = super.execute( flat, inner.withArgs(Arg.DerivIdx.of(0)) ).get();
                Operation mul = Neureka.get().backend().getOperation("*");
                Function opFun = new FunctionParser(Neureka.get().backend()).parse( mul, 2, false );
                return mul.execute(
                            opFun,
                            ExecutionCall.of( innerDeriv, outerDeriv )
                                            .running(mul)
                                            .on(call.getDevice())
                        );
            }
        }
        return super.execute( caller, call );
    }

    @Override
    public final String stringify( String[] children ) {
        String expression = String.join( ", ", children );
        if ( expression.startsWith("(") && expression.endsWith(")") ) return getIdentifier() + expression;
        return getIdentifier() + "(" + expression + ")";
    }

    @Override
    public final double calculate( double[] inputs, int j, int d, Function[] src ) {
        boolean derive = d >= 0;
        double inner = ( !derive ? 1 : src[ 0 ].derive( inputs, d, j ) );
        return _fun.calculate( src[ 0 ].call( inputs, j ),  derive ) * inner;
    }

}