CLScalarBroadcastAddition.java
package neureka.backend.main.implementations.broadcast;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.devices.opencl.OpenCLDevice;
import neureka.math.args.Arg;
public class CLScalarBroadcastAddition extends CLScalarBroadcast
{
public CLScalarBroadcastAddition(String id) {
super( id, "output = input1 + value;\n", "output = 1;\n" );
}
@Override
public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
assert call.arity() == 3;
if ( call.getDerivativeIndex() == 0 )
return Tensor.of( call.input(1).shape(), 1d ).mut().setIsIntermediate( true );
else if ( call.getDerivativeIndex() == 1 )
return Tensor.of( call.input( 2 ).shape(), 1d ).mut().setIsIntermediate( true );
else {
int gwz = call.input(Number.class, 0).size();
float value = call.input(Number.class, 2).item(0).floatValue();
call.getDevice()
.getKernel(call)
.passAllOf(call.input(Number.class, 0))
.passAllOf(call.input(Number.class, 1))
.pass(value)
.pass(call.input(Number.class, 0).rank())
.pass(call.getValOf(Arg.DerivIdx.class))
.call(gwz);
}
return call.input( 0 );
}
}