CLScalarBroadcast.java

package neureka.backend.main.implementations.broadcast;

import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.main.implementations.ParsedCLImplementation;
import neureka.math.args.Arg;
import neureka.devices.opencl.KernelCode;
import neureka.dtype.DataType;

import java.util.Arrays;

public class CLScalarBroadcast extends ParsedCLImplementation
{
    protected final static String TYPE = "#DATA_TYPE#";

    public CLScalarBroadcast(
        String postfix, String activation, String derivation
    ) {
        super(
            call->{
                Tensor<Number> t = call.input( Number.class, 0 );
                int gwz = t.size();
                call.getDevice()
                        .getKernel(call)
                        .passAllOf( t )
                        .passAllOf( t )
                        .pass( call.input( Number.class, 1 ).at(0).get().floatValue() )
                        .pass( t.rank() )
                        .pass( call.getValOf( Arg.DerivIdx.class ) )
                        .call( gwz );

                return call.input(0);
            },
            2,
            Neureka.get().utility().readResource("kernels/scalarization_template.cl"),
            activation,
            derivation,
            postfix,
            kernelCode -> {
                String[] types = new String[]{
                        "float", "double", "int", "long", "short", "char"
                };
                return
                    Arrays.stream(types).map( type -> {
                        String newName = kernelCode.getName() + ("_" + type);
                        String newCode = kernelCode.getCode()
                                                    .replace(TYPE, type)
                                                    .replace(kernelCode.getName(), newName);
                        DataType<?> dt;
                        switch (type) {
                            case "float":  dt = DataType.of(Float.class);   break;
                            case "double": dt = DataType.of(Double.class);  break;
                            case "int":    dt = DataType.of(Integer.class); break;
                            case "long":   dt = DataType.of(Long.class);    break;
                            case "short":  dt = DataType.of(Short.class);   break;
                            case "char":   dt = DataType.of(Byte.class);    break;
                            default:       dt = DataType.of(Float.class);   break;
                        }
                        return new KernelCode(newName, newCode, dt);
                    })
                    .toArray(KernelCode[]::new);
            }
        );
    }
}