AbstractDeviceAlgorithm.java
- package neureka.backend.api.template.algorithms;
- import neureka.Neureka;
- import neureka.Shape;
- import neureka.Tensor;
- import neureka.backend.api.*;
- import neureka.backend.api.fun.ExecutionPreparation;
- import neureka.backend.main.algorithms.ElementwiseAlgorithm;
- import neureka.backend.main.internal.FinalExecutor;
- import neureka.backend.main.memory.MemUtil;
- import neureka.math.Function;
- import neureka.math.args.Arg;
- import neureka.math.parsing.FunctionParser;
- import neureka.math.implementations.FunctionConstant;
- import neureka.common.utility.LogUtil;
- import neureka.devices.Device;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import java.util.HashMap;
- import java.util.Map;
- import java.util.function.Consumer;
- import java.util.function.Supplier;
- import java.util.stream.Collectors;
- import java.util.stream.IntStream;
- /**
- * This is a partial implementation of the {@link Algorithm} interface which implements
- * the component system for implementation instances of the {@link ImplementationFor} interface.
- * These components implement an algorithm for a specific {@link Algorithm}.
- *
- * @param <C> The type of the concrete extension of this class.
- */
- public abstract class AbstractDeviceAlgorithm<C extends DeviceAlgorithm<C>>
- extends AbstractAlgorithm
- implements DeviceAlgorithm<C>
- {
- private final static Logger _LOG = LoggerFactory.getLogger(AbstractDeviceAlgorithm.class);
- protected final Map<Class<Device<?>>, ImplementationFor<Device<?>>> _implementations = new HashMap<>();
- public AbstractDeviceAlgorithm( String name ) { super( name ); }
- @Override
- public <D extends Device<?>, E extends ImplementationFor<D>> C setImplementationFor(
- Class<D> deviceClass, E implementation
- ) {
- if ( _implementations.containsKey( deviceClass ) )
- _LOG.info(
- "Implementation for device '" + deviceClass.getSimpleName() + "' already defined!"
- );
- _implementations.put(
- (Class<Device<?>>) deviceClass,
- (ImplementationFor<Device<?>>) implementation
- );
- return (C) this;
- }
- @Override
- public <D extends Device<?>> ImplementationFor<D> getImplementationFor(Class<D> deviceClass )
- {
- ImplementationFor<D> found = (ImplementationFor<D>) _implementations.get( deviceClass );
- if ( found == null )
- for ( Class<Device<?>> type : _implementations.keySet() )
- if ( type.isAssignableFrom(deviceClass) )
- return (ImplementationFor<D>) _implementations.get(type);
- return found;
- }
- @Override
- public String toString() {
- String algorithmString = getClass().getSimpleName()+"@"+Integer.toHexString(hashCode());
- String implementations = _implementations.keySet().stream().map(Class::getSimpleName).collect(Collectors.joining(","));
- algorithmString = ( algorithmString + "[name=" + getName() + ",support=[" + implementations + "]]" );
- return algorithmString;
- }
- public static Tensor<?> executeFor(
- final Function caller,
- final ExecutionCall<? extends Device<?>> call,
- final FinalExecutor executor
- ) {
- Function[] nodes = caller.getSubFunctions().toArray(new Function[0]);
- Operation operation = caller.getOperation();
- assert call.getOperation() == operation;
- boolean isFlat = caller.isFlat();
- boolean isDoingAD = caller.isDoingAD();
- if ( call.getValOf( Arg.DerivIdx.class ) < 0 )
- return _deepActivation( call, nodes, isFlat, isDoingAD, executor );
- else
- return _deepDerivative( call, nodes, executor );
- }
- public static Tensor<?> prepareAndExecute(
- ExecutionCall<? extends Device<?>> executionCall,
- FinalExecutor executor
- ) {
- ExecutionCall<? extends Device<?>> call = _prepareForExecution(executionCall);
- return executeOnCommonDevice(call, ()->{
- /*
- Below is the core lambda of recursive preprocessing
- which is defined for each Algorithm individually :
- */
- Tensor<?> result = null;
- if ( executor != null )
- result = executor.execute(call);
- return result;
- });
- }
- public static ExecutionCall<? extends Device<?>> _prepareForExecution(ExecutionCall<? extends Device<?>> executionCall) {
- Algorithm currentAlgorithm = executionCall.getAlgorithm();
- if ( currentAlgorithm instanceof ExecutionPreparation)
- executionCall = ( (ExecutionPreparation) currentAlgorithm ).prepare( executionCall );
- for ( Tensor<?> t : executionCall.inputs() )
- if ( t == null ) throw new IllegalArgumentException(
- "Device arguments may not be null!\n" +
- "One or more tensor arguments within the given ExecutionCall instance is null."
- );
- return executionCall;
- }
- public static Tensor<?> executeDeviceAlgorithm(
- ExecutionCall<? extends Device<?>> call
- ) {
- for ( Tensor<?> t : call.inputs() )
- if ( t == null ) throw new IllegalArgumentException(
- "Device arguments may not be null!\n" +
- "One or more tensor arguments within the given ExecutionCall instance is null."
- );
- Device<?> device = call.getDevice();
- Algorithm algorithm = call.getAlgorithm();
- if ( algorithm == null ) {
- String message = _couldNotFindSuitableAlgorithmFor( device.getClass() );
- _LOG.error( message );
- throw new IllegalStateException( message );
- } else {
- DeviceAlgorithm<?> deviceAlgorithm = ( algorithm instanceof DeviceAlgorithm ? ((DeviceAlgorithm<?>) algorithm) : null );
- ImplementationFor<Device<?>> implementation = ( deviceAlgorithm == null ? null : deviceAlgorithm.getImplementationFor(device) );
- if ( implementation == null ) {
- String message = _couldNotFindSuitableImplementationFor( call.getOperation(), algorithm, device.getClass() );
- _LOG.error( message );
- throw new IllegalStateException( message );
- }
- else {
- device.approve( call );
- return implementation.run( (ExecutionCall<Device<?>>) call );
- }
- }
- }
- public static <D extends Device<?>> ExecutionCall<D> flatten(
- Function caller, ExecutionCall<D> call
- ) {
- return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), true );
- }
- public static <D extends Device<?>> ExecutionCall<D> flattenForIndexer(
- Function caller, ExecutionCall<D> call
- ) {
- return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), false );
- }
- private static <D extends Device<?>> ExecutionCall<D> _flatten(
- ExecutionCall<D> call, Function[] src
- ) {
- return _flatten( call, src, true );
- }
-
- private static <D extends Device<?>> ExecutionCall<D> _flatten(
- ExecutionCall<D> call, Function[] src, boolean ignoreJs
- ) {
- ExecutionCall<D> innerCall = !ignoreJs ? call : call.withArgs( Arg.DerivIdx.of(-1) );
- Tensor<?>[] inputs = innerCall.inputs();
- return MemUtil.keep( inputs, () ->
- {
- Shape tempShape = null;
- Class<?> tempType = null;
- Tensor<?>[] tensors = new Tensor[src.length];
- for ( int i = 0; i < tensors.length; i++ ) {//constants need to be figured out!
- if ( !( src[i] instanceof FunctionConstant) ) {
- tensors[ i ] = src[i].execute(innerCall);
- tempShape = ( tempShape == null ? tensors[ i ].shape() : tempShape );
- tempType = ( tempType == null ? tensors[ i ].getItemType() : tempType );
- }
- }
- int j = innerCall.getValOf( Arg.VarIdx.class );
- for ( int i = 0; i < tensors.length; i++ )
- if ( tensors[ i ] == null )
- tensors[ i ] =
- j < 0
- ? Tensor.of( tempType, tempShape, ((FunctionConstant) src[i]).value() ).mut().setIsIntermediate( true ).to(call.getDevice())
- : Tensor.of( tempType, tempShape, src[i].call(new double[]{}, j) ).mut().setIsIntermediate( true ).to(call.getDevice());
- return innerCall.withInputs(tensors);
- });
- }
-
- private static Tensor<?> _deepActivation(
- final ExecutionCall<? extends Device<?>> call,
- final Function[] nodes,
- final boolean isFlat,
- final boolean isDoingAD,
- final FinalExecutor executor
- ) {
- int j = call.getValOf( Arg.VarIdx.class );
- assert call.getValOf( Arg.DerivIdx.class ) == -1;
- ExecutionCall<?> flattenedCall = _flatten( call.withArgs( Arg.VarIdx.of(j) ), nodes );
- if (
- !isFlat && j < 0 && (
- call.getOperation().isOperator()
- ||
- call.getOperation().supportsAlgorithm(ElementwiseAlgorithm.class)
- )
- ) {/* '+', '-', 'x', '*', '%', '«', '»', ',', ... */
- String asStr = call.getOperation().stringify(
- IntStream.range(0, nodes.length)
- .mapToObj(i -> "I[" + i + "]")
- .toArray(String[]::new)
- );
- Tensor<?>[] finalTensors = flattenedCall.inputs();
- Tensor<?> result = MemUtil.keep(finalTensors, () -> new FunctionParser(Neureka.get().backend()).parse(asStr, isDoingAD).execute(finalTensors));
- for ( int i = 1; i < finalTensors.length; i++ )
- _deleteIfNotIn(call.inputs(), finalTensors[i]);
- return result;
- } else {
- int numberOfInputs = flattenedCall.arity();
- boolean anyNumberOfInputs = flattenedCall.getOperation().getArity() < 0;
- int operationArity = flattenedCall.getOperation().getArity();
- if (numberOfInputs < operationArity)
- throw new IllegalArgumentException(
- "The number of inputs to the operation " + flattenedCall.getOperation() + " is " + numberOfInputs +
- " but the operation requires " + operationArity + " inputs."
- );
- boolean tooManyArgs = numberOfInputs > operationArity + 1;
- Tensor<?>[] tensors;
- if ( !tooManyArgs || anyNumberOfInputs )
- tensors = flattenedCall.withAddedInputAt(0, null).inputs();
- else
- tensors = flattenedCall.inputs();
- return prepareAndExecute(
- call.withInputs( tensors ).withArgs( Arg.DerivIdx.of(-1), Arg.VarIdx.of(-1) ),
- executor
- );
- }
- }
- /**
- * This method return the index of the tensor
- * in the given tensor array which is virtual and contains "1.0".
- * However, if not all tensors are virtual or their values are not all "0.0" except one
- * whose value is "1.0" then it returns -1, because the optimization cannot
- * be made...
- *
- * @param tensors An array of tensors which ought to be analyzed.
- * @return The index of the tensor whose value is "1.0" (if all others are "0.0"), otherwise : -1
- */
-
- private static int _indexOfFoundDerivative( final Tensor<?>[] tensors )
- {
- boolean allVirtual = true;
- for ( Tensor<?> t : tensors )
- if ( t != null && !t.isVirtual() ) allVirtual = false;
- if ( allVirtual ) {
- int index = -1;
- for ( int i = 0; i < tensors.length; i++ ) {
- double value = ( tensors[ i ] == null ? 0.0 : tensors[ i ].getItemsAs( double[].class )[ 0 ] );
- if ( value == 1.0 ) {
- if ( index >= 0 ) return -1;
- index = i;
- }
- else if ( value != 0.0 ) return -1;
- }
- return index;
- }
- return -1;
- }
-
- private static Tensor<?> _deepDerivative(
- final ExecutionCall<? extends Device<?>> call,
- final Function[] nodes,
- final FinalExecutor executor
- ) {
- Supplier<Tensor<?>> actor = () ->
- MemUtil.keep( call.inputs(), () -> {
- int d = call.getValOf( Arg.DerivIdx.class );
- final int j = call.getValOf( Arg.VarIdx.class );
- assert d >= 0;
- Tensor<?>[] tensors;
- if ( call.getOperation().isIndexer() ) tensors = new Tensor[ 1 + call.arity() ];
- else tensors = new Tensor[ 1 + nodes.length ];
- // Chain-rule (forward AutoDiff):
- // inner times outer means:
- // first derive source!
- // like so:
- if ( call.getOperation().isIndexer() )
- for ( int i = 1; i < tensors.length; i++ )
- tensors[ i ] = nodes[ 0 ].executeDerive( call.inputs(), d, i - 1 );
- else
- for ( int i = 1; i < tensors.length; i++ )
- tensors[ i ] =
- j >= 0
- ? nodes[ i - 1 ].executeDerive( call.inputs(), d, j )
- : nodes[ i - 1 ].executeDerive( call.inputs(), d );
- //...then add them all together! (is possible because of linearity...)
- Tensor<?> inner;
- if ( tensors.length > 2 ) {// Optimization: Finds index of "1.0" among otherwise all "0.0" virtual tensors!
- int index = _indexOfFoundDerivative( tensors );
- if ( index >= 0 ) inner = tensors[ index ];
- else {
- // Optimization above did not apply, so we accumulate all the derivatives!
- tensors[0] = prepareAndExecute(
- ExecutionCall.of( tensors )
- .andArgs( Arg.DerivIdx.of( -1 ) )
- .running( Neureka.get().backend().getOperation("+") )
- .on( call.getDevice() ),
- innerCall -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( innerCall )
- );
- inner = tensors[ 0 ];//-> this is now the inner derivative!
- }
- }
- else inner = tensors[ 1 ];
- tensors[ 0 ] = null;
- //...then activate (No differentiation!) the source like so:
- if ( call.getOperation().isIndexer() ) // Indexer pass an index j of course!
- for ( int i = 1; i < tensors.length; i++ )
- tensors[ i ] = nodes[ 0 ].execute( call.inputs(), i - 1 ); // i - 1 := j
- else
- for ( int i = 1; i < tensors.length; i++ )
- tensors[ i ] =
- j >= 0
- ? nodes[ i - 1 ].execute( call.inputs(), j )
- : nodes[ i - 1 ].execute( call.inputs() );
- //...get derivative index within src list:
- for ( int i = 0; i < nodes.length; i++ )
- if ( nodes[ i ].dependsOn( d ) && !call.getOperation().isIndexer() ) {
- d = i;
- break;
- }
- // Use those tensors for the outer derivative:
- tensors[0] = prepareAndExecute(
- ExecutionCall.of( tensors )
- .andArgs( Arg.DerivIdx.of( d ) )
- .running( call.getOperation() )
- .on( call.getDevice() ),
- executor
- );
- // At the end:
- //...multiply inner times outer: ( if inner is not 1 entirely... )
- Tensor<?> result = _innerTimesOuter( inner, tensors, call );
- // done!
- _delete( inner );
- return result;
- });
- int d = call.getValOf( Arg.DerivIdx.class );
- Tensor<?> out = null;
- for ( int i = 0; i < nodes.length; i++ )
- {
- // constants need to be figured out!
- int di = ( nodes[ i ].dependsOn( d ) ? i : -1 );
- if ( di >= 0 )
- if ( out == null ) out = actor.get();
- else
- break;
- }
- return out;
- }
- private static Tensor<?> _innerTimesOuter(Tensor<?> inner, Tensor<?>[] tensors, ExecutionCall<?> call)
- {
- if ( !( ( inner.isVirtual() || inner.size() == 1 ) && inner.getItemsAs( double[].class )[ 0 ] == 1.0 ) ) {
- tensors = new Tensor[]{ null, inner, tensors[ 0 ] };
- tensors[0] = prepareAndExecute(
- ExecutionCall.of( tensors )
- .andArgs( Arg.DerivIdx.of( -1 ) )
- .running( Neureka.get().backend().getOperation("*") )
- .on( call.getDevice() ),
- AbstractDeviceAlgorithm::executeDeviceAlgorithm
- );
- for ( int i = 1; i < tensors.length; i++ )
- _deleteIfNotIn( call.inputs(), tensors[ i ] );
- }
- return tensors[ 0 ];
- }
- private static void _deleteIfNotIn(Tensor<?>[] array, Tensor<?> tensor ) {
- if ( Neureka.get().settings().debug().isDeletingIntermediateTensors() ) {
- for ( int i = 1; i < array.length; i++ )
- if ( array[i] == tensor ) return;
- if ( !tensor.isDeleted() ) tensor.mut().delete();
- }
- }
- private static void _delete( Tensor<?> tensor ) {
- Neureka.Settings.Debug debug = Neureka.get().settings().debug();
- if ( !tensor.isDeleted() && debug.isDeletingIntermediateTensors() )
- tensor.mut().delete();
- }
- public static <R> R executeOnCommonDevice( ExecutionCall<?> call, Supplier<R> execution ) {
- Device<Object> device = call.getDeviceFor(Object.class);
- Consumer<Tensor<?>>[] rollbacks = new Consumer[ call.arity() ];
- for (int i = 0; i < call.arity(); i++ )
- if ( call.input( i ) != null && !call.input( i ).isOutsourced() ) {
- device.store( call.input( i ) );
- rollbacks[ i ] = tensor -> device.restore( (Tensor<Object>) tensor );
- }
- else
- rollbacks[ i ] = t -> {};
- R result = execution.get();
- if ( result == null )
- throw new IllegalStateException( "Execution of " + call + " failed!" );
- for ( int i = 0; i < rollbacks.length; i++ )
- if ( call.input( i ) != null && !call.input( i ).isDeleted() && !call.input( i ).isUndefined() )
- rollbacks[ i ].accept( call.input( i ) );
- return result;
- }
- private static String _couldNotFindSuitableAlgorithmFor( Class<?> type ) {
- return LogUtil.format(
- "No suitable '"+ Algorithm.class.getSimpleName()+"' found for device of type '{}'.",
- type.getSimpleName()
- );
- }
- private static String _couldNotFindSuitableImplementationFor(
- Operation operation,
- Algorithm algorithm,
- Class<?> type
- ) {
- return LogUtil.format(
- "No suitable implementation found for operation '{}', algorithm '{}' and device type '{}'.",
- operation.getIdentifier(),
- algorithm.getName(),
- type.getSimpleName()
- );
- }
- }