AbstractOperation.java

package neureka.backend.api.template.operations;


import neureka.backend.api.Algorithm;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Operation;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.math.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.LinkedHashMap;
import java.util.Map;

/**
 *  This abstract {@link Operation} implementation is a useful template for creating new operations.
 *  It provides a partial implementation which consists of a simple component system for hosting {@link Algorithm} instances
 *  as well as a set of properties which {@link Operation} implementations are expected to have. <br>
 *  Therefore, the number of properties this class needs to receive is rather large.
 *  In order to instantiate it one has to pass {@link OperationBuilder} instance to the constructor.
 *  Using the factory will make the property configuration as readable as possible. <br>
 *
 */
public abstract class AbstractOperation implements Operation
{
    private static final Logger _LOG = LoggerFactory.getLogger( AbstractOperation.class );

    /**
     *  An operation may have two ways in which it can describe itself as String within a Function AST.
     *  The first one is an operator style of representation and the second one a classical function.
     *  So for the 'Addition' operation the following two representations exist: <br>
     * <ul>
     *      <li> Operator: '+';   Example: 'I[0] + 3 + 5 * I[1]'
     *      <li> Function: 'add'; Example: 'add( I[0], 3, 5*I[1] )'
     * </ul>
     * The following String is the latter way of representing the operation, namely: a functional way.
     */
    protected final String _function;

    /**
     *  An operation may have two ways in which it can describe itself as String within a Function AST.
     *  The first one is an operator style of representation and the second one a classical function.
     *  So for the 'Addition' operation the following two representations exist: <br>
     * <ul>
     *      <li> Operator: '+';   Example: 'I[0] + 3 + 5 * I[1]'
     *      <li> Function: 'add'; Example: 'add( I[0], 3, 5*I[1] )'
     * </ul>
     * The following String is the primary way of representing the operation, namely: as an operator.
     */
    protected final String _operator;

    /**
     * Arity is the number of arguments or operands
     * that this function or operation takes.
     */
    protected final int _arity;

    /**
     *  This flag determines if this operation is auto-indexing passed input arguments.
     *  Auto-indexing inputs means that for a given array of input arguments
     *  the wrapping Function instance will call its child nodes targeted via an
     *  index incrementally.
     *  The variable 'j' in a Functions expressions containing 'I[j]' will then be
     *  resolved to an actual input for a given indexer...
     */
    protected final boolean _isIndexer;

    /**
     *  Certain operations are not differentiable, meaning they cannot participate
     *  in neither forward nor reverse mode differentiation.
     *  In order to avoid error-prone behaviour trying to involve
     *  non- differentiable operations will yield proper exceptions.
     */
    protected final boolean _isDifferentiable;

    /**
     *  Inline operations are operations which change the state of the arguments passed to them.
     */
    protected final boolean _isInline;
    protected final boolean _isOperator;

    private final Map<Class<?>, Algorithm> _algorithms = new LinkedHashMap<>();

    /**
     *  This is the default algorithm for every Operation extending this class.
     *  It may not fit the purpose of every Operation implementation,
     *  however for most operation types it will provide useful functionalities.
     *
     *  The default algorithm assumes an operation that is either a function or operator.
     *  Meaning that it assumes that the operation is also differentiable.
     *  Therefore, it contains functionality that goes alongside this assumption,
     *  just to name a few :                                                                                            <br>
     *                                                                                                                  <br>
     *  - An ADAction supplier returning ADAction instances capable of performing both forward- and reverse- mode AD.     <br>
     *  - A simple result tensor instantiation implementation.                                                          <br>
     *  - A basic threaded execution based on the AST of a given Function object.                                       <br>
     */
    private final FallbackAlgorithm _defaultAlgorithm;

    public AbstractOperation( OperationBuilder builder )
    {
        builder.dispose();

        _function         = builder.getIdentifier();
        _arity            = builder.getArity();
        _operator         = builder.getOperator();
        _isOperator       = builder.getIsOperator();
        _isIndexer        = builder.getIsIndexer();
        _isDifferentiable = builder.getIsDifferentiable();
        _isInline         = builder.getIsInline();
        _defaultAlgorithm = new FallbackAlgorithm( "default", _arity, this );
    }

    @Override
    public final Algorithm[] getAllAlgorithms() { return _algorithms.values().toArray(new Algorithm[0]); }

    /**
     *  {@link Operation} implementations embody a component system hosting unique {@link Algorithm} instances.
     *  For a given class implementing the {@link Algorithm} class, there can only be a single
     *  instance of it referenced (aka supported) by a given {@link Operation} instance.
     *  This method ensures this in terms of read access by returning only a single instance or null
     *  based on the provided class instance whose type extends the {@link Algorithm} interface.
     *
     * @param type The class of the type which implements {@link Algorithm} as a key to get an existing instance.
     * @param <T> The type parameter of the {@link Algorithm} type class.
     * @return The instance of the specified type if any exists within this {@link Operation}.
     */
    @Override
    public final <T extends Algorithm> T getAlgorithm( Class<T> type ) {
        T found = (T) _algorithms.get( type );
        if ( found == null ) // Maybe the provided type is a superclass of one of the entries...
            return _algorithms.entrySet()
                                    .stream()
                                    .filter( e -> type.isAssignableFrom( e.getKey() ) )
                                    .map( e -> (T) e.getValue() )
                                    .findFirst()
                                    .orElse( null );
        else
            return found;
    }

