AbstractFunDeviceAlgorithm.java
package neureka.backend.api.template.algorithms;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.*;
import neureka.backend.api.fun.*;
import neureka.backend.main.memory.MemValidator;
import neureka.math.Function;
import neureka.devices.Device;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
/**
* This is the base class for implementations of the {@link Algorithm} interface.
* The class implements a basic component system, as is implicitly expected by said interface.
* Additionally, it contains useful methods used to process passed arguments of {@link ExecutionCall}
* as well as an implementation of the {@link Algorithm} interface which allows its methods to
* be implemented in a functional programming style, meaning that instances of concrete implementations
* extending this abstract class have setters for lambdas representing the {@link Algorithm} methods.
* It is being used by the standard backend of Neureka as abstract base class for various algorithms.
*
* Conceptually an implementation of the {@link Algorithm} interface represents "a sub-kind of operation" for
* an instance of an implementation of the {@link Operation} interface. <br>
* The "+" operator for example has different {@link Algorithm} instances tailored to specific requirements
* originating from different {@link ExecutionCall} instances with unique arguments.
* {@link Tensor} instances within an execution call having the same shape would
* cause the {@link Operation} instance to choose an {@link Algorithm} instance which is responsible
* for performing element-wise operations, whereas otherwise the {@link neureka.backend.main.algorithms.Broadcast}
* algorithm might be called to perform the operation.
*
* @param <C> The final type extending this class.
*/
public abstract class AbstractFunDeviceAlgorithm<C extends DeviceAlgorithm<C>>
extends AbstractDeviceAlgorithm<C> implements ExecutionPreparation
{
private static final Logger _LOG = LoggerFactory.getLogger( AbstractFunDeviceAlgorithm.class );
/*
Consider the following lambdas as effectively immutable because this
class will warn us if any field variable is set for a second time.
This makes the backend somewhat hackable, but also manageable with respect to complexity.
*/
private SuitabilityPredicate _isSuitableFor;
private ADSupportPredicate _autogradModeFor;
private Execution _execution;
private ADActionSupplier _supplyADActionFor;
private ExecutionPreparation _instantiateNewTensorsForExecutionIn;
/*
This flag will ensure that we can warn the user that the state has been illegally modified.
*/
private boolean _isFullyBuilt = false;
public AbstractFunDeviceAlgorithm( String name ) { super(name); }
/**
* The {@link SuitabilityPredicate} checks if a given instance of an {@link ExecutionCall} is
* suitable to be executed in {@link neureka.backend.api.ImplementationFor}
* residing in this {@link Algorithm} as components.
* It can be implemented as s simple lambda.
*/
@Override
public final float isSuitableFor( ExecutionCall<? extends Device<?>> call ) {
_checkReadiness();
return _isSuitableFor.isSuitableFor(call);
}
/**
* Preparing refers to instantiating output tensors for the provided {@link ExecutionCall}.
*
* @param call The execution call which needs to be prepared for execution.
* @return The prepared {@link ExecutionCall} instance.
*/
@Override
public final ExecutionCall<? extends Device<?>> prepare( ExecutionCall<? extends Device<?>> call ) {
_checkReadiness();
if ( call != null ) {
Tensor<?>[] inputs = call.inputs().clone();
ExecutionCall<? extends Device<?>> prepared = _instantiateNewTensorsForExecutionIn.prepare(call);
Arrays.stream(prepared.inputs())
.filter(
out -> Arrays.stream(inputs)
.noneMatch(in -> in == out)
)
.forEach(t -> t.mut().setIsIntermediate(true));
return prepared;
}
else return null;
}
/**
* @return A new concrete implementation of the {@link AbstractFunDeviceAlgorithm} which
* is fully built and ready to be used as an {@link Operation} component.
*/
public final C buildFunAlgorithm() {
if (
_isSuitableFor == null ||
_autogradModeFor == null ||
(_supplyADActionFor == null && _execution == null) ||
_execution == null ||
_instantiateNewTensorsForExecutionIn == null
) {
throw new IllegalStateException(
"Instance '"+getClass().getSimpleName()+"' incomplete!"
);
}
_isFullyBuilt = true;
return (C) this;
}
/**
* This method ensures that this algorithm was fully supplied with all the
* required lambdas...
*/
private void _checkReadiness() {
if ( !_isFullyBuilt ) {
throw new IllegalStateException(
"Trying use an instance of '"+this.getClass().getSimpleName()+"' with name '" + getName() + "' " +
"which was not fully built!"
);
}
}
/**
* Neureka is supposed to be extremely modular and in a sense its backend should be "hackable" to a degree.
* However, this comes with a lot of risk, because it requires us to expose mutable state, which is not good.
* This class is semi-immutable, by simply warning us about any mutations after building was completed!
*
* @param newState The state which will be set.
* @param <T> The type of the thing which is supposed to be set.
* @param current The state which is currently set.
* @return The checked thing.
*/
private <T> T _checked( T newState, T current, Class<T> type ) {
if ( _isFullyBuilt )
_LOG.warn(
"Implementation '" + type.getSimpleName() + "' in algorithm '"+this+"' was modified! " +
"Please consider only modifying the standard backend state of Neureka for experimental reasons."
);
else if ( current != null && newState == null )
throw new IllegalArgumentException(
"Trying set an already specified implementation of lambda '"+current.getClass().getSimpleName()+"' to null!"
);
return newState;
}
/**
* The {@link SuitabilityPredicate} received by this method
* checks if a given instance of an {@link ExecutionCall} is
* suitable to be executed in {@link neureka.backend.api.ImplementationFor} instances
* residing in this {@link Algorithm} as components.
* The lambda will be called by the {@link #isSuitableFor(ExecutionCall)} method
* by any given {@link Operation} instances this algorithm belongs to.
*
* @param isSuitableFor The suitability predicate which determines if the algorithm is suitable or not.
* @return This very instance to enable method chaining.
*/
public final AbstractFunDeviceAlgorithm<C> setIsSuitableFor( SuitabilityPredicate isSuitableFor ) {
_isSuitableFor = _checked(isSuitableFor, _isSuitableFor, SuitabilityPredicate.class);
return this;
}
/**
* This method receives a {@link ADActionSupplier} which will supply
* {@link ADAction} instances which can perform backward and forward auto differentiation.
*
* @param supplyADActionFor A supplier for an {@link ADAction} containing implementation details for autograd.
* @return This very instance to enable method chaining.
*/
public final AbstractFunDeviceAlgorithm<C> setSupplyADActionFor( ADActionSupplier supplyADActionFor ) {
_supplyADActionFor = _checked(supplyADActionFor, _supplyADActionFor, ADActionSupplier.class);
return this;
}
/**
* An {@link Algorithm} will produce a {@link Result} when executing an {@link ExecutionCall}.
* This result must be created somehow.
* A {@link ExecutionPreparation} implementation instance will do just that...
* Often times the first entry in the array of tensors stored inside the call
* will be null to serve as a position for the output to be placed at.
* The creation of this output tensor is of course highly dependent on the type
* of operation and algorithm that is currently being used.
* Element-wise operations for example will require the creation of an output tensor
* with the shape of the provided input tensors, whereas the execution of a
* linear operation like for example a broadcast operation will require a very different approach...
* The lambda passed to this will be called by the {@link #prepare(ExecutionCall)} method
* by any given {@link Operation} instances this algorithm belongs to.
*
* @param instantiateNewTensorsForExecutionIn A lambda which prepares the provided execution call (usually output instantiation).
* @return This very instance to enable method chaining.
*/
public final AbstractFunDeviceAlgorithm<C> setCallPreparation(ExecutionPreparation instantiateNewTensorsForExecutionIn ) {
_instantiateNewTensorsForExecutionIn = _checked(instantiateNewTensorsForExecutionIn, _instantiateNewTensorsForExecutionIn, ExecutionPreparation.class);
return this;
}
/**
* A {@link ADSupportPredicate} lambda checks what kind of auto differentiation mode an
* {@link Algorithm} supports for a given {@link ExecutionCall}.
* The lambda will be called by the {@link #autoDiffModeFrom(ExecutionCall)} method
* by any given {@link Operation} instances this algorithm belongs to.
*
* @param autogradModeFor A predicate lambda which determines the auto diff mode of this algorithm a given execution call.
* @return This very instance to enable method chaining.
*/
public final AbstractFunDeviceAlgorithm<C> setAutogradModeFor(ADSupportPredicate autogradModeFor ) {
_autogradModeFor = _checked(autogradModeFor, _autogradModeFor, ADSupportPredicate.class);
return this;
}
@Override
public AutoDiffMode autoDiffModeFrom(ExecutionCall<? extends Device<?>> call ) {
_checkReadiness();
return _autogradModeFor.autoDiffModeFrom( call );
}
public AbstractFunDeviceAlgorithm<C> setExecution( Execution execution ) {
_execution = _checked(execution, _execution, Execution.class);
return this;
}
@Override
public Result execute( Function caller, ExecutionCall<? extends Device<?>> call ) {
_checkReadiness();
if ( call == null ) {
if ( _supplyADActionFor != null )
return _execution.execute( caller, call ).withAutoDiff(_supplyADActionFor);
else
return _execution.execute( caller, call );
}
MemValidator checker = MemValidator.forInputs( call.inputs(), ()-> {
if ( _supplyADActionFor != null )
return _execution.execute( caller, call ).withAutoDiff(_supplyADActionFor);
else
return _execution.execute( caller, call );
});
if ( checker.isWronglyIntermediate() ) {
throw new IllegalStateException(
"Output of algorithm '" + this.getName() + "' " +
"is marked as intermediate result, despite the fact " +
"that it is a member of the input array. " +
"Tensors instantiated by library users instead of operations in the backend are not supposed to be flagged " +
"as 'intermediate', because they are not eligible for deletion!"
);
}
if ( checker.isWronglyNonIntermediate() ) {
throw new IllegalStateException(
"Output of algorithm '" + this.getName() + "' " +
"is neither marked as intermediate result nor a member of the input array. " +
"Tensors instantiated by operations in the backend are expected to be flagged " +
"as 'intermediate' in order to be eligible for deletion!"
);
}
return checker.getResult();
}
}