AssignLeft.java
package neureka.backend.main.operations.other;
import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.ElementwiseAlgorithm;
import neureka.backend.main.algorithms.BiScalarBroadcast;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
public class AssignLeft extends AbstractOperation
{
public AssignLeft() {
super(
new OperationBuilder()
.identifier( "left_inline" )
.operator( "<" )
.arity( -2 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( false )
.isInline( true )
);
setAlgorithm(
BiScalarBroadcast.class,
new BiScalarBroadcast()
.setIsSuitableFor(
call -> {
if ( call.arity() > 3 )
throw new IllegalArgumentException("AssignLeft operation only supports up to 3 arguments!");
if ( call.arity() < 2 )
throw new IllegalArgumentException("AssignLeft operation needs at least 2 arguments!");
int offset = call.arity() - 1;
if ( call.input( offset ).isVirtual() || call.input( offset ).size() == 1 )
return call.validate()
.allNotNull( t -> t.getDataType().typeClassImplements(Object.class) )
//.allNotNull( Tensor::isVirtual )
.tensors( tensors -> tensors.length == 2 || tensors.length == 3 )
.suitabilityIfValid(SuitabilityPredicate.PERFECT);
else
return SuitabilityPredicate.UNSUITABLE;
}
)
.setAutogradModeFor( call -> AutoDiffMode.NOT_SUPPORTED)
.setExecution( (caller, call) -> {
Tensor<?> t = AbstractDeviceAlgorithm.executeDeviceAlgorithm( call );
t.mut().incrementVersion(call);
return Result.of(t);
})
.setCallPreparation(
call -> {
int offset = ( call.input( 0 ) == null ? 1 : 0 );
call.input( offset ).mut().setIsVirtual( false );
return
ExecutionCall.of( call.input( offset ), call.input( offset + 1 ) )
.andArgs(Arg.DerivIdx.of(-1))
.running(this)
.on( call.getDevice() );
}
)
.buildFunAlgorithm()
);
setAlgorithm(
new ElementwiseAlgorithm()
.setIsSuitableFor(
call -> call.validate()
.allNotNull( t -> t.getDataType().typeClassImplements(Object.class) )
//.allNotNull( t -> !t.isVirtual() )
.tensors( tensors -> tensors.length == 2 || tensors.length == 3 )
.suitabilityIfValid(SuitabilityPredicate.EXCELLENT)
)
.setAutogradModeFor( call -> AutoDiffMode.NOT_SUPPORTED)
.setExecution( (caller, call) -> {
Tensor<?> t = AbstractDeviceAlgorithm.executeDeviceAlgorithm( call );
t.mut().incrementVersion(call);
return Result.of(t);
})
.setCallPreparation(
call -> {
int offset = ( call.input( 0 ) == null ? 1 : 0 );
return ExecutionCall.of( call.input(offset), call.input(1+offset) )
.running(Neureka.get().backend().getOperation("idy"))
.on( call.getDevice() );
}
)
.buildFunAlgorithm()
);
}
@Override
public Result execute( final Function caller, ExecutionCall<?> call )
{
if ( call.getDerivativeIndex() >= 0 )
throw new IllegalArgumentException("Assignment does not support autograd!");
Function reducedCaller = reducePairwise(caller);
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
for (Tensor<?> t : call.inputs()) t.mut().setIsIntermediate(false);
Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), false );
return super.execute( flat, flatCall );
}
private Function reducePairwise( final Function fun ) {
Function reduced = fun;
if ( reduced.getSubFunctions().size() > 2 ) {
/*
So currently we have something like this: a <- b <- c <- d...
However, this is how it is really executed: (a**(b**(c**(d**..))))
...so let's create a function that is nested like the above:
*/
Function nested = reduced.getSubFunctions().get(reduced.getSubFunctions().size()-1);
for ( int i = reduced.getSubFunctions().size()-2; i >= 0; i-- )
nested = Function.of( reduced.getSubFunctions().get(i) + " <- " + nested, true );
reduced = nested;
}
return reduced;
}
@Override
public String stringify( String[] children ) {
StringBuilder reconstructed = new StringBuilder();
for ( int i = 0; i < children.length; ++i ) {
reconstructed.append( children[ i ] );
if ( i < children.length - 1 ) reconstructed.append(" <- ");
}
return "(" + reconstructed + ")";
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
int right = src.length - 1;
return d >= 0 ? src[ right ].derive( inputs, d, j ) : src[ right ].call( inputs, j );
}
}