GraphNodeUtility.java

package neureka.autograd;

import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.AutoDiffMode;
import neureka.devices.Device;

/**
 *  This class exists in order to allow for {@link GraphNode}s to be instantiated
 *  with final field variables by collecting them when defined
 *  within constructor methods...
 */
final class GraphNodeUtility {

    private GraphNodeUtility() {}

    /**
     *  Evaluates and sets the auto-grad/auto-differentiation mode:
     *  A positive value means that the AD-procedure will be forward mode AD,
     *  whereas a negative value is backward mode AD.
     *  If the resulting mode equals 0 then this means that no auto differentiation is needed.
     *  This class tries to optimize the calculation of partial derivatives by forward propagating them
     *  for as long as only a single input for every computation graph node requires gradients,
     *  and they all are differentiable!
     *
     */
    public static <V> int modeOf( AutoDiffMode adMode, ExecutionCall<? extends Device<?>> call )
    {
        Tensor<V>[] inputs = (Tensor<V>[]) call.inputs();
        int resultMode = 0;
        int[] modes = new int[ inputs.length ];
        int inputMode = 0;
        for ( int i = 0; i < inputs.length; i++ ) {
            GraphNode<V> node = inputs[ i ].getGraphNode().orElseThrow(IllegalStateException::new); // Not null checked in constructor!
            modes[ i ] = ( inputs[ i ].rqsGradient() ) ? 1 : node.getMode();
            inputMode += ( modes[ i ] != 0) ? 1 : 0;
        }
        if ( inputMode == 1 && adMode.allowsForward() ) { // Convolution and reshaping prohibit forward AutoDiff
            for ( int i = 0; i < inputs.length; i++ ) {
                resultMode +=
                        ( modes[ i ] == 0 )
                                ? 0
                                : ( modes[ i ] < 0 ) ? 1 : modes[ i ] + 1;
            }
        } // Reverse mode auto-differentiation :
        else if ( adMode.allowsBackward() ) resultMode = -inputMode;

        return resultMode;
    }

}