    /**
     *  This method checks if this {@link Operation} contains an instance of the
     *  {@link Algorithm} implementation specified via its type class.
     *
     * @param type The class of the type which implements {@link Algorithm}.
     * @param <T> The type parameter of the {@link Algorithm} type class.
     * @return The truth value determining if this {@link Operation} contains an instance of the specified {@link Algorithm} type.
     */
    @Override
    public final <T extends Algorithm> boolean supportsAlgorithm( Class<T> type ) {
        return _algorithms.containsKey( type );
    }

    /**
     *  {@link Operation} implementations embody a component system hosting unique {@link Algorithm} instances.
     *  For a given class implementing the {@link Algorithm} class, there can only be a single
     *  instance of it referenced (aka supported) by a given {@link Operation} instance.
     *  This method enables the registration of {@link Algorithm} types in the component system of this {@link Operation}.
     *
     * @param type The class of the type which implements {@link Algorithm} as key for the provided instance.
     * @param instance The instance of the provided type class which ought to be referenced (supported) by this {@link Operation}.
     * @param <T> The type parameter of the {@link Algorithm} type class.
     * @return This very {@link Operation} instance to enable method chaining on it.
     */
    @Override
    public final <T extends Algorithm> Operation setAlgorithm( Class<T> type, T instance ) {
        if ( _algorithms.containsKey( type ) )
            throw new IllegalArgumentException(
                        "Algorithm of type '"+type.getSimpleName()+"' already defined for this operation!"
                    );

        _algorithms.put( type, instance );
        return this;
    }

    @Override
    public final Algorithm getAlgorithmFor( ExecutionCall<?> call )
    {
        float bestScore = 0f;
        Algorithm bestImpl = null;
        //Device<?> device = call.getDevice();
        for ( Algorithm impl : _algorithms.values() ) {
            //if ( impl instanceof DeviceAlgorithm<?> && !((DeviceAlgorithm)impl).hasImplementationFor(device) )
            //    continue;

            float currentScore = impl.isSuitableFor( call );
            if ( currentScore > bestScore ) {
                if ( currentScore == 1.0 ) return impl;
                else {
                    bestScore = currentScore;
                    bestImpl = impl;
                }
            }
        }
        float defaultSuitability = _defaultAlgorithm.isSuitableFor( call );

        if ( defaultSuitability > bestScore ) {
            _LOG.debug("Default algorithm picked for call targeting operation '"+call.getOperation()+"'.");
            return _defaultAlgorithm;
        }

        if ( bestImpl == null ) {
            String message = "No suitable implementation for execution call '"+call+"' could be found.\n" +
                                "Execution process aborted.";
            _LOG.error( message );
            throw new IllegalStateException( message );
        }
        return bestImpl;
    }

    @Override
    public final <T extends Algorithm> boolean supports( Class<T> implementation ) {
        return _algorithms.containsKey( implementation );
    }

    @Override public final boolean isOperator() { return _isOperator; }

    @Override public String getIdentifier() { return _function; }

    @Override public final String getOperator() { return _operator; }

    @Override public final int getArity() { return _arity; }

    @Override public final boolean isIndexer() { return _isIndexer; }

    @Override public final boolean isDifferentiable() { return _isDifferentiable; }

    @Override public boolean isInline() { return _isInline; }

    public final FallbackAlgorithm getDefaultAlgorithm() { return _defaultAlgorithm; }

    /** {@inheritDoc} */
    @Override
    public String asDerivative(Function[] children, int derivationIndex) {
        throw new IllegalStateException("Operation '"+this.getIdentifier()+"' does not support dynamic derivation!");
    }

    /** {@inheritDoc} */
    
    @Override
    public String stringify( String[] children ) {
        if ( this.isOperator() ) {
            StringBuilder reconstructed = new StringBuilder();
            for ( int i = 0; i < children.length; ++i ) {
                reconstructed.append( children[ i ] );
                if ( i < children.length - 1 )
                    reconstructed
                            .append(" ")
                            .append(this.getOperator())
                            .append(" ");
            }
            return "(" + reconstructed + ")";
        } else {
            String expression = String.join(", ", children);
            if (expression.charAt(0) == '(' && expression.charAt(expression.length() - 1) == ')')
                return getIdentifier() + expression;
            else
                return getIdentifier() + "(" + expression + ")";
        }
    }

    @Override
    public final String toString() {
        String operationName = operationName().trim();
        operationName = operationName.isEmpty() ? "AnonymousOperation" : operationName;
        String asString = operationName+"@"+Integer.toHexString(hashCode());
        asString = asString + "[identifier='" + _function + "',operator='"+_operator+"']";
        return asString;
    }

    /**
     *  Override this if you want your operation to have a string representation
     *  with a custom prefix which is something other than the simple class name!
     *
     * @return The simple class name, or something else if overridden.
     */
    protected String operationName() {
        return this.getClass().getSimpleName();
    }
}