BackendContext.java

package neureka.backend.api;


import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.ini.BackendLoader;
import neureka.backend.api.ini.BackendRegistry;
import neureka.backend.api.ini.ImplementationReceiver;
import neureka.backend.api.ini.LoadingContext;
import neureka.math.Function;
import neureka.math.FunctionCache;
import neureka.math.Functions;
import neureka.math.parsing.FunctionParser;
import neureka.math.parsing.ParseUtil;
import neureka.common.utility.LogUtil;
import neureka.devices.Device;
import org.slf4j.Logger;

import java.util.*;
import java.util.function.Supplier;

/**
 *    Instances of this class are execution contexts hosting {@link Operation} instances which receive {@link Tensor}
 *    instances for execution.
 *    {@link BackendContext}s are managed by {@link Neureka}, a (thread-local) Singleton / Multiton library context.<br>
 *    Contexts are cloneable for testing purposes and to enable extending the backend dynamically.
 *    A given instance also hosts a reference to a {@link Functions} instance which exposes commonly used
 *    pre-instantiated {@link Function} implementation instances.
 *    <br><br>
 *    The {@link BackendContext} initializes and stores {@link Operation} instances in various data structures
 *    for fast access and querying (Mostly used by the {@link ParseUtil} and {@link FunctionParser}).
 *    <br>
 *    {@link Operation}s are stored in simple list and map collections,
 *    namely: <br>
 *    The "_instances" list and the "_lookup" map as declared below.
 *    <br>
 *    <br>
 *    During class initialization concrete classes extending the {@link Operation} class
 *    are being instantiated in the static block below via a {@link ServiceLoader}.
 *    {@link BackendContext} instances expose a useful class called {@link Runner},
 *    which performs temporary context switching between the caller's context and this
 *    context during the execution of provided lambdas.
 *
 */
public final class BackendContext implements Cloneable
{
    private static final Logger log = org.slf4j.LoggerFactory.getLogger(BackendContext.class);

    private final Extensions _extensions = new Extensions();

    /**
     *  A mapping between OperationType identifiers and their corresponding instances.
     */
    private final Map<String, Operation> _lookup;

    /**
     *  A list of all OperationType instances.
     */
    private final List<Operation> _operations;

    /**
     *  The number of operation instances stored in this context.
     */
    private int _size;

    // Global context and cache:
    private final FunctionCache _functionCache = new FunctionCache();

    private final LazyRef<Functions> _getAutogradFunction;

    /**
     *  This {@link Functions} instance wraps pre-instantiated
     *  {@link Function} instances which are configured to not track their computational history.
     *  This means that no computation graph will be built by these instances.
     *  ( Computation graphs in Neureka are made of instances of the "GraphNode" class... )
     */
    private final LazyRef<Functions> _getFunction;


    /**
     *  This creates a new context which is completely void of any {@link Operation} implementation instances.
     *  Use this constructor to test, debug, build and populate custom execution contexts.
     */
    public BackendContext()
    {
        _getAutogradFunction = LazyRef.of( () -> new Functions( true ) );
        _getFunction = LazyRef.of( () -> new Functions( false ) );
        _lookup = new HashMap<>();
        _operations = new ArrayList<>();
        _size = 0;
    }

    public void reset() {
        for ( BackendExtension e : _extensions.getAll(BackendExtension.class) ) {
            try {
                e.reset();
            } catch (Exception ex) {
                log.error("Error while resetting backend extension: " + e.getClass().getName(), ex);
            }
        }
    }

    /**
     *  A {@link Runner} wraps both the called context as well as the context of the caller in order
     *  to perform temporary context switching during the execution of lambdas passed to the {@link Runner}.
     *  After a given lambda was executed successfully, the original context will be restored in the current
     *  thread local {@link Neureka} instance through the {@link Neureka#setBackend(BackendContext)}) method.
     *
     * @return A lambda {@link Runner} which performs temporary context switching between the caller's context and this context.
     */
    public Runner runner() { return new Runner( this, Neureka.get().backend() ); }

    /**
     * This method returns an unmodifiable view of the mapping between the {@link Operation#getIdentifier()} / {@link Operation#getOperator()} properties
     * and the {@link Operation} implementation instances to which they belong.
     * Query operations on the returned map "read through" to the specified map,
     * and attempts to modify the returned map, whether direct or via its collection views,
     * result in an {@link UnsupportedOperationException}.
     *
     * @return An unmodifiable mapping of {@link Operation} properties to the {@link Operation} instances to which they belong.
     */
    public Map<String, Operation> getOperationLookupMap() { return Collections.unmodifiableMap( _lookup ); }

