CLScalarFunction.java
- package neureka.backend.main.implementations.scalar;
- import neureka.Tensor;
- import neureka.backend.api.ExecutionCall;
- import neureka.backend.api.ImplementationFor;
- import neureka.backend.main.implementations.fun.api.CPUFun;
- import neureka.backend.main.implementations.fun.api.ScalarFun;
- import neureka.math.args.Arg;
- import neureka.devices.opencl.OpenCLDevice;
- public class CLScalarFunction implements ImplementationFor<OpenCLDevice>
- {
- private final ScalarFun _fun;
- public CLScalarFunction(ScalarFun fun) {
- _fun = fun;
- }
- @Override
- public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
- int d = call.getValOf(Arg.DerivIdx.class);
- CPUFun f = d < 0 ? _fun.getActivation() : _fun.getDerivative();
- Number value = f.invoke(call.input( Number.class, 1 ).item(0).doubleValue());
- Tensor<Number> out = call.input( Number.class, 0 );
- out.mut().setDataAt(0, value);
- return call.input(0);
- }
- }