Min.java

  1. package neureka.backend.main.operations.other;

  2. import neureka.Shape;
  3. import neureka.Tensor;
  4. import neureka.backend.api.AutoDiffMode;
  5. import neureka.backend.api.DeviceAlgorithm;
  6. import neureka.backend.api.Result;
  7. import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
  8. import neureka.backend.api.template.operations.AbstractOperation;
  9. import neureka.backend.api.template.operations.OperationBuilder;
  10. import neureka.backend.main.operations.ElemWiseUtil;
  11. import neureka.backend.main.operations.linear.internal.opencl.CLReduce;
  12. import neureka.backend.main.operations.other.internal.CPUReduce;
  13. import neureka.math.Function;
  14. import neureka.devices.Device;
  15. import neureka.devices.host.CPU;
  16. import neureka.devices.opencl.OpenCLDevice;

  17. public class Min extends AbstractOperation
  18. {
  19.     public Min()
  20.     {
  21.         super(
  22.                 new OperationBuilder()
  23.                         .identifier(       "min"       )
  24.                         .operator(         "min"       )
  25.                         .arity(            1           )
  26.                         .isOperator(       false       )
  27.                         .isIndexer(        false       )
  28.                         .isDifferentiable( true        )
  29.                         .isInline(         false       )
  30.         );

  31.         setAlgorithm(
  32.             DeviceAlgorithm
  33.             .withName("min_algorithm")
  34.             .setIsSuitableFor(
  35.                     call -> call.validate()
  36.                             .allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
  37.                             .basicSuitability()
  38.             )
  39.             .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
  40.             .setExecution( (caller, call) -> {
  41.                 Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
  42.                 call = call.withInputs(inputs);
  43.                 Tensor<Integer> index = ((DeviceAlgorithm)call.getAlgorithm()).getImplementationFor(call.getDevice()).run(call);
  44.                 int i = index.item(0);
  45.                 Tensor<?> in = inputs[0] == null ? inputs[1] : inputs[0];
  46.                 Class<Object> typeClass = (Class<Object>) in.itemType();
  47.                 Shape shape = in.shape();
  48.                 Device<Object> device = (Device<Object>) call.getDevice();
  49.                 return Result.of(
  50.                             Tensor.of(in.itemType(), Shape.of( 1 ), in.item(i)).to(call.getDevice()).mut().setIsIntermediate(true)
  51.                         )
  52.                         .withADAction( target -> {
  53.                             Tensor<Object> error = (Tensor<Object>) target.error();
  54.                             assert error.size() == 1;
  55.                             Tensor<Object> newError = ElemWiseUtil.newTensorLike(typeClass, shape, true, device, 0);
  56.                             newError.mut().setIsVirtual(false);
  57.                             newError.mut().setItemAt(i, error.item(0));
  58.                             return newError;
  59.                         });
  60.             })
  61.             .setCallPreparation( call ->
  62.             {
  63.                 if ( call.input( 0 ) == null )
  64.                     call = call.withInputAt( 0, call.input( 1 ) );

  65.                 return call;
  66.             })
  67.             .buildFunAlgorithm()
  68.             .setImplementationFor( CPU.class, new CPUReduce(CPUReduce.Type.MIN) )
  69.             .setImplementationFor( OpenCLDevice.class, new CLReduce(CLReduce.Type.MIN) )
  70.         );
  71.     }

  72.     @Override
  73.     public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }
  74. }