Division.java
package neureka.backend.main.operations.operator;
import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Call;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.BiElementwise;
import neureka.backend.main.algorithms.BiScalarBroadcast;
import neureka.backend.main.algorithms.Broadcast;
import neureka.devices.Device;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.ndim.NDimensional;
import java.util.Arrays;
public class Division extends AbstractOperation
{
public Division()
{
super(
new OperationBuilder()
.identifier( "divide" )
.operator( "/" )
.arity( -1 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
BiElementwise.class,
new BiElementwise()
.setSupplyADActionFor( getDefaultAlgorithm() )
.buildFunAlgorithm()
);
setAlgorithm(
Broadcast.class,
new Broadcast()
.setAutogradModeFor(
call -> call
.validate().allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
)
.setSupplyADActionFor(
( Function f, ExecutionCall<? extends Device<?>> call ) ->
{
if ( call.autogradMode().allowsForward() )
throw new IllegalArgumentException("Broadcast implementation does not support forward-AD!");
Tensor<?> ctxDerivative = (Tensor<?>) call.getValOf(Arg.Derivative.class);
Function mul = Neureka.get().backend().getFunction().mul();
if ( ctxDerivative != null ) {
return ADAction.of( target -> mul.execute( target.error(), ctxDerivative ) );
}
int d = call.getDerivativeIndex();
Tensor<?> derivative = f.executeDerive( call.inputs(), d );
return ADAction.of( target -> mul.execute( target.error(), derivative ) );
}
)
.buildFunAlgorithm()
);
setAlgorithm(
BiScalarBroadcast.class,
new BiScalarBroadcast()
.setIsSuitableFor( call -> SuitabilityPredicate.BAD )
.setAutogradModeFor( call -> AutoDiffMode.FORWARD_AND_BACKWARD )
.setExecution( (caller, call) -> Result.of(AbstractDeviceAlgorithm.executeFor(caller, call, AbstractDeviceAlgorithm::executeDeviceAlgorithm)).withAutoDiff( FallbackAlgorithm::ADAction ))
.buildFunAlgorithm()
);
}
@Override
public Result execute( final Function caller, final ExecutionCall<?> call )
{
Function reducedCaller = reducePairwise( caller );
int d = call.getDerivativeIndex();
if ( !reducedCaller.isFlat() ) {
if ( d < 0 ) {
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
Arrays.stream(flatCall.inputs()).forEach( t -> t.mut().setIsIntermediate(false) );
Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
return super.execute( flat, flatCall );
}
}
if ( d >= 0 ) {
if ( !call.validate().all( (a, b) -> Util.canBeBroadcast(a.shape(), b.shape()) ).isValid() )
throw new IllegalArgumentException("The shapes of the operands of the division operation must be equal or broadcast compatible! (when deriving nested functions)");
// So here we assume that there are only two sub-functions: a/b
Function noAd = Function.of( reducedCaller.toString(), false );
Function a = noAd.getSubFunctions().get(0);
Function b = noAd.getSubFunctions().get(1);
boolean deriveA = a.dependsOn(d);
boolean deriveB = b.dependsOn(d);
if ( !deriveA && !deriveB ) return super.execute( reducedCaller, call );
Tensor<?> bResult = b.call((Call) call.withArgs(Arg.DerivIdx.of(-1)));
Tensor<?> derivOfA = null;
if ( deriveA ) {
Function div = Neureka.get().backend().getFunction().div();
// This is simple, we just derive the first sub-function and multiply it with the inverse of the second sub-function:
Tensor<?> aDeriv = a.call((Call)call);
derivOfA = div.call((Tensor<Object>)aDeriv, (Tensor<Object>)bResult);
}
if ( !deriveB && deriveA )
return Result.of(derivOfA.mut().setIsIntermediate(true));
Tensor<?> aResult = a.call((Call)call.withArgs(Arg.DerivIdx.of(-1)));
if ( deriveB )
return _deriveB( call, b, deriveA, derivOfA, aResult, bResult );
}
return super.execute( reducedCaller, call );
}
private Result _deriveB(
ExecutionCall<?> call,
Function b,
boolean deriveA,
Tensor<?> derivOfA,
Tensor<?> aResult,
Tensor<?> bResult
) {
Function mul = Neureka.get().backend().getFunction().mul();
Tensor<?> innerDerivB = b.call((Call)call);
// So we have something like this: a/b, where we want to derive b.
// This is how it is really executed: (a/b) = (a * (1/b))
// So we can derive b and then later on add the derivative of 'a' to it (if it must be derived).
// The derivative of 1/b is -1/b^2
// Let's derive b:
Function derive = Function.of("-I[0] / (I[1] ** 2)", false);
Tensor<?> derivOfB = derive.call( (Tensor<Object>)innerDerivB, (Tensor<Object>)bResult );
derivOfB = mul.call((Tensor<Object>)aResult, (Tensor<Object>)derivOfB);
if ( !deriveA )
return Result.of(derivOfB.mut().setIsIntermediate(true));
else {
Function add = Neureka.get().backend().getFunction().add();
return Result.of( add.call((Tensor<Object>)derivOfA, (Tensor<Object>)derivOfB).mut().setIsIntermediate(true) );
}
}
private Function reducePairwise( Function fun ) {
Function reduced = fun;
if ( reduced.getSubFunctions().size() > 2 ) {
/*
So currently we have something like this: a/b/c/d...
However, this is how it is really executed: ((((a/b)/c)/d)..)
...so let's create a function that is nested like the above:
*/
Function nested = reduced.getSubFunctions().get(0);
for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
nested = Function.of( nested + " / " + reduced.getSubFunctions().get(i), true );
reduced = nested;
}
return reduced;
}
@Override
public String asDerivative( Function[] children, int derivationIndex) {
return _asDerivative( children, derivationIndex, children.length - 1 );
}
private String _asDerivative( Function[] children, int d, int index ) {
if ( d >= 0 ) {
if ( index <= 0 ) return children[ 0 ].getDerivative( d ).toString();
else {
String first = ( children[ index - 1 ].dependsOn( d ) )
? "(" + _asDerivative( children, d, index - 1 )+ " / " + children[ index ] + " )"
: "";
if ( !children[ index ].dependsOn(d) ) return first;
String s = children[ index - 1 ].toString();
if ( s.equals("0.0") ) return first;
return first +
" - ((" + // The second expression is the inner derivative (current index)! (inner times outer...)
s + " * " + children[ index ].getDerivative(d) +
") / ( "
+ children[ index ] + "**2 " +
") )";
}
} else {
if ( index <= 0 ) return children[ 0 ].toString();
else
return _asDerivative( children, -1, index - 1 ) + " / " + children[ index ].toString();
}
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
if ( j < 0 ) return calculate( inputs, d, src );
if ( d < 0 ) {
double result = src[ 0 ].call( inputs, j );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs, j );
result /= current;
}
return result;
} else {
double u, ud, v, vd;
u = src[ 0 ].call( inputs, j );
ud = src[ 0 ].derive( inputs, d, j );
for ( int i = 0; i < src.length - 1; i++ ) {
v = src[ i + 1 ].call( inputs, j );
vd = src[ i + 1 ].derive( inputs, d, j );
ud = (ud * v - u * vd) / Math.pow(v, 2);
u /= v;
}
return ud;
}
}
public static double calculate( double[] inputs, int d, Function[] src ) {
if ( d < 0 ) {
double result = src[ 0 ].call( inputs );
for ( int i = 1; i < src.length; i++ ) {
final double current = src[ i ].call( inputs );
result /= current;
}
return result;
} else {
double derivative;
double tempVar = src[ 0 ].call( inputs );
derivative = src[ 0 ].derive( inputs, d );
for ( int i = 0; i < src.length - 1; i++ ) {
double u, ud, v, vd;
v = src[ i + 1 ].call( inputs );
vd = src[ i + 1 ].derive( inputs, d );
u = tempVar;
ud = derivative;
derivative = ( ud * v - u * vd ) / Math.pow(v, 2);
tempVar /= v;
}
return derivative;
}
}
}