Division.java

package neureka.backend.main.operations.operator;

import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Call;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.BiElementwise;
import neureka.backend.main.algorithms.BiScalarBroadcast;
import neureka.backend.main.algorithms.Broadcast;
import neureka.devices.Device;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.ndim.NDimensional;

import java.util.Arrays;


public class Division extends AbstractOperation
{
    public Division()
    {
        super(
            new OperationBuilder()
                .identifier(       "divide"   )
                .operator(         "/"        )
                .arity(            -1         )
                .isOperator(       true       )
                .isIndexer(        false      )
                .isDifferentiable( true       )
                .isInline(         false      )
        );

        setAlgorithm(
            BiElementwise.class,
            new BiElementwise()
            .setSupplyADActionFor( getDefaultAlgorithm() )
            .buildFunAlgorithm()
        );

        setAlgorithm(
                Broadcast.class,
                new Broadcast()
                .setAutogradModeFor(
                    call -> call
                            .validate().allNotNullHaveSame(NDimensional::shape)
                            .ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
                            .orElse(AutoDiffMode.BACKWARD_ONLY)
                )
                .setSupplyADActionFor(
                    ( Function f, ExecutionCall<? extends Device<?>> call ) ->
                    {
                        if ( call.autogradMode().allowsForward() )
                            throw new IllegalArgumentException("Broadcast implementation does not support forward-AD!");
                        Tensor<?> ctxDerivative = (Tensor<?>) call.getValOf(Arg.Derivative.class);
                        Function mul = Neureka.get().backend().getFunction().mul();
                        if ( ctxDerivative != null ) {
                            return ADAction.of( target -> mul.execute( target.error(), ctxDerivative ) );
                        }
                        int d = call.getDerivativeIndex();
                        Tensor<?> derivative = f.executeDerive( call.inputs(), d );
                        return ADAction.of( target -> mul.execute( target.error(), derivative ) );
                    }
                )
                .buildFunAlgorithm()
        );

        setAlgorithm(
            BiScalarBroadcast.class,
            new BiScalarBroadcast()
            .setIsSuitableFor( call -> SuitabilityPredicate.BAD )
            .setAutogradModeFor( call -> AutoDiffMode.FORWARD_AND_BACKWARD )
            .setExecution( (caller, call) -> Result.of(AbstractDeviceAlgorithm.executeFor(caller, call, AbstractDeviceAlgorithm::executeDeviceAlgorithm)).withAutoDiff( FallbackAlgorithm::ADAction ))
            .buildFunAlgorithm()
        );
    }

    @Override
    public Result execute( final Function caller, final ExecutionCall<?> call )
    {
        Function reducedCaller = reducePairwise( caller );

        int d = call.getDerivativeIndex();
        if ( !reducedCaller.isFlat() ) {
            if ( d < 0 ) {
                ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
                Arrays.stream(flatCall.inputs()).forEach( t -> t.mut().setIsIntermediate(false) );
                Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
                return super.execute( flat, flatCall );
            }
        }
        if ( d >= 0 ) {
            if ( !call.validate().all( (a, b) -> Util.canBeBroadcast(a.shape(), b.shape()) ).isValid() )
                throw new IllegalArgumentException("The shapes of the operands of the division operation must be equal or broadcast compatible! (when deriving nested functions)");

            // So here we assume that there are only two sub-functions: a/b

            Function noAd = Function.of( reducedCaller.toString(), false );
            Function a = noAd.getSubFunctions().get(0);
            Function b = noAd.getSubFunctions().get(1);
            boolean deriveA = a.dependsOn(d);
            boolean deriveB = b.dependsOn(d);

            if ( !deriveA && !deriveB ) return super.execute( reducedCaller, call );

            Tensor<?> bResult = b.call((Call) call.withArgs(Arg.DerivIdx.of(-1)));
            Tensor<?> derivOfA = null;
            if ( deriveA ) {
                Function div = Neureka.get().backend().getFunction().div();
                // This is simple, we just derive the first sub-function and multiply it with the inverse of the second sub-function:
                Tensor<?> aDeriv = a.call((Call)call);
                derivOfA = div.call((Tensor<Object>)aDeriv, (Tensor<Object>)bResult);
            }
            if ( !deriveB && deriveA )
                return Result.of(derivOfA.mut().setIsIntermediate(true));

            Tensor<?> aResult = a.call((Call)call.withArgs(Arg.DerivIdx.of(-1)));
            if ( deriveB )
                return _deriveB( call, b, deriveA, derivOfA, aResult, bResult );
        }
        return super.execute( reducedCaller, call );
    }

