Summation.java
package neureka.backend.main.operations.indexer;
import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Operation;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
/**
* This type of operation belongs to the same species as the
* {@link Product} operation.
* It executes incoming calls so that the calling function
* will be executed with all input indices passed to it.
* The resulting array of tensors will then be summed
* to produce the result of this operation, hence the name {@link Summation}.
*/
public final class Summation extends AbstractOperation
{
public Summation()
{
super (
new OperationBuilder()
.identifier( "sumJs" )
.operator( "sumJs" )
.arity( 1 )
.isOperator( false )
.isIndexer( true )
.isDifferentiable( true )
.isInline( false )
);
/*
The summation operation does not have algorithms because it is
a special derivative case of the "addition" operation.
*/
}
@Override
public Result execute( final Function caller, final ExecutionCall<?> call )
{
Tensor<?>[] inputs = new Tensor[ call.arity() ];
for ( int i = 0; i < inputs.length; i++ ) {
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flattenForIndexer( caller, call.withArgs(Arg.VarIdx.of(i)) );
inputs[ i ] = flatCall.input( 0 );
}
Operation plusOp = Neureka.get().backend().getOperation("+");
Function plus = new FunctionParser(Neureka.get().backend())
.parse( plusOp, inputs.length, caller.isDoingAD() );
return plusOp.execute( plus, call.withInputs(inputs).withOperation(plusOp).withArgs(Arg.DerivIdx.of(-1)) );
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
if ( j < 0 ) return calculate( inputs, d, src );
if ( d < 0 ) return _calculate( inputs, src );
else return src[ 0 ].derive( inputs, d, j );
}
public static double calculate( double[] inputs, int d, Function[] src ) {
if ( d < 0 )
return _calculate( inputs, src );
else {
double sum = 0;
boolean nothingDone = true;
for ( int i = 0; i < inputs.length; i++ ) {
double r = src[ 0 ].derive( inputs, d, i );
sum += r;
nothingDone = false;
}
if ( nothingDone ) return src[ 0 ].call( inputs );
return sum;
}
}
private static double _calculate( double[] inputs, Function[] src ) {
double sum = 0;
boolean nothingDone = true;
for ( int i = 0; i < inputs.length; i++ ) {
sum += src[ 0 ].call( inputs, i );
nothingDone = false;
}
if ( nothingDone ) return src[ 0 ].call( inputs );
return sum;
}
}