CLScalarFunction.java

  1. package neureka.backend.main.implementations.scalar;

  2. import neureka.Tensor;
  3. import neureka.backend.api.ExecutionCall;
  4. import neureka.backend.api.ImplementationFor;
  5. import neureka.backend.main.implementations.fun.api.CPUFun;
  6. import neureka.backend.main.implementations.fun.api.ScalarFun;
  7. import neureka.math.args.Arg;
  8. import neureka.devices.opencl.OpenCLDevice;

  9. public class CLScalarFunction implements ImplementationFor<OpenCLDevice>
  10. {
  11.     private final ScalarFun _fun;

  12.     public CLScalarFunction(ScalarFun fun) {
  13.         _fun = fun;
  14.     }

  15.     @Override
  16.     public Tensor<?> run(ExecutionCall<OpenCLDevice> call) {
  17.         int d = call.getValOf(Arg.DerivIdx.class);
  18.         CPUFun f = d < 0 ? _fun.getActivation() : _fun.getDerivative();
  19.         Number value =  f.invoke(call.input( Number.class, 1 ).item(0).doubleValue());
  20.         Tensor<Number> out = call.input( Number.class, 0 );
  21.         out.mut().setDataAt(0, value);
  22.         return call.input(0);
  23.     }
  24. }