ADAction.java
package neureka.autograd;
import neureka.Tensor;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* This interface is the declaration for
* lambda actions for both the {@link #act(ADTarget)} method of the {@link ADAction} interface. <br><br>
* Implementations of this perform auto-differentiation forwards or backwards along the computation graph.
* These differentiation actions are performed through the "{@link ADAction#act(ADTarget)}"
* method which are being called
* by instances of the {@link GraphNode} class during propagation.
* An {@link ADAction} may also wrap and expose a partial derivative
* which may or may not be present for certain operations.
* Said derivative must be tracked and flagged as derivative by a {@link GraphNode}
* to make sure that it will not be deleted after a forward pass.
* <p>
* Note: Do not access the {@link GraphNode#getPayload()} of the {@link GraphNode}
* passed to implementation of this.
* The payload is weakly referenced, meaning that this method can return null!
*/
@FunctionalInterface
public interface ADAction
{
static ADAction of( ADAction action ) { return new DefaultADAction( action, null ); }
static ADAction of(Tensor<?> derivative, ADAction action ) { return new DefaultADAction( action, derivative ); }
/**
* The auto-differentiation forward or backward pass of an ADAction
* propagate partial differentiations forward along the computation graph.
*
* @param target A wrapper for the {@link GraphNode} at which the differentiation ought to
* be performed and error which ought to be used for the forward or backward differentiation.
* @return The result of a forward or backward mode auto differentiation.
*/
Tensor<?> act(ADTarget<?> target );
/**
* Finds captured {@link Tensor} instances in this current action
* using reflection (This is usually a partial derivative).
*
* @return The captured {@link Tensor} instances.
*/
default Tensor<?>[] findCaptured() {
List<Tensor<?>> captured = new ArrayList<>();
for ( Class<?> c = this.getClass(); c != null; c = c.getSuperclass() ) {
for ( java.lang.reflect.Field f : c.getDeclaredFields() ) {
if ( f.getType().equals(Tensor.class) ) {
f.setAccessible(true);
try {
captured.add( (Tensor<?>) f.get(this) );
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
return captured.toArray( new Tensor[0] );
}
default Optional<Tensor<?>> partialDerivative() {
Tensor<?>[] captured = this.findCaptured();
if ( captured.length > 0 )
return Optional.of(captured[captured.length - 1]);
return Optional.empty();
}
}