    /**
     * This method returns an unmodifiable view of the
     * list of {@link Operation} implementation instances managed by this context.
     * Query operations on the returned map "read through" to the specified map,
     * and attempts to modify the returned map, whether direct or via its collection views,
     * result in an {@link UnsupportedOperationException}.
     *
     * @return An unmodifiable view of the list of {@link Operation} implementation instances managed by this context
     */
    public List<Operation> getOperations() { return Collections.unmodifiableList( _operations); }

    /**
     * @return The number of {@link Operation} instances stored on this {@link BackendContext}.
     */
    public int size() { return _size; }

    /**
     * @return The {@link Function} and {@link Tensor} cache of this {@link BackendContext}
     */
    public FunctionCache getFunctionCache() { return _functionCache; }

    /**
     *  This method returns a {@link Functions} instance which wraps pre-instantiated
     *  {@link Function} instances which are configured to not track their computational history.
     *  This means that no computation graph will be built by these instances.
     *  ( Computation graphs in Neureka are made of instances of the {@link neureka.autograd.GraphNode} class... )
     */
    public Functions getFunction() { return _getFunction.get(); }

    /**
     *  This method returns a {@link Functions} instance which wraps pre-instantiated
     *  {@link Function} instances which are configured to track their computational history.
     *  This means that a computation graph will be built by these instances.
     *  ( Computation graphs in Neureka are made of instances of the {@link neureka.autograd.GraphNode} class... )
     *
     * @return A container object which exposes various types of functions with autograd support.
     */
    public Functions getAutogradFunction() { return _getAutogradFunction.get(); }

    /**
     *  This method registers {@link Operation} implementation instances in this {@link BackendContext}
     *  which is the thread local execution context receiving and processing {@link Tensor} instances...         <br><br>
     *
     * @param operation The {@link Operation} instance which ought to be registered as part of this execution context.
     * @return This very context instance to allow for method chaining.
     */
    public BackendContext addOperation(Operation operation )
    {
        _operations.add( operation );
        String function = operation.getIdentifier();
        String operator = operation.getOperator();
        assert !_lookup.containsKey( operator );
        assert !_lookup.containsKey( function );
        _lookup.put( operator, operation );
        _lookup.put( function, operation );
        _lookup.put( operator.toLowerCase(), operation );
        _size++;
        return this;
    }

    /**
     * @param operation The {@link Operation} which may or may not be part of this {@link BackendContext}.
     * @return The truth value determining if the provided {@link Operation} is part of this {@link BackendContext}.
     */
    public boolean hasOperation( Operation operation ) {
        return _lookup.containsKey( operation.getIdentifier() );
    }

    /**
     * @param operationIdentifier The {@link Operation} identifier which may be the function name or operator if present.
     * @return The truth value determining if the provided {@link Operation} is part of this {@link BackendContext}.
     */
    public boolean hasOperation( String operationIdentifier ) {
        return _lookup.containsKey( operationIdentifier );
    }

    /**
     *  This method queries the operations in this {@link BackendContext}
     *  by a provided index integer targeting an entry in the list of {@link Operation} implementation instances
     *  sitting in this execution context.
     *
     * @param index The index of the operation.
     * @return The found Operation instance or null.
     */
    public Operation getOperation( int index ) { return _operations.get( index ); }

    /**
     *  This method queries the operations in this BackendContext
     *  by a provided identifier which has to match the name of
     *  an existing operation.
     *
     * @param identifier The operation identifier, aka: its name.
     * @return The requested Operation or null.
     */
    public Operation getOperation( String identifier ) { return _lookup.getOrDefault( identifier, null ); }

    /**
     *  This method produces a shallow copy of this {@link BackendContext}.
     *  This is useful for debugging, testing and extending contexts during runtime without side effects!  <br>
     *
     * @return A shallow copy of this operation / execution context.
     */
    @Override
    public BackendContext clone()
    {
        BackendContext clone = new BackendContext();
        clone._size = _size;
        clone._lookup.putAll( _lookup );
        clone._operations.addAll( _operations );
        return clone;
    }

