MatMul.java
package neureka.backend.main.operations.linear;
import neureka.Neureka;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.MatMulAlgorithm;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
public class MatMul extends AbstractOperation
{
public MatMul()
{
super(
new OperationBuilder()
.identifier( "matMul" )
.operator( "@" )
.arity( 2 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
new MatMulAlgorithm().buildFunAlgorithm()
);
}
@Override
public Result execute( final Function caller, final ExecutionCall<?> call )
{
if ( !caller.isFlat() ) {
Function reducedCaller = reducePairwise(caller);
ExecutionCall<?> flatCall = AbstractDeviceAlgorithm.flatten( reducedCaller, call.withArgs(Arg.DerivIdx.of(-1)) );
Function flat = new FunctionParser(Neureka.get().backend()).parse( flatCall.getOperation(), flatCall.arity(), true );
return super.execute( flat, flatCall );
}
return super.execute( reducePairwise(caller), call );
}
private Function reducePairwise( final Function fun ) {
Function reduced = fun;
if ( reduced.getSubFunctions().size() > 2 ) {
/*
So currently we have something like this: a@b@c@d...
However, this is how it is really executed: ((((a@b)@c)@d)..)
...so let's create a function that is nested like the above:
*/
Function nested = reduced.getSubFunctions().get(0);
for ( int i = 1; i < reduced.getSubFunctions().size(); i++ )
nested = Function.of( nested + " @ " + reduced.getSubFunctions().get(i), true );
reduced = nested;
}
return reduced;
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
return src[ 0 ].call( inputs, j );
}
}