    private Result _deriveB(
            ExecutionCall<?> call,
            Function b,
            boolean deriveA,
            Tensor<?> derivOfA,
            Tensor<?> aResult,
            Tensor<?> bResult
    ) {
        Function mul = Neureka.get().backend().getFunction().mul();
        Tensor<?> innerDerivB = b.call((Call)call);
        // So we have something like this: a/b, where we want to derive b.
        // This is how it is really executed:  (a/b) = (a * (1/b))
        // So we can derive b and then later on add the derivative of 'a' to it (if it must be derived).
        // The derivative of 1/b is -1/b^2
        // Let's derive b:
        Function derive = Function.of("-I[0] / (I[1] ** 2)", false);
        Tensor<?> derivOfB = derive.call( (Tensor<Object>)innerDerivB, (Tensor<Object>)bResult );
        derivOfB = mul.call((Tensor<Object>)aResult, (Tensor<Object>)derivOfB);
        if ( !deriveA )
            return Result.of(derivOfB.mut().setIsIntermediate(true));
        else {
            Function add = Neureka.get().backend().getFunction().add();
            return Result.of( add.call((Tensor<Object>)derivOfA, (Tensor<Object>)derivOfB).mut().setIsIntermediate(true) );
        }
    }

    private Function reducePairwise( Function fun ) {
        Function reduced = fun;
        if ( reduced.getSubFunctions().size() > 2 ) {
            /*
                So currently we have something like this: a/b/c/d...
                However, this is how it is really executed:  ((((a/b)/c)/d)..)
                ...so let's create a function that is nested like the above:
            */
            Function nested = reduced.getSubFunctions().get(0);
            for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
                nested = Function.of( nested + " / " + reduced.getSubFunctions().get(i), true );

            reduced = nested;
        }
        return reduced;
    }

    @Override
    public String asDerivative( Function[] children, int derivationIndex) {
        return _asDerivative( children, derivationIndex, children.length - 1 );
    }

    private String _asDerivative( Function[] children, int d, int index ) {
        if ( d >= 0 ) {
            if ( index <= 0 ) return children[ 0 ].getDerivative( d ).toString();
            else {
                String first = ( children[ index - 1 ].dependsOn( d ) )
                        ? "(" + _asDerivative( children, d, index - 1 )+ " / " + children[ index ]  + " )"
                        : "";

                if ( !children[ index ].dependsOn(d) ) return first;
                String s = children[ index - 1 ].toString();
                if ( s.equals("0.0") ) return first;

                return first +
                        " - ((" + // The second expression is the inner derivative (current index)! (inner times outer...)
                            s + " * " + children[ index ].getDerivative(d) +
                        ") / ( "
                            + children[ index ] + "**2 " +
                        ") )";
            }
        } else {
            if ( index <= 0 ) return children[ 0 ].toString();
            else
                return _asDerivative( children, -1, index - 1 ) + " / " + children[ index ].toString();
        }
    }

    @Override
    public double calculate( double[] inputs, int j, int d, Function[] src ) {
        if ( j < 0 ) return calculate( inputs, d, src );
        if ( d < 0 ) {
            double result = src[ 0 ].call( inputs, j );
            for ( int i = 1; i < src.length; i++ ) {
                final double current = src[ i ].call( inputs, j );
                result /= current;
            }
            return result;
        } else {
            double u, ud, v, vd;
            u = src[ 0 ].call( inputs, j );
            ud = src[ 0 ].derive( inputs, d, j );
            for ( int i = 0; i < src.length - 1; i++ ) {
                v = src[ i + 1 ].call( inputs, j );
                vd = src[ i + 1 ].derive( inputs, d, j );
                ud = (ud * v - u * vd) / Math.pow(v, 2);
                u /= v;
            }
            return ud;
        }
    }

    
    public static double calculate( double[] inputs, int d, Function[] src ) {
        if ( d < 0 ) {
            double result = src[ 0 ].call( inputs );
            for ( int i = 1; i < src.length; i++ ) {
                final double current = src[ i ].call( inputs );
                result /= current;
            }
            return result;
        } else {
            double derivative;
            double tempVar = src[ 0 ].call( inputs );
            derivative = src[ 0 ].derive( inputs, d );

            for ( int i = 0; i < src.length - 1; i++ ) {
                double u, ud, v, vd;
                v = src[ i + 1 ].call( inputs );
                vd = src[ i + 1 ].derive( inputs, d );
                u = tempVar;
                ud = derivative;
                derivative = ( ud * v - u * vd ) / Math.pow(v, 2);
                tempVar /= v;
            }
            return derivative;
        }
    }




}