CLScalarBroadcastDivision.java

package neureka.backend.main.implementations.broadcast;

import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.math.args.Arg;
import neureka.devices.opencl.OpenCLDevice;

public class CLScalarBroadcastDivision extends CLScalarBroadcast
{
    public CLScalarBroadcastDivision( String id ) {
        super(
                id,
                "output = ("+TYPE+")(((float)input1) / ((float)value));\n",
                "if ( d == 0 ) {                                                        \n" +
                "    output = ("+TYPE+")( 1 / (float)value );                                    \n" +
                "} else {                                                                        \n" +
                "    output = -(("+TYPE+")(((float)value) /(float)pow((float)input1, 2.0f)));    \n" +
                "}                                                                               \n"
        );
    }

    @Override
    public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
        int offset = (call.input( Number.class, 2 ).isVirtual() || call.input( Number.class, 2 ).size() == 1)?1:0;
        int gwz = call.input( Number.class, 0 ).size();
        call.getDevice().getKernel(call)
                .passAllOf(call.input( Number.class, 0 ))
                .passAllOf(call.input( Number.class, 0 ))
                .pass( call.input( Number.class, 1 + offset ).at( 0 ).get().floatValue() )
                .pass( call.input( Number.class, 0 ).rank() )
                .pass( call.getValOf( Arg.DerivIdx.class ) )
                .call( gwz );

        return call.input( 0 );
    }
}