Subtraction.java
package neureka.backend.main.operations.operator;
import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.BiElementwise;
import neureka.backend.main.algorithms.BiScalarBroadcast;
import neureka.backend.main.algorithms.Broadcast;
import neureka.backend.main.operations.ElemWiseUtil;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.devices.Device;
import neureka.ndim.NDimensional;
import java.util.Arrays;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class Subtraction extends AbstractOperation
{
public Subtraction()
{
super(
new OperationBuilder()
.identifier( "subtract" )
.operator( "-" )
.arity( -1 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
new BiElementwise()
.setSupplyADActionFor( getDefaultAlgorithm() )
.buildFunAlgorithm()
);
setAlgorithm(
BiScalarBroadcast.class,
new BiScalarBroadcast()
.setIsSuitableFor( call -> SuitabilityPredicate.BAD )
.setExecution( (caller, call) -> Result.of(AbstractDeviceAlgorithm.executeFor(caller, call, AbstractDeviceAlgorithm::executeDeviceAlgorithm)).withAutoDiff( FallbackAlgorithm::ADAction ))
.buildFunAlgorithm()
);
setAlgorithm(
Broadcast.class,
new Broadcast()
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setSupplyADActionFor(
( Function f, ExecutionCall<? extends Device<?>> call ) ->
{
if ( call.autogradMode().allowsForward() )
throw new IllegalArgumentException("Broadcast implementation does not support forward-AD!");
Tensor<?> ctxDerivative = (Tensor<?>) call.getValOf(Arg.Derivative.class);
assert ctxDerivative == null;
int d = call.getDerivativeIndex();
Tensor<?> derivative = ElemWiseUtil.newTensorLike( call.input( d == 0 ? 1 : 0 ), 0 );
Tensor<?> toBeDerived = ElemWiseUtil.newTensorLike( call.input( d ), 0 );
Device device = call.getDevice();
return
ADAction.of(
target ->
this.getAlgorithm( Broadcast.class )
.getImplementationFor( device )
.run(
ExecutionCall.of(
toBeDerived.mut().setIsVirtual(false),
derivative,
target.error()
)
.andArgs( Arg.DerivIdx.of(d) )
.running( this )
.on( device )
)
);
}
)
.buildFunAlgorithm()
);
}
@Override
public Result execute( final Function caller, final ExecutionCall<?> call )
{
if ( !caller.isFlat() ) {
int d = call.getDerivativeIndex();
if ( d < 0 ) {
Function reducedCaller = reducePairwise(caller);
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
return super.execute( flat, flatCall );
} else {
if ( !call.validate().allNotNullHaveSame(NDimensional::shape).isValid() )
throw new IllegalArgumentException("The shapes of the operands of the subtraction operation must be equal! (when deriving nested functions)");
int[] toBeDerived = IntStream.range(0,caller.getSubFunctions().size())
.filter( i -> caller.getSubFunctions().get(i).dependsOn(d) )
.toArray();
Tensor[] results = new Tensor[ toBeDerived.length ];
Function neg = Neureka.get().backend().getFunction().neg();
for ( int i = 0; i < results.length; i++ ) {
Function noAD = Function.of( caller.getSubFunctions().get( toBeDerived[i] ).toString(), false );
Tensor<?> deriv = noAD.execute( noAD.getOperation() == null ? call : call.withOperation(noAD.getOperation()) );
if ( i > 0 ) deriv = neg.execute(deriv);
results[ i ] = deriv;
}
if ( results.length == 1 ) return Result.of( results[0] );
Function addAll = new FunctionParser(Neureka.get().backend()).parse(Neureka.get().backend().getOperation("+"), results.length, false);
return addAll.getOperation().execute(addAll, call.withOperation(addAll.getOperation()).withInputs(results).withArgs(Arg.DerivIdx.of(-1)));
}
}
return super.execute( reducePairwise(caller), call );
}
private Function reducePairwise( final Function fun ) {
Function reduced = fun;
if ( reduced.getSubFunctions().size() > 2 ) {
/*
So currently we have something like this: a-b-c-d...
However, this is how it is really executed: ((((a-b)-c)-d)..)
...so let's create a function that is nested like the above:
*/
Function nested = reduced.getSubFunctions().get(0);
for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
nested = Function.of( nested + " - " + reduced.getSubFunctions().get(i), true );
reduced = nested;
}
return reduced;
}
@Override
public String asDerivative( Function[] children, int derivationIndex) {
return ( ( children[0].dependsOn(derivationIndex) ) ? "" : "-" ) +
Arrays.stream( children )
.filter( child -> child.dependsOn(derivationIndex) )
.map( child -> child.getDerivative(derivationIndex) )
.map( Object::toString )
.collect( Collectors.joining( " - " ) );
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
if ( j < 0 ) return calculate( inputs, d, src );
if ( d < 0 ) {
double result = src[ 0 ].call( inputs, j );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs, j );
result -= current;
}
return result;
} else {
double derivative = 0;
for ( int i = 0; i < src.length; i++ ) {
if ( i == 0 )
derivative += src[ i ].derive( inputs, d, j );
else
derivative -= src[ i ].derive( inputs, d, j );
}
return derivative;
}
}
public static double calculate( double[] inputs, int d, Function[] src ) {
if ( d < 0 ) {
double result = src[ 0 ].call( inputs );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs );
result -= current;
}
return result;
} else {
double derivative = 0;
for ( int i = 0; i < src.length; i++ ) {
if ( i == 0 )
derivative += src[ i ].derive( inputs, d );
else
derivative -= src[ i ].derive( inputs, d );
}
return derivative;
}
}
}