CLScalarBroadcastMultiplication.java
package neureka.backend.main.implementations.broadcast;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.math.args.Arg;
import neureka.devices.opencl.OpenCLDevice;
public class CLScalarBroadcastMultiplication extends CLScalarBroadcast
{
public CLScalarBroadcastMultiplication(String id) {
super( id, "output = input1 * value;\n", "if ( d == 0 ) {output = value;}else{output = input1;}\n" );
}
@Override
public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
if ( call.getDerivativeIndex() == 0 )
return call.input( 2 ).shallowCopy().mut().setIsIntermediate( true );
else if ( call.getDerivativeIndex() == 1 )
return call.input( 1 ).shallowCopy().mut().setIsIntermediate( true );
else {
int offset = (call.input(Number.class, 2).isVirtual() || call.input(Number.class, 2).size() == 1) ? 1 : 0;
int gwz = call.input(Number.class, 0).size();
call.getDevice()
.getKernel(call)
.passAllOf(call.input(Number.class, 0))
.passAllOf(call.input(Number.class, 0 + offset))
.pass( call.input( Number.class, 1 + offset ).at( 0 ).get() )
.pass(call.input(Number.class, 0).rank())
.pass(call.getValOf(Arg.DerivIdx.class))
.call(gwz);
}
return call.input( 0 );
}
}