AbstractDeviceAlgorithm.java
package neureka.backend.api.template.algorithms;
import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.*;
import neureka.backend.api.fun.ExecutionPreparation;
import neureka.backend.main.algorithms.ElementwiseAlgorithm;
import neureka.backend.main.internal.FinalExecutor;
import neureka.backend.main.memory.MemUtil;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.math.implementations.FunctionConstant;
import neureka.common.utility.LogUtil;
import neureka.devices.Device;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* This is a partial implementation of the {@link Algorithm} interface which implements
* the component system for implementation instances of the {@link ImplementationFor} interface.
* These components implement an algorithm for a specific {@link Algorithm}.
*
* @param <C> The type of the concrete extension of this class.
*/
public abstract class AbstractDeviceAlgorithm<C extends DeviceAlgorithm<C>>
extends AbstractAlgorithm
implements DeviceAlgorithm<C>
{
private final static Logger _LOG = LoggerFactory.getLogger(AbstractDeviceAlgorithm.class);
protected final Map<Class<Device<?>>, ImplementationFor<Device<?>>> _implementations = new HashMap<>();
public AbstractDeviceAlgorithm( String name ) { super( name ); }
@Override
public <D extends Device<?>, E extends ImplementationFor<D>> C setImplementationFor(
Class<D> deviceClass, E implementation
) {
if ( _implementations.containsKey( deviceClass ) )
_LOG.info(
"Implementation for device '" + deviceClass.getSimpleName() + "' already defined!"
);
_implementations.put(
(Class<Device<?>>) deviceClass,
(ImplementationFor<Device<?>>) implementation
);
return (C) this;
}
@Override
public <D extends Device<?>> ImplementationFor<D> getImplementationFor(Class<D> deviceClass )
{
ImplementationFor<D> found = (ImplementationFor<D>) _implementations.get( deviceClass );
if ( found == null )
for ( Class<Device<?>> type : _implementations.keySet() )
if ( type.isAssignableFrom(deviceClass) )
return (ImplementationFor<D>) _implementations.get(type);
return found;
}
@Override
public String toString() {
String algorithmString = getClass().getSimpleName()+"@"+Integer.toHexString(hashCode());
String implementations = _implementations.keySet().stream().map(Class::getSimpleName).collect(Collectors.joining(","));
algorithmString = ( algorithmString + "[name=" + getName() + ",support=[" + implementations + "]]" );
return algorithmString;
}
public static Tensor<?> executeFor(
final Function caller,
final ExecutionCall<? extends Device<?>> call,
final FinalExecutor executor
) {
Function[] nodes = caller.getSubFunctions().toArray(new Function[0]);
Operation operation = caller.getOperation();
assert call.getOperation() == operation;
boolean isFlat = caller.isFlat();
boolean isDoingAD = caller.isDoingAD();
if ( call.getValOf( Arg.DerivIdx.class ) < 0 )
return _deepActivation( call, nodes, isFlat, isDoingAD, executor );
else
return _deepDerivative( call, nodes, executor );
}
public static Tensor<?> prepareAndExecute(
ExecutionCall<? extends Device<?>> executionCall,
FinalExecutor executor
) {
ExecutionCall<? extends Device<?>> call = _prepareForExecution(executionCall);
return executeOnCommonDevice(call, ()->{
/*
Below is the core lambda of recursive preprocessing
which is defined for each Algorithm individually :
*/
Tensor<?> result = null;
if ( executor != null )
result = executor.execute(call);
return result;
});
}
public static ExecutionCall<? extends Device<?>> _prepareForExecution(ExecutionCall<? extends Device<?>> executionCall) {
Algorithm currentAlgorithm = executionCall.getAlgorithm();
if ( currentAlgorithm instanceof ExecutionPreparation)
executionCall = ( (ExecutionPreparation) currentAlgorithm ).prepare( executionCall );
for ( Tensor<?> t : executionCall.inputs() )
if ( t == null ) throw new IllegalArgumentException(
"Device arguments may not be null!\n" +
"One or more tensor arguments within the given ExecutionCall instance is null."
);
return executionCall;
}
public static Tensor<?> executeDeviceAlgorithm(
ExecutionCall<? extends Device<?>> call
) {
for ( Tensor<?> t : call.inputs() )
if ( t == null ) throw new IllegalArgumentException(
"Device arguments may not be null!\n" +
"One or more tensor arguments within the given ExecutionCall instance is null."
);
Device<?> device = call.getDevice();
Algorithm algorithm = call.getAlgorithm();
if ( algorithm == null ) {
String message = _couldNotFindSuitableAlgorithmFor( device.getClass() );
_LOG.error( message );
throw new IllegalStateException( message );
} else {
DeviceAlgorithm<?> deviceAlgorithm = ( algorithm instanceof DeviceAlgorithm ? ((DeviceAlgorithm<?>) algorithm) : null );
ImplementationFor<Device<?>> implementation = ( deviceAlgorithm == null ? null : deviceAlgorithm.getImplementationFor(device) );
if ( implementation == null ) {
String message = _couldNotFindSuitableImplementationFor( call.getOperation(), algorithm, device.getClass() );
_LOG.error( message );
throw new IllegalStateException( message );
}
else {
device.approve( call );
return implementation.run( (ExecutionCall<Device<?>>) call );
}
}
}
public static <D extends Device<?>> ExecutionCall<D> flatten(
Function caller, ExecutionCall<D> call
) {
return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), true );
}
public static <D extends Device<?>> ExecutionCall<D> flattenForIndexer(
Function caller, ExecutionCall<D> call
) {
return _flatten( call, caller.getSubFunctions().toArray(new Function[0]), false );
}
private static <D extends Device<?>> ExecutionCall<D> _flatten(
ExecutionCall<D> call, Function[] src
) {
return _flatten( call, src, true );
}
private static <D extends Device<?>> ExecutionCall<D> _flatten(
ExecutionCall<D> call, Function[] src, boolean ignoreJs
) {
ExecutionCall<D> innerCall = !ignoreJs ? call : call.withArgs( Arg.DerivIdx.of(-1) );
Tensor<?>[] inputs = innerCall.inputs();
return MemUtil.keep( inputs, () ->
{
Shape tempShape = null;
Class<?> tempType = null;
Tensor<?>[] tensors = new Tensor[src.length];
for ( int i = 0; i < tensors.length; i++ ) {//constants need to be figured out!
if ( !( src[i] instanceof FunctionConstant) ) {
tensors[ i ] = src[i].execute(innerCall);
tempShape = ( tempShape == null ? tensors[ i ].shape() : tempShape );
tempType = ( tempType == null ? tensors[ i ].getItemType() : tempType );
}
}
int j = innerCall.getValOf( Arg.VarIdx.class );
for ( int i = 0; i < tensors.length; i++ )
if ( tensors[ i ] == null )
tensors[ i ] =
j < 0
? Tensor.of( tempType, tempShape, ((FunctionConstant) src[i]).value() ).mut().setIsIntermediate( true ).to(call.getDevice())
: Tensor.of( tempType, tempShape, src[i].call(new double[]{}, j) ).mut().setIsIntermediate( true ).to(call.getDevice());
return innerCall.withInputs(tensors);
});
}
private static Tensor<?> _deepActivation(
final ExecutionCall<? extends Device<?>> call,
final Function[] nodes,
final boolean isFlat,
final boolean isDoingAD,
final FinalExecutor executor
) {
int j = call.getValOf( Arg.VarIdx.class );
assert call.getValOf( Arg.DerivIdx.class ) == -1;
ExecutionCall<?> flattenedCall = _flatten( call.withArgs( Arg.VarIdx.of(j) ), nodes );
if (
!isFlat && j < 0 && (
call.getOperation().isOperator()
||
call.getOperation().supportsAlgorithm(ElementwiseAlgorithm.class)
)
) {/* '+', '-', 'x', '*', '%', '«', '»', ',', ... */
String asStr = call.getOperation().stringify(
IntStream.range(0, nodes.length)
.mapToObj(i -> "I[" + i + "]")
.toArray(String[]::new)
);
Tensor<?>[] finalTensors = flattenedCall.inputs();
Tensor<?> result = MemUtil.keep(finalTensors, () -> new FunctionParser(Neureka.get().backend()).parse(asStr, isDoingAD).execute(finalTensors));
for ( int i = 1; i < finalTensors.length; i++ )
_deleteIfNotIn(call.inputs(), finalTensors[i]);
return result;
} else {
int numberOfInputs = flattenedCall.arity();
boolean anyNumberOfInputs = flattenedCall.getOperation().getArity() < 0;
int operationArity = flattenedCall.getOperation().getArity();
if (numberOfInputs < operationArity)
throw new IllegalArgumentException(
"The number of inputs to the operation " + flattenedCall.getOperation() + " is " + numberOfInputs +
" but the operation requires " + operationArity + " inputs."
);
boolean tooManyArgs = numberOfInputs > operationArity + 1;
Tensor<?>[] tensors;
if ( !tooManyArgs || anyNumberOfInputs )
tensors = flattenedCall.withAddedInputAt(0, null).inputs();
else
tensors = flattenedCall.inputs();
return prepareAndExecute(
call.withInputs( tensors ).withArgs( Arg.DerivIdx.of(-1), Arg.VarIdx.of(-1) ),
executor
);
}
}
/**
* This method return the index of the tensor
* in the given tensor array which is virtual and contains "1.0".
* However, if not all tensors are virtual or their values are not all "0.0" except one
* whose value is "1.0" then it returns -1, because the optimization cannot
* be made...
*
* @param tensors An array of tensors which ought to be analyzed.
* @return The index of the tensor whose value is "1.0" (if all others are "0.0"), otherwise : -1
*/
private static int _indexOfFoundDerivative( final Tensor<?>[] tensors )
{
boolean allVirtual = true;
for ( Tensor<?> t : tensors )
if ( t != null && !t.isVirtual() ) allVirtual = false;
if ( allVirtual ) {
int index = -1;
for ( int i = 0; i < tensors.length; i++ ) {
double value = ( tensors[ i ] == null ? 0.0 : tensors[ i ].getItemsAs( double[].class )[ 0 ] );
if ( value == 1.0 ) {
if ( index >= 0 ) return -1;
index = i;
}
else if ( value != 0.0 ) return -1;
}
return index;
}
return -1;
}
private static Tensor<?> _deepDerivative(
final ExecutionCall<? extends Device<?>> call,
final Function[] nodes,
final FinalExecutor executor
) {
Supplier<Tensor<?>> actor = () ->
MemUtil.keep( call.inputs(), () -> {
int d = call.getValOf( Arg.DerivIdx.class );
final int j = call.getValOf( Arg.VarIdx.class );
assert d >= 0;
Tensor<?>[] tensors;
if ( call.getOperation().isIndexer() ) tensors = new Tensor[ 1 + call.arity() ];
else tensors = new Tensor[ 1 + nodes.length ];
// Chain-rule (forward AutoDiff):
// inner times outer means:
// first derive source!
// like so:
if ( call.getOperation().isIndexer() )
for ( int i = 1; i < tensors.length; i++ )
tensors[ i ] = nodes[ 0 ].executeDerive( call.inputs(), d, i - 1 );
else
for ( int i = 1; i < tensors.length; i++ )
tensors[ i ] =
j >= 0
? nodes[ i - 1 ].executeDerive( call.inputs(), d, j )
: nodes[ i - 1 ].executeDerive( call.inputs(), d );
//...then add them all together! (is possible because of linearity...)
Tensor<?> inner;
if ( tensors.length > 2 ) {// Optimization: Finds index of "1.0" among otherwise all "0.0" virtual tensors!
int index = _indexOfFoundDerivative( tensors );
if ( index >= 0 ) inner = tensors[ index ];
else {
// Optimization above did not apply, so we accumulate all the derivatives!
tensors[0] = prepareAndExecute(
ExecutionCall.of( tensors )
.andArgs( Arg.DerivIdx.of( -1 ) )
.running( Neureka.get().backend().getOperation("+") )
.on( call.getDevice() ),
innerCall -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( innerCall )
);
inner = tensors[ 0 ];//-> this is now the inner derivative!
}
}
else inner = tensors[ 1 ];
tensors[ 0 ] = null;
//...then activate (No differentiation!) the source like so:
if ( call.getOperation().isIndexer() ) // Indexer pass an index j of course!
for ( int i = 1; i < tensors.length; i++ )
tensors[ i ] = nodes[ 0 ].execute( call.inputs(), i - 1 ); // i - 1 := j
else
for ( int i = 1; i < tensors.length; i++ )
tensors[ i ] =
j >= 0
? nodes[ i - 1 ].execute( call.inputs(), j )
: nodes[ i - 1 ].execute( call.inputs() );
//...get derivative index within src list:
for ( int i = 0; i < nodes.length; i++ )
if ( nodes[ i ].dependsOn( d ) && !call.getOperation().isIndexer() ) {
d = i;
break;
}
// Use those tensors for the outer derivative:
tensors[0] = prepareAndExecute(
ExecutionCall.of( tensors )
.andArgs( Arg.DerivIdx.of( d ) )
.running( call.getOperation() )
.on( call.getDevice() ),
executor
);
// At the end:
//...multiply inner times outer: ( if inner is not 1 entirely... )
Tensor<?> result = _innerTimesOuter( inner, tensors, call );
// done!
_delete( inner );
return result;
});
int d = call.getValOf( Arg.DerivIdx.class );
Tensor<?> out = null;
for ( int i = 0; i < nodes.length; i++ )
{
// constants need to be figured out!
int di = ( nodes[ i ].dependsOn( d ) ? i : -1 );
if ( di >= 0 )
if ( out == null ) out = actor.get();
else
break;
}
return out;
}
private static Tensor<?> _innerTimesOuter(Tensor<?> inner, Tensor<?>[] tensors, ExecutionCall<?> call)
{
if ( !( ( inner.isVirtual() || inner.size() == 1 ) && inner.getItemsAs( double[].class )[ 0 ] == 1.0 ) ) {
tensors = new Tensor[]{ null, inner, tensors[ 0 ] };
tensors[0] = prepareAndExecute(
ExecutionCall.of( tensors )
.andArgs( Arg.DerivIdx.of( -1 ) )
.running( Neureka.get().backend().getOperation("*") )
.on( call.getDevice() ),
AbstractDeviceAlgorithm::executeDeviceAlgorithm
);
for ( int i = 1; i < tensors.length; i++ )
_deleteIfNotIn( call.inputs(), tensors[ i ] );
}
return tensors[ 0 ];
}
private static void _deleteIfNotIn(Tensor<?>[] array, Tensor<?> tensor ) {
if ( Neureka.get().settings().debug().isDeletingIntermediateTensors() ) {
for ( int i = 1; i < array.length; i++ )
if ( array[i] == tensor ) return;
if ( !tensor.isDeleted() ) tensor.mut().delete();
}
}
private static void _delete( Tensor<?> tensor ) {
Neureka.Settings.Debug debug = Neureka.get().settings().debug();
if ( !tensor.isDeleted() && debug.isDeletingIntermediateTensors() )
tensor.mut().delete();
}
public static <R> R executeOnCommonDevice( ExecutionCall<?> call, Supplier<R> execution ) {
Device<Object> device = call.getDeviceFor(Object.class);
Consumer<Tensor<?>>[] rollbacks = new Consumer[ call.arity() ];
for (int i = 0; i < call.arity(); i++ )
if ( call.input( i ) != null && !call.input( i ).isOutsourced() ) {
device.store( call.input( i ) );
rollbacks[ i ] = tensor -> device.restore( (Tensor<Object>) tensor );
}
else
rollbacks[ i ] = t -> {};
R result = execution.get();
if ( result == null )
throw new IllegalStateException( "Execution of " + call + " failed!" );
for ( int i = 0; i < rollbacks.length; i++ )
if ( call.input( i ) != null && !call.input( i ).isDeleted() && !call.input( i ).isUndefined() )
rollbacks[ i ].accept( call.input( i ) );
return result;
}
private static String _couldNotFindSuitableAlgorithmFor( Class<?> type ) {
return LogUtil.format(
"No suitable '"+ Algorithm.class.getSimpleName()+"' found for device of type '{}'.",
type.getSimpleName()
);
}
private static String _couldNotFindSuitableImplementationFor(
Operation operation,
Algorithm algorithm,
Class<?> type
) {
return LogUtil.format(
"No suitable implementation found for operation '{}', algorithm '{}' and device type '{}'.",
operation.getIdentifier(),
algorithm.getName(),
type.getSimpleName()
);
}
}