FunctionNode.java
package neureka.math.implementations;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Operation;
import neureka.backend.main.operations.other.Permute;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.args.Args;
import java.util.Arrays;
import java.util.List;
/**
* The most common type of {@link Function} which references other {@link Function}s to
* form an abstract syntax tree.
*/
public final class FunctionNode implements Function
{
private final Operation _operation;
private final boolean _isFlat;
private final boolean _isDoingAD;
private final Function[] _src;
/**
* @param type The operation which ought to be represented.
* @param sources The child function nodes of this node.
* @param doAD A flag determining if this function should perform autograd.
*/
public FunctionNode( Operation type, List<Function> sources, boolean doAD )
{
if ( type.getArity() >= 0 && sources.size() != type.getArity() ) {
String tip = ( type.isIndexer() )
? "\nNote: This function is an 'indexer'. Therefore it expects to sum variable 'I[j]' inputs, where 'j' is the index of an iteration."
: "";
throw new IllegalArgumentException(
"The function/operation '" + type.getOperator() + "' expects " + type.getArity() + " parameters, " +
"however " + sources.size() + " where given!" + tip
);
}
boolean isFlat = true;
for ( Function f : sources ) // AbstractFunction does only reference tip nodes of the function graph:
isFlat = (
(f instanceof FunctionInput) || (f instanceof FunctionVariable) || (f instanceof FunctionConstant)
) && isFlat;
_operation = type;
_isFlat = isFlat;
_src = sources.toArray( new Function[0] );
_isDoingAD = doAD;
for ( int i = 0; i < _src.length; i++ ) {
if ( _src[i] == null )
throw new IllegalArgumentException("The function node '" + this + "' has a null source at index " + i + "!");
if ( _src[i] instanceof FunctionNode && _src[i].isDoingAD() != _isDoingAD )
throw new IllegalArgumentException(
"Detected an attempt to mix autograd and non-autograd functions in the same function graph!\n" +
"A function can either be doing autograd or not doing autograd!"
);
}
}
@Override
public String toString()
{
return _operation.stringify(
Arrays.stream( _src )
.map( e -> e == null ? "(null)" : e.toString() )
.toArray( String[]::new )
);
}
@Override
public boolean dependsOn( int index ) {
for ( Function f : _src )
if ( f.dependsOn( index ) ) return true;
return false;
}
@Override
public Function getDerivative( int index ) { return Function.of( _operation.asDerivative( _src, index ) ); }
@Override
public List<Function> getSubFunctions() { return Arrays.asList(_src); }
@Override
public Tensor<?> execute(Args arguments, Tensor<?>... inputs )
{
if ( this.isDoingAD() )
Permute.makeFit( inputs, this.isDoingAD() ); // reshaping if needed
ExecutionCall<? extends Device<?>> call = ExecutionCall.of(inputs)
.andArgs(arguments.getAll(Arg.class))
.running(_operation)
.on(_deviceFor(inputs));
return call.getOperation()
.execute( this, call ).get();
}
/**
* This method tries to find a common {@link Device} for the provided {@link Tensor}s.
*
* @param inputs The input {@link Tensor}s for which a {@link Device} ought to be found and returned.
* @return A found {@link Device} implementation instance.
*/
private Device<?> _deviceFor( Tensor<?>[] inputs )
{
if ( inputs.length == 0 ) return CPU.get();
Device<?> device = inputs[ 0 ].get( Device.class );
boolean onSameDevice = _shareGuestDevice( inputs );
boolean doAccel = !_operation.getOperator().equals(",") && onSameDevice;
return ( doAccel && device != null ? device : inputs[ 0 ].getDevice() );
}
/**
* @param tensors An array of tensors for which the most common {@link Device} should be determined.
* @return The most common {@link Device} among the provided tensors.
*/
private static boolean _shareGuestDevice( Tensor<?>[] tensors )
{
boolean onSameGuestDevice = true;
Device<?> device = null;
for ( Tensor<?> tensor : tensors ) device = ( tensor.isOutsourced() ? tensor.get( Device.class ) : device );
if ( device != null ) {
for ( Tensor<?> tensor : tensors ) {
onSameGuestDevice = ( !tensor.isVirtual() && device == tensor.get(Device.class) ) && onSameGuestDevice;
}
}
else onSameGuestDevice = false;
if ( device != null && tensors.length == 2 && tensors[ 1 ].size() == 1 ) onSameGuestDevice = true;
return onSameGuestDevice;
}
@Override
public double call( final double[] inputs, int j ) {
return this.getOperation().calculate( inputs, j, -1, _src );
}
@Override
public double derive( final double[] inputs, final int d, final int j ) {
return this.getOperation().calculate( inputs, j, d, _src );
}
@Override
public double derive( final double[] inputs, final int d ) {
return this.getOperation().calculate( inputs, -1, d, _src );
}
@Override
public Operation getOperation() { return _operation; }
@Override
public boolean isFlat() { return _isFlat; }
@Override
public boolean isDoingAD() { return _isDoingAD; }
}