    public String toString() {
        return getClass().getSimpleName()+"[size=" + this.size() + "]";
    }

    /**
     *  Checks if this context has an instance of the provided {@link BackendExtension} type.
     *
     * @param extensionClass The type class of the extensions whose presents should be checked.
     * @param <E> The type parameter of the provided type class which requires the type to be an extension.
     * @return The truth value determining if the provided type is present.
     */
    public <E extends BackendExtension> boolean has( Class<E> extensionClass ) {
        return _extensions.has( extensionClass );
    }

    /**
     *  Returns an {@link Optional} instance of the provided {@link BackendExtension} type
     *  or an empty {@link Optional} if no extension of that type was found.
     */
    public <E extends BackendExtension> Optional<E> find( Class<E> componentClass ) {
        return _extensions.find( componentClass );
    }

    /**
     * @return A list of all {@link BackendExtension} instances.
     */
    public List<BackendExtension> getExtensions() {
        return _extensions.getAll( BackendExtension.class );
    }


    private class Registered<D extends Device<?>> {

        final Class<? extends Operation> operationType;
        final Class<? extends DeviceAlgorithm> algorithmType;
        final Class<? extends D> deviceType;
        final java.util.function.Function<LoadingContext, ImplementationFor<D>> function;

        private Registered(Class<? extends Operation> operationType, Class<? extends DeviceAlgorithm> algorithmType, Class<? extends D> deviceType, java.util.function.Function<LoadingContext, ImplementationFor<D>> function) {
            this.operationType = operationType;
            this.algorithmType = algorithmType;
            this.deviceType = deviceType;
            this.function = function;
        }
    }

    /**
     *  Registers the provided {@link BackendExtension} instance
     *  which can then be accessed via {@link #find(Class)}.
     *
     * @param extension The backend extension component which ought to be stored by this.
     * @return This very {@link BackendContext} instance to allow for method chaining.
     */
    public BackendContext set( BackendExtension extension )
    {
        LogUtil.nullArgCheck( extension, "extension", BackendExtension.class );
        BackendLoader loader = extension.getLoader();
        LogUtil.nullArgCheck( loader, "loader", BackendLoader.class );
        // Now before adding the extension to the backend we first try to load all the implementations:
        List<Registered<?>> registeredList = new ArrayList<>();
        loader.load(BackendRegistry.of(
                new ImplementationReceiver() {
                    @Override
                    public <D extends Device<?>> void accept(
                            Class<? extends Operation> operationType,
                            Class<? extends DeviceAlgorithm> algorithmType,
                            Class<? extends D> deviceType,
                            java.util.function.Function<LoadingContext, ImplementationFor<D>> function
                    ) {
                        registeredList.add(new Registered<>(operationType, algorithmType, deviceType, function));
                    }
                }
        ));
        int count = 0;
        for ( Registered<?> registered : registeredList )
            count += _register( registered ) ? 1 : 0;

        count = registeredList.size() - count;

        if ( count != 0 )
            throw new IllegalStateException(
                "Failed to register "+count+" implementations for extension of type '"+extension.getClass().getSimpleName()+"'."
            );

        _extensions.set( extension );
        return this;
    }

    private boolean _register( Registered<?> registered ) {
        for ( Operation o : _operations ) {
            if ( o.getClass().equals( registered.operationType ) ) {
                for ( Algorithm a : o.getAllAlgorithms() ) {
                    // We make sure it is a device algorithm:
                    if ( a instanceof DeviceAlgorithm ) {
                        DeviceAlgorithm da = (DeviceAlgorithm) a;
                        if ( registered.algorithmType.isAssignableFrom(da.getClass()) ) {
                            da.setImplementationFor(
                                registered.deviceType,
                                registered.function.apply(new LoadingContext() {
                                    @Override public String getAlgorithmName() { return da.getName(); }
                                    @Override public String getOperationIdentidier() { return o.getIdentifier(); }
                                })
                            );
                            return true;
                        }
                    }
                }
            }
        }
        return false;
    }

