DefaultADAction.java
package neureka.autograd;
import neureka.Tensor;
import java.util.Optional;
/**
* {@link ADAction} stands for "Auto-Differentiation-Agent", meaning
* that implementations of this class are responsible for managing
* forward- and reverse- mode differentiation actions.
* These differentiation actions are performed through the "{@link ADAction#act(ADTarget)}"
* method which are being called
* by instances of the {@link GraphNode} class during propagation.
* An {@link ADAction} may also wrap and expose a partial derivative
* which may or may not be present for certain operations.
* <br>
* This class stores implementations for the propagation method
* inside the agent as a lambda instance. <br>
*
* So in essence this class is a container for a lambda as well as an optional derivative.
*/
final class DefaultADAction implements ADAction
{
/**
* This lambda ought to perform the forward or backward propagation
* for the concrete {@link neureka.backend.api.ImplementationFor} of a {@link neureka.devices.Device}.
*/
private final ADAction _action;
private final Tensor<?> _partialDerivative;
DefaultADAction( ADAction action, Tensor<?> derivative ) { _action = action; _partialDerivative = derivative; }
@Override
public Tensor<?> act(ADTarget<?> target ) {
if ( _action == null )
throw new IllegalStateException(
"Cannot perform propagation because this "+ADAction.class.getSimpleName()+" does have an auto-diff implementation."
);
return _action.act( target );
}
@Override
public Optional<Tensor<?>> partialDerivative() {
if ( _partialDerivative != null )
return Optional.of( _partialDerivative );
Tensor<?>[] captured = _action.findCaptured();
if ( captured.length > 0 )
return Optional.of(captured[captured.length - 1]);
return Optional.empty();
}
/**
* An {@link ADAction} also contains a context of variables which have been
* passed to it by an {@link neureka.backend.api.ExecutionCall}.
* A given {@link neureka.backend.api.ExecutionCall} will itself have gathered the context
* variables within a given backend implementation, more specifically an {@link neureka.backend.api.Operation}.
* These variables are used by an implementation of the {@link neureka.backend.api.Operation} to perform auto differentiation
* or to facilitate further configuration of an {@link neureka.backend.api.ExecutionCall}.
* This method lets us view the current state of these variables for this agent in the form of
* a nice {@link String}...
*
* @return A String view of this {@link ADAction}.
*/
@Override
public String toString() {
if ( this.partialDerivative().isPresent() ) return partialDerivative().get().toString();
return "";
}
}