AbstractFunAlgorithm.java

package neureka.backend.api.template.algorithms;

import neureka.backend.api.*;
import neureka.backend.api.fun.ADSupportPredicate;
import neureka.backend.api.fun.Execution;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.main.memory.MemValidator;
import neureka.math.Function;
import neureka.devices.Device;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AbstractFunAlgorithm extends AbstractAlgorithm
{
    private static final Logger _LOG = LoggerFactory.getLogger( AbstractFunDeviceAlgorithm.class );
    /*
        Consider the following lambdas as effectively immutable because this
        class will warn us if any field variable is set for a second time.
        This makes the backend somewhat hackable, but also manageable with respect to complexity.
     */
    private SuitabilityPredicate _isSuitableFor;
    private ADSupportPredicate _autogradModeFor;
    private Execution _execution;
    /*
        This flag will ensure that we can warn the user that the state has been illegally modified.
     */
    private boolean _isFullyBuilt = false;


    protected AbstractFunAlgorithm(String name) {
        super(name);
    }

    /**
     *  The {@link SuitabilityPredicate} checks if a given instance of an {@link ExecutionCall} is
     *  suitable to be executed in {@link neureka.backend.api.ImplementationFor}
     *  residing in this {@link Algorithm} as components.
     *  It can be implemented as s simple lambda.
     */
    @Override
    public final float isSuitableFor( ExecutionCall<? extends Device<?>> call ) {
        _checkReadiness();
        return _isSuitableFor.isSuitableFor(call);
    }

    /**
     * @return A new concrete implementation of the {@link AbstractFunDeviceAlgorithm} which
     *         is fully built and ready to be used as an {@link Operation} component.
     */
    public final AbstractFunAlgorithm buildFunAlgorithm() {
        if (
                _isSuitableFor == null ||
                _autogradModeFor == null ||
                _execution == null
        ) {
            throw new IllegalStateException(
                    "Instance '"+getClass().getSimpleName()+"' incomplete!"
            );
        }

        _isFullyBuilt = true;
        return this;
    }

    /**
     *  This method ensures that this algorithm was fully supplied with all the
     *  required lambdas...
     */
    private void _checkReadiness() {
        if ( !_isFullyBuilt ) {
            throw new IllegalStateException(
                    "Trying use an instance of '"+this.getClass().getSimpleName()+"' with name '" + getName() + "' " +
                            "which was not fully built!"
            );
        }
    }

    /**
     *  Neureka is supposed to be extremely modular and in a sense its backend should be "hackable" to a degree.
     *  However, this comes with a lot of risk, because it requires us to expose mutable state, which is not good.
     *  This class is semi-immutable, by simply warning us about any mutations after building was completed!
     *
     * @param newState The state which will be set.
     * @param <T> The type of the thing which is supposed to be set.
     * @param current The state which is currently set.
     * @return The checked thing.
     */
    private <T> T _checked( T newState, T current, Class<T> type ) {
        if ( _isFullyBuilt )
            _LOG.warn(
                "Implementation '" + type.getSimpleName() + "' in algorithm '"+this+"' was modified! " +
                "Please consider only modifying the standard backend state of Neureka for experimental reasons."
            );
        else if ( current != null && newState == null )
            throw new IllegalArgumentException(
                    "Trying set an already specified implementation of lambda '"+current.getClass().getSimpleName()+"' to null!"
            );

        return newState;
    }

    /**
     *  The {@link SuitabilityPredicate} received by this method
     *  checks if a given instance of an {@link ExecutionCall} is
     *  suitable to be executed in {@link neureka.backend.api.ImplementationFor} instances
     *  residing in this {@link Algorithm} as components.
     *  The lambda will be called by the {@link #isSuitableFor(ExecutionCall)} method
     *  by any given {@link Operation} instances this algorithm belongs to.
     *
     * @param isSuitableFor The suitability predicate which determines if the algorithm is suitable or not.
     * @return This very instance to enable method chaining.
     */
    public final AbstractFunAlgorithm setIsSuitableFor(SuitabilityPredicate isSuitableFor ) {
        _isSuitableFor = _checked(isSuitableFor, _isSuitableFor, SuitabilityPredicate.class);
        return this;
    }

    /**
     *  A {@link ADSupportPredicate} lambda checks what kind of auto differentiation mode an
     *  {@link Algorithm} supports for a given {@link ExecutionCall}.
     *  The lambda will be called by the {@link #autoDiffModeFrom(ExecutionCall)} method
     *  by any given {@link Operation} instances this algorithm belongs to.
     *
     * @param autogradModeFor A predicate lambda which determines the auto diff mode of this algorithm a given execution call.
     * @return This very instance to enable method chaining.
     */
    public final AbstractFunAlgorithm setAutogradModeFor(ADSupportPredicate autogradModeFor ) {
        _autogradModeFor = _checked(autogradModeFor, _autogradModeFor, ADSupportPredicate.class);
        return this;
    }

    @Override
    public AutoDiffMode autoDiffModeFrom(ExecutionCall<? extends Device<?>> call ) {
        _checkReadiness();
        return _autogradModeFor.autoDiffModeFrom( call );
    }

    public final AbstractFunAlgorithm setExecution(Execution execution ) {
        _execution = _checked(execution, _execution, Execution.class);
        return this;
    }

    @Override
    public Result execute( Function caller, ExecutionCall<? extends Device<?>> call ) {
        _checkReadiness();
        MemValidator checker = MemValidator.forInputs( call.inputs(), ()-> _execution.execute( caller, call ));
        if ( checker.isWronglyIntermediate() ) {
            throw new IllegalStateException(
                    "Output of algorithm '" + this.getName() + "' " +
                    "is marked as intermediate result, despite the fact " +
                    "that it is a member of the input array. " +
                    "Tensors instantiated by library users instead of operations in the backend are not supposed to be flagged " +
                    "as 'intermediate', because they are not eligible for deletion!"
            );
        }
        if ( checker.isWronglyNonIntermediate() ) {
            throw new IllegalStateException(
                    "Output of algorithm '" + this.getName() + "' " +
                    "is neither marked as intermediate result nor a member of the input array. " +
                    "Tensors instantiated by operations in the backend are expected to be flagged " +
                    "as 'intermediate' in order to be eligible for deletion!"
            );
        }
        return checker.getResult();
    }

}