    /**
     *  This is a very simple class with a single purpose, namely
     *  it exposes methods which receive lambda instances in order to then execute them
     *  in a given {@link BackendContext}, just to then switch back to the original context again.
     *  Switching a context simply means that the {@link BackendContext} which produced this {@link Runner}
     *  will temporarily be set as execution context for the current thread
     *  local {@link Neureka} instance.                                              <br><br>
     *
     *  A {@link Runner} wraps both the called context as well as the context of the caller in order
     *  to perform this temporary context switching throughout the execution of the lambdas passed to the {@link Runner}.
     *  After a given lambda was executed, the original context will be restored in the current thread
     *  local {@link Neureka} instance through the {@link Neureka#setBackend(BackendContext)}) method.
     */
    public static class Runner
    {
        private final BackendContext originalContext;
        private final BackendContext visitedContext;

        private Runner(BackendContext visited, BackendContext originalContext ) {
            if ( visited == originalContext ) log.warn("Context runner encountered two identical contexts!");
            this.originalContext = originalContext;
            this.visitedContext = visited;
        }

        /**
         *  Use this method to supply a lambda which will be executed in the {@link BackendContext}
         *  which produced this very {@link Runner} instance.
         *  After the lambda finished execution successfully the original {@link BackendContext} will
         *  be restored for the current thread local {@link Neureka} instance.
         *
         * @param contextSpecificAction The context specific action which will be execute in the {@link BackendContext} which produced this {@link Runner}.
         * @return This very {@link Runner} instance to enable method chaining.
         */
        public Runner run( Runnable contextSpecificAction ) {
            Neureka.get().setBackend( visitedContext );
            contextSpecificAction.run();
            Neureka.get().setBackend( originalContext );
            return this;
        }

        /**
         *  Use this method to supply a lambda which will be executed in the {@link BackendContext}
         *  which produced this very {@link Runner} instance.
         *  After the lambda finished execution successfully the original {@link BackendContext} will be restored.
         *  This method distinguishes itself from the {@link #run(Runnable)} method because the
         *  lambda supplied to this method is expected to return something.
         *  What may be returned is up to the user, one might want to return the result
         *  of a tensor operation which might be exclusively available in the used context.
         *
         * @param contextSpecificAction The context specific action which will be execute in the {@link BackendContext} which produced this {@link Runner}.
         * @param <T> The return type of the supplied context action which will also be returned by this method.
         * @return The result of the supplied context action.
         */
        public <T> T runAndGet( Supplier<T> contextSpecificAction ) {
            Neureka.get().setBackend( visitedContext );
            T result = contextSpecificAction.get();
            Neureka.get().setBackend( originalContext );
            return result;
        }

        /**
         *  Use this method to supply a lambda which will be executed in the {@link BackendContext}
         *  which produced this very {@link Runner} instance.
         *  After the lambda finished execution successfully the original {@link BackendContext} will be restored.
         *  This method distinguishes itself from the {@link #run(Runnable)} method because the
         *  lambda supplied to this method is expected to return something.                            <br>
         *  What may be returned is up to the user, one might want to return the result
         *  of a tensor operation which might be exclusively available in the used context.
         *  This method is doing the exact same thing as the {@link #runAndGet(Supplier)} method,
         *  however its name is shorter and it can even be omitted entirely when using Groovy.          <br><br>
         *
         * @param contextSpecificAction The context specific action which will be execute in the {@link BackendContext} which produced this {@link Runner}.
         * @param <T> The return type of the supplied context action which will also be returned by this method.
         * @return The result of the supplied context action.
         */
        public <T> T call( Supplier<T> contextSpecificAction ) {
            return runAndGet( contextSpecificAction );
        }

        /**
         *  Use this method to supply a lambda which will be executed in the {@link BackendContext}
         *  which produced this very {@link Runner} instance.
         *  After the lambda finished execution successfully the original {@link BackendContext} will be restored.
         *  This method distinguishes itself from the {@link #run(Runnable)} method because the
         *  lambda supplied to this method is expected to return something.                            <br>
         *  What may be returned is up to the user, one might want to return the result
         *  of a tensor operation which might be exclusively available in the used context.
         *  This method is doing the exact same thing as the {@link #runAndGet(Supplier)} method,
         *  however its name is shorter and it can even be omitted entirely when using Kotlin.          <br><br>
         *
         * @param contextSpecificAction The context specific action which will be execute in the {@link BackendContext} which produced this {@link Runner}.
         * @param <T> The return type of the supplied context action which will also be returned by this method.
         * @return The result of the supplied context action.
         */
        public <T> T invoke( Supplier<T> contextSpecificAction ) {
            return call( contextSpecificAction );
        }
    }

}