Modulo.java
package neureka.backend.main.operations.operator;
import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Call;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.BiElementwise;
import neureka.backend.main.algorithms.BiScalarBroadcast;
import neureka.backend.main.algorithms.Broadcast;
import neureka.devices.Device;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.ndim.NDimensional;
import java.util.Arrays;
public class Modulo extends AbstractOperation
{
public Modulo()
{
super(
new OperationBuilder()
.identifier( "modulo" )
.operator( "%" )
.arity( -1 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
BiElementwise.class,
new BiElementwise()
.setSupplyADActionFor( getDefaultAlgorithm() )
.buildFunAlgorithm()
);
setAlgorithm(
Broadcast.class,
new Broadcast()
.setAutogradModeFor(
call -> call.validate()
.allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
)
.setSupplyADActionFor(
( Function f, ExecutionCall<? extends Device<?>> call ) ->
{
if ( call.autogradMode().allowsForward() )
throw new IllegalArgumentException("Broadcast implementation does not support forward-AD!");
Tensor<?> ctxDerivative = (Tensor<?>) call.getValOf(Arg.Derivative.class);
Function mul = Neureka.get().backend().getFunction().mul();
if ( ctxDerivative != null ) {
return ADAction.of( target -> mul.execute( target.error(), ctxDerivative ) );
}
int d = call.getDerivativeIndex();
Tensor<?> derivative = f.executeDerive( call.inputs(), d );
return ADAction.of( target -> mul.execute( target.error(), derivative ) );
}
)
.buildFunAlgorithm()
);
setAlgorithm(
BiScalarBroadcast.class,
new BiScalarBroadcast()
.setIsSuitableFor( call -> SuitabilityPredicate.BAD )
.setAutogradModeFor(
call -> call.validate()
.allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
)
.setExecution( (caller, call) -> Result.of(AbstractDeviceAlgorithm.executeFor(caller, call, AbstractDeviceAlgorithm::executeDeviceAlgorithm)).withAutoDiff( FallbackAlgorithm::ADAction ))
.buildFunAlgorithm()
);
}
@Override
public Result execute( final Function caller, final ExecutionCall<?> call )
{
Function reducedCaller = reducePairwise(caller);
int d = call.getDerivativeIndex();
if ( !reducedCaller.isFlat() ) {
if ( d < 0 ) {
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
Arrays.stream(flatCall.inputs()).forEach(t -> t.mut().setIsIntermediate(false) );
Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
return super.execute( flat, flatCall );
}
}
if ( d >= 0 ) {
if ( !call.validate().all( (a, b) -> Util.canBeBroadcast(a.shape(), b.shape()) ).isValid() )
throw new IllegalArgumentException("The shapes of the operands of the modulo operation must be equal or broadcast compatible! (when deriving nested functions)");
// So here we assume that there are only two sub-functions: a/b
Function noAd = Function.of( reducedCaller.toString(), false );
Function a = noAd.getSubFunctions().get(0);
Function b = noAd.getSubFunctions().get(1);
boolean deriveA = a.dependsOn(d);
boolean deriveB = b.dependsOn(d);
if ( !deriveA && !deriveB ) return super.execute( reducedCaller, call );
Tensor<?> bResult = b.call((Call) call.withArgs(Arg.DerivIdx.of(-1)));
Tensor<?> derivOfA = null;
if ( deriveA ) {
Function div = Neureka.get().backend().getFunction().div();
// This is simple, we just derive the first sub-function and multiply it with the inverse of the second sub-function:
Tensor<?> aDeriv = a.call((Call)call);
derivOfA = div.call((Tensor<Object>)aDeriv, (Tensor<Object>)bResult);
}
if ( !deriveB && deriveA )
return Result.of(derivOfA.mut().setIsIntermediate(true));
Tensor<?> aResult = a.call((Call)call.withArgs(Arg.DerivIdx.of(-1)));
if ( deriveB ) {
Function mul = Neureka.get().backend().getFunction().mul();
Tensor<?> innerDerivB = b.call((Call)call);
// So we have something like this: a/b, where we want to derive b.
// This is how it is really executed: (a/b) = (a * (1/b))
// So we can derive b and then later on add the derivative of 'a' to it (if it must be derived).
// The derivative of 1/b is -1/b^2
// Let's derive b:
Function derive = Function.of("-I[0] / (I[1] ** 2)", false);
Tensor<?> derivOfB = derive.call( (Tensor<Object>)innerDerivB, (Tensor<Object>)bResult );
derivOfB = mul.call((Tensor<Object>)aResult, (Tensor<Object>)derivOfB);
if ( !deriveA )
return Result.of(derivOfB.mut().setIsIntermediate(true));
else {
Function add = Neureka.get().backend().getFunction().add();
return Result.of( add.call((Tensor<Object>)derivOfA, (Tensor<Object>)derivOfB).mut().setIsIntermediate(true) );
}
}
}
return super.execute( reducedCaller, call );
}
private Function reducePairwise( final Function fun ) {
Function reduced = fun;
if ( reduced.getSubFunctions().size() > 2 ) {
/*
So currently we have something like this: a%b%c%d...
However, this is how it is really executed: ((((a%b)%c)%d)..)
...so let's create a function that is nested like the above:
*/
Function nested = reduced.getSubFunctions().get(0);
for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
nested = Function.of( nested + " % " + reduced.getSubFunctions().get(i), true );
reduced = nested;
}
return reduced;
}
public static double calculate( double[] inputs, int d, Function[] src ) {
if ( d < 0 ) {
double result = src[ 0 ].call( inputs );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs );
result %= current;
}
return result;
}
else return src[ 0 ].derive( inputs, d );
}
@Override
public String asDerivative( Function[] children, int derivationIndex) {
return children[ 0 ].getDerivative(derivationIndex).toString();
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
if ( j < 0 ) return calculate( inputs, d, src );
if ( d < 0 ) {
double result = src[ 0 ].call( inputs, j );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs, j );
result %= current;
}
return result;
}
else
return src[ 0 ].derive( inputs, d, j );
}
}