AbstractDeviceAlgorithm.java

  1. package neureka.backend.api.template.algorithms;

  2. import neureka.Neureka;
  3. import neureka.Shape;
  4. import neureka.Tensor;
  5. import neureka.backend.api.*;
  6. import neureka.backend.api.fun.ExecutionPreparation;
  7. import neureka.backend.main.algorithms.ElementwiseAlgorithm;
  8. import neureka.backend.main.internal.FinalExecutor;
  9. import neureka.backend.main.memory.MemUtil;
  10. import neureka.math.Function;
  11. import neureka.math.args.Arg;
  12. import neureka.math.parsing.FunctionParser;
  13. import neureka.math.implementations.FunctionConstant;
  14. import neureka.common.utility.LogUtil;
  15. import neureka.devices.Device;
  16. import org.slf4j.Logger;
  17. import org.slf4j.LoggerFactory;

  18. import java.util.HashMap;
  19. import java.util.Map;
  20. import java.util.function.Consumer;
  21. import java.util.function.Supplier;
  22. import java.util.stream.Collectors;
  23. import java.util.stream.IntStream;

  24. /**
  25.  *  This is a partial implementation of the {@link Algorithm} interface which implements
  26.  *  the component system for implementation instances of the {@link ImplementationFor} interface.
  27.  *  These components implement an algorithm for a specific {@link Algorithm}.
  28.  *
  29.  * @param <C> The type of the concrete extension of this class.
  30.  */
  31. public abstract class AbstractDeviceAlgorithm<C extends DeviceAlgorithm<C>>
  32. extends AbstractAlgorithm
  33. implements DeviceAlgorithm<C>
  34. {
  35.     private final static Logger _LOG = LoggerFactory.getLogger(AbstractDeviceAlgorithm.class);

  36.     protected final Map<Class<Device<?>>, ImplementationFor<Device<?>>> _implementations = new HashMap<>();

  37.     public AbstractDeviceAlgorithm( String name ) { super( name ); }

  38.     @Override
  39.     public <D extends Device<?>, E extends ImplementationFor<D>> C setImplementationFor(
  40.             Class<D> deviceClass, E implementation
  41.     ) {
  42.         if ( _implementations.containsKey( deviceClass ) )
  43.             _LOG.info(
  44.                 "Implementation for device '" + deviceClass.getSimpleName() + "' already defined!"
  45.             );

  46.         _implementations.put(
  47.             (Class<Device<?>>) deviceClass,
  48.             (ImplementationFor<Device<?>>) implementation
  49.         );
  50.         return (C) this;
  51.     }

  52.     @Override
  53.     public <D extends Device<?>> ImplementationFor<D> getImplementationFor(Class<D> deviceClass )
  54.     {
  55.         ImplementationFor<D> found = (ImplementationFor<D>) _implementations.get( deviceClass );
  56.         if ( found == null )
  57.             for ( Class<Device<?>> type : _implementations.keySet() )
  58.                 if ( type.isAssignableFrom(deviceClass) )
  59.                     return (ImplementationFor<D>) _implementations.get(type);

  60.         return found;
  61.     }

  62.     @Override
  63.     public String toString() {
  64.         String algorithmString = getClass().getSimpleName()+"@"+Integer.toHexString(hashCode());
  65.         String implementations = _implementations.keySet().stream().map(Class::getSimpleName).collect(Collectors.joining(","));
  66.         algorithmString = ( algorithmString + "[name=" + getName() + ",support=[" + implementations + "]]" );
  67.         return algorithmString;
  68.     }


  69.     public static Tensor<?> executeFor(
  70.             final Function caller,
  71.             final ExecutionCall<? extends Device<?>> call,
  72.             final FinalExecutor executor
  73.     ) {
  74.         Function[] nodes = caller.getSubFunctions().toArray(new Function[0]);
  75.         Operation operation = caller.getOperation();
  76.         assert call.getOperation() == operation;
  77.         boolean isFlat = caller.isFlat();
  78.         boolean isDoingAD = caller.isDoingAD();
  79.         if ( call.getValOf( Arg.DerivIdx.class ) < 0 )
  80.             return _deepActivation( call, nodes, isFlat, isDoingAD, executor );
  81.         else
  82.             return _deepDerivative( call, nodes,  executor );
  83.     }

  84.     public static Tensor<?> prepareAndExecute(
  85.             ExecutionCall<? extends Device<?>> executionCall,
  86.             FinalExecutor executor
  87.     ) {
  88.         ExecutionCall<? extends Device<?>> call = _prepareForExecution(executionCall);
  89.         return executeOnCommonDevice(call, ()->{
  90.              /*
  91.                 Below is the core lambda of recursive preprocessing
  92.                 which is defined for each Algorithm individually :
  93.              */
  94.             Tensor<?> result = null;
  95.             if ( executor != null )
  96.                 result = executor.execute(call);
  97.             return result;
  98.         });
  99.     }

  100.     public static ExecutionCall<? extends Device<?>> _prepareForExecution(ExecutionCall<? extends Device<?>> executionCall) {
  101.         Algorithm currentAlgorithm = executionCall.getAlgorithm();
  102.         if ( currentAlgorithm instanceof ExecutionPreparation)
  103.             executionCall = ( (ExecutionPreparation) currentAlgorithm ).prepare( executionCall );

  104.         for ( Tensor<?> t : executionCall.inputs() )
  105.             if ( t == null ) throw new IllegalArgumentException(
  106.                                 "Device arguments may not be null!\n" +
  107.                                 "One or more tensor arguments within the given ExecutionCall instance is null."
  108.                             );
  109.         return executionCall;
  110.     }

  111.     public static Tensor<?> executeDeviceAlgorithm(
  112.             ExecutionCall<? extends Device<?>> call
  113.     ) {
  114.         for ( Tensor<?> t : call.inputs() )
  115.             if ( t == null ) throw new IllegalArgumentException(
  116.                     "Device arguments may not be null!\n" +
  117.                     "One or more tensor arguments within the given ExecutionCall instance is null."
  118.                 );

  119.         Device<?> device = call.getDevice();

  120.         Algorithm algorithm = call.getAlgorithm();
  121.         if ( algorithm == null ) {
  122.             String message = _couldNotFindSuitableAlgorithmFor( device.getClass() );
  123.             _LOG.error( message );
  124.             throw new IllegalStateException( message );
  125.         } else {
  126.             DeviceAlgorithm<?> deviceAlgorithm = ( algorithm instanceof DeviceAlgorithm ? ((DeviceAlgorithm<?>) algorithm) : null );
  127.             ImplementationFor<Device<?>> implementation =  ( deviceAlgorithm == null ? null : deviceAlgorithm.getImplementationFor(device) );
  128.             if ( implementation == null ) {
  129.                 String message = _couldNotFindSuitableImplementationFor( call.getOperation(), algorithm, device.getClass() );
  130.                 _LOG.error( message );
  131.                 throw new IllegalStateException( message );
  132.             }
  133.             else {
  134.                 device.approve( call );
  135.                 return implementation.run( (ExecutionCall<Device<?>>) call );
  136.             }
  137.         }
  138.     }

  139.     public static <D extends Device<?>> ExecutionCall<D> flatten(
  140.             Function caller, ExecutionCall<D> call
  141.     ) {
  142.         return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), true );
  143.     }

  144.     public static <D extends Device<?>> ExecutionCall<D> flattenForIndexer(
  145.             Function caller, ExecutionCall<D> call
  146.     ) {
  147.         return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), false );
  148.     }

  149.     private static <D extends Device<?>> ExecutionCall<D> _flatten(
  150.             ExecutionCall<D> call, Function[] src
  151.     ) {
  152.         return _flatten( call, src, true );
  153.     }

  154.    
  155.     private static <D extends Device<?>> ExecutionCall<D> _flatten(
  156.             ExecutionCall<D> call, Function[] src, boolean ignoreJs
  157.     ) {
  158.         ExecutionCall<D> innerCall = !ignoreJs ? call : call.withArgs( Arg.DerivIdx.of(-1) );
  159.         Tensor<?>[] inputs = innerCall.inputs();
  160.         return MemUtil.keep( inputs, () ->
  161.         {
  162.             Shape tempShape = null;
  163.             Class<?> tempType = null;
  164.             Tensor<?>[] tensors = new Tensor[src.length];
  165.             for ( int i = 0; i < tensors.length; i++ ) {//constants need to be figured out!
  166.                 if ( !( src[i] instanceof FunctionConstant) ) {
  167.                     tensors[ i ] = src[i].execute(innerCall);
  168.                     tempShape = ( tempShape == null ? tensors[ i ].shape() : tempShape );
  169.                     tempType  = ( tempType  == null ? tensors[ i ].getItemType()     : tempType  );
  170.                 }
  171.             }
  172.             int j = innerCall.getValOf( Arg.VarIdx.class );
  173.             for ( int i = 0; i < tensors.length; i++ )
  174.                 if ( tensors[ i ] == null )
  175.                     tensors[ i ] =
  176.                             j < 0
  177.                                 ? Tensor.of( tempType, tempShape, ((FunctionConstant) src[i]).value() ).mut().setIsIntermediate( true ).to(call.getDevice())
  178.                                 : Tensor.of( tempType, tempShape, src[i].call(new double[]{}, j)      ).mut().setIsIntermediate( true ).to(call.getDevice());

  179.             return innerCall.withInputs(tensors);
  180.         });
  181.     }

  182.    
  183.     private static Tensor<?> _deepActivation(
  184.             final ExecutionCall<? extends Device<?>> call,
  185.             final Function[] nodes,
  186.             final boolean isFlat,
  187.             final boolean isDoingAD,
  188.             final FinalExecutor executor
  189.     ) {
  190.         int j = call.getValOf( Arg.VarIdx.class );
  191.         assert call.getValOf( Arg.DerivIdx.class ) == -1;

  192.         ExecutionCall<?> flattenedCall = _flatten( call.withArgs( Arg.VarIdx.of(j) ), nodes );

  193.         if (
  194.                 !isFlat && j < 0 && (
  195.                         call.getOperation().isOperator()
  196.                                 ||
  197.                         call.getOperation().supportsAlgorithm(ElementwiseAlgorithm.class)
  198.                 )
  199.         ) {/*   '+', '-', 'x', '*', '%', '«', '»', ',', ...   */
  200.             String asStr = call.getOperation().stringify(
  201.                                         IntStream.range(0, nodes.length)
  202.                                             .mapToObj(i -> "I[" + i + "]")
  203.                                             .toArray(String[]::new)
  204.                                     );
  205.             Tensor<?>[] finalTensors = flattenedCall.inputs();
  206.             Tensor<?> result = MemUtil.keep(finalTensors, () -> new FunctionParser(Neureka.get().backend()).parse(asStr, isDoingAD).execute(finalTensors));
  207.             for ( int i = 1; i < finalTensors.length; i++ )
  208.                 _deleteIfNotIn(call.inputs(), finalTensors[i]);

  209.             return result;
  210.         } else {
  211.             int numberOfInputs = flattenedCall.arity();
  212.             boolean anyNumberOfInputs = flattenedCall.getOperation().getArity() < 0;
  213.             int operationArity = flattenedCall.getOperation().getArity();
  214.             if (numberOfInputs < operationArity)
  215.                 throw new IllegalArgumentException(
  216.                         "The number of inputs to the operation " + flattenedCall.getOperation() + " is " + numberOfInputs +
  217.                         " but the operation requires " + operationArity + " inputs."
  218.                     );

  219.             boolean tooManyArgs = numberOfInputs > operationArity + 1;

  220.             Tensor<?>[] tensors;

  221.             if ( !tooManyArgs || anyNumberOfInputs )
  222.                 tensors = flattenedCall.withAddedInputAt(0, null).inputs();
  223.             else
  224.                 tensors = flattenedCall.inputs();

  225.             return prepareAndExecute(
  226.                         call.withInputs( tensors ).withArgs( Arg.DerivIdx.of(-1), Arg.VarIdx.of(-1) ),
  227.                         executor
  228.                     );
  229.         }
  230.     }

  231.     /**
  232.      *  This method return the index of the tensor
  233.      *  in the given tensor array which is virtual and contains "1.0".
  234.      *  However, if not all tensors are virtual or their values are not all "0.0" except one
  235.      *  whose value is "1.0" then it returns -1, because the optimization cannot
  236.      *  be made...
  237.      *
  238.      * @param tensors An array of tensors which ought to be analyzed.
  239.      * @return The index of the tensor whose value is "1.0" (if all others are "0.0"), otherwise : -1
  240.      */
  241.    
  242.     private static int _indexOfFoundDerivative( final Tensor<?>[] tensors )
  243.     {
  244.         boolean allVirtual = true;
  245.         for ( Tensor<?> t : tensors )
  246.             if ( t != null && !t.isVirtual() ) allVirtual = false;

  247.         if ( allVirtual ) {
  248.             int index = -1;
  249.             for ( int i = 0; i < tensors.length; i++ ) {
  250.                 double value = ( tensors[ i ] == null ? 0.0 : tensors[ i ].getItemsAs( double[].class )[ 0 ] );
  251.                 if ( value == 1.0 ) {
  252.                     if ( index >= 0 ) return -1;
  253.                     index = i;
  254.                 }
  255.                 else if ( value != 0.0 ) return -1;
  256.             }
  257.             return index;
  258.         }
  259.         return -1;
  260.     }

  261.    
  262.     private static Tensor<?> _deepDerivative(
  263.             final ExecutionCall<? extends Device<?>> call,
  264.             final Function[] nodes,
  265.             final FinalExecutor executor
  266.     ) {
  267.         Supplier<Tensor<?>> actor = () ->
  268.                 MemUtil.keep( call.inputs(), () -> {
  269.                     int d = call.getValOf( Arg.DerivIdx.class );
  270.                     final int j = call.getValOf( Arg.VarIdx.class );
  271.                     assert d >= 0;

  272.                     Tensor<?>[] tensors;
  273.                     if ( call.getOperation().isIndexer() ) tensors = new Tensor[ 1 + call.arity() ];
  274.                     else tensors = new Tensor[ 1 + nodes.length ];

  275.                     // Chain-rule (forward AutoDiff):
  276.                     // inner times outer means:
  277.                     // first derive source!
  278.                     // like so:
  279.                     if ( call.getOperation().isIndexer() )
  280.                         for ( int i = 1; i < tensors.length; i++ )
  281.                             tensors[ i ] = nodes[ 0 ].executeDerive( call.inputs(), d, i - 1 );
  282.                     else
  283.                         for ( int i = 1; i < tensors.length; i++ )
  284.                             tensors[ i ] =
  285.                                         j >= 0
  286.                                             ? nodes[ i - 1 ].executeDerive( call.inputs(), d, j )
  287.                                             : nodes[ i - 1 ].executeDerive( call.inputs(), d    );

  288.                     //...then add them all together! (is possible because of linearity...)
  289.                     Tensor<?> inner;
  290.                     if ( tensors.length > 2 ) {// Optimization: Finds index of "1.0" among otherwise all "0.0" virtual tensors!
  291.                         int index = _indexOfFoundDerivative( tensors );
  292.                         if ( index >= 0 ) inner = tensors[ index ];
  293.                         else {
  294.                             // Optimization above did not apply, so we accumulate all the derivatives!
  295.                             tensors[0] = prepareAndExecute(
  296.                                                 ExecutionCall.of( tensors )
  297.                                                         .andArgs( Arg.DerivIdx.of( -1 ) )
  298.                                                         .running( Neureka.get().backend().getOperation("+") )
  299.                                                         .on( call.getDevice() ),
  300.                                                 innerCall -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( innerCall )
  301.                                         );
  302.                             inner = tensors[ 0 ];//-> this is now the inner derivative!
  303.                         }
  304.                     }
  305.                     else inner = tensors[ 1 ];

  306.                     tensors[ 0 ] = null;
  307.                     //...then activate (No differentiation!) the source like so:
  308.                     if ( call.getOperation().isIndexer() ) // Indexer pass an index j of course!
  309.                         for ( int i = 1; i < tensors.length; i++ )
  310.                             tensors[ i ] = nodes[ 0 ].execute( call.inputs(), i - 1 ); // i - 1 := j
  311.                     else
  312.                         for ( int i = 1; i < tensors.length; i++ )
  313.                             tensors[ i ] =
  314.                                     j >= 0
  315.                                         ? nodes[ i - 1 ].execute( call.inputs(), j )
  316.                                         : nodes[ i - 1 ].execute( call.inputs() );

  317.                     //...get derivative index within src list:
  318.                     for ( int i = 0; i < nodes.length; i++ )
  319.                         if ( nodes[ i ].dependsOn( d ) && !call.getOperation().isIndexer() ) {
  320.                             d = i;
  321.                             break;
  322.                         }

  323.                     // Use those tensors for the outer derivative:
  324.                     tensors[0] = prepareAndExecute(
  325.                                         ExecutionCall.of( tensors )
  326.                                                 .andArgs( Arg.DerivIdx.of( d ) )
  327.                                                 .running( call.getOperation() )
  328.                                                 .on( call.getDevice() ),
  329.                                         executor
  330.                                     );
  331.                     // At the end:
  332.                     //...multiply inner times outer: ( if inner is not 1 entirely... )
  333.                     Tensor<?> result = _innerTimesOuter( inner, tensors, call );
  334.                     // done!

  335.                     _delete( inner );

  336.                     return result;
  337.                 });

  338.         int d = call.getValOf( Arg.DerivIdx.class );
  339.         Tensor<?> out = null;
  340.         for ( int i = 0; i < nodes.length; i++ )
  341.         {
  342.             // constants need to be figured out!
  343.             int di = ( nodes[ i ].dependsOn( d ) ? i : -1 );
  344.             if ( di >= 0 )
  345.                 if ( out == null ) out = actor.get();
  346.                 else
  347.                     break;
  348.         }
  349.         return out;
  350.     }

  351.     private static Tensor<?> _innerTimesOuter(Tensor<?> inner, Tensor<?>[] tensors, ExecutionCall<?> call)
  352.     {
  353.         if ( !( ( inner.isVirtual() || inner.size() == 1 ) && inner.getItemsAs( double[].class )[ 0 ] == 1.0 ) ) {
  354.             tensors = new Tensor[]{ null, inner, tensors[ 0 ] };
  355.             tensors[0] = prepareAndExecute(
  356.                     ExecutionCall.of( tensors )
  357.                             .andArgs( Arg.DerivIdx.of( -1 ) )
  358.                             .running( Neureka.get().backend().getOperation("*") )
  359.                             .on( call.getDevice() ),
  360.                     AbstractDeviceAlgorithm::executeDeviceAlgorithm
  361.             );
  362.             for ( int i = 1; i < tensors.length; i++ )
  363.                 _deleteIfNotIn( call.inputs(), tensors[ i ] );
  364.         }
  365.         return tensors[ 0 ];
  366.     }

  367.     private static void _deleteIfNotIn(Tensor<?>[] array, Tensor<?> tensor ) {
  368.         if ( Neureka.get().settings().debug().isDeletingIntermediateTensors() ) {
  369.             for ( int i = 1; i < array.length; i++ )
  370.                 if ( array[i] == tensor ) return;

  371.             if ( !tensor.isDeleted() ) tensor.mut().delete();
  372.         }
  373.     }

  374.     private static void _delete( Tensor<?> tensor ) {
  375.         Neureka.Settings.Debug debug = Neureka.get().settings().debug();
  376.         if (  !tensor.isDeleted() && debug.isDeletingIntermediateTensors() )
  377.             tensor.mut().delete();
  378.     }

  379.     public static <R> R executeOnCommonDevice( ExecutionCall<?> call, Supplier<R> execution ) {
  380.         Device<Object> device = call.getDeviceFor(Object.class);

  381.         Consumer<Tensor<?>>[] rollbacks = new Consumer[ call.arity() ];
  382.         for (int i = 0; i < call.arity(); i++ )
  383.             if ( call.input( i ) != null && !call.input( i ).isOutsourced() ) {
  384.                 device.store( call.input( i ) );
  385.                 rollbacks[ i ] = tensor -> device.restore( (Tensor<Object>) tensor );
  386.             }
  387.             else
  388.                 rollbacks[ i ] = t -> {};

  389.         R result = execution.get();

  390.         if ( result == null )
  391.             throw new IllegalStateException( "Execution of " + call + " failed!" );

  392.         for ( int i = 0; i < rollbacks.length; i++ )
  393.             if ( call.input( i ) != null && !call.input( i ).isDeleted() && !call.input( i ).isUndefined() )
  394.                 rollbacks[ i ].accept( call.input( i ) );

  395.         return result;
  396.     }

  397.     private static String _couldNotFindSuitableAlgorithmFor( Class<?> type ) {
  398.         return LogUtil.format(
  399.                 "No suitable '"+ Algorithm.class.getSimpleName()+"' found for device of type '{}'.",
  400.                 type.getSimpleName()
  401.         );
  402.     }

  403.     private static String _couldNotFindSuitableImplementationFor(
  404.             Operation operation,
  405.             Algorithm algorithm,
  406.             Class<?> type
  407.     ) {
  408.         return LogUtil.format(
  409.                 "No suitable implementation found for operation '{}', algorithm '{}' and device type '{}'.",
  410.                 operation.getIdentifier(),
  411.                 algorithm.getName(),
  412.                 type.getSimpleName()
  413.         );
  414.     }


  415. }