ScalarReLU.java

  1. package neureka.backend.main.implementations.fun;

  2. import neureka.backend.main.implementations.fun.api.CPUFun;
  3. import neureka.backend.main.implementations.fun.api.ScalarFun;

  4. public final class ScalarReLU implements ScalarFun
  5. {
  6.     @Override public String id() { return "relu"; }

  7.     @Override
  8.     public String activationCode() {
  9.         return "if (input >= 0) {  output = input; } else { output = input * (float)0.01; }\n";
  10.     }

  11.     @Override
  12.     public String derivationCode() {
  13.         return "if (input >= 0) { output = (float)1; } else { output = (float)0.01; }\n";
  14.     }

  15.     @Override
  16.     public CPUFun getActivation() {
  17.         return new CPUFun() {
  18.             @Override public double invoke(double x) { return ( x >= 0 ? x : x * .01 ); }
  19.             @Override public float invoke(float x) { return ( x >= 0 ? x : x * .01f ); }
  20.         };
  21.     }

  22.     @Override
  23.     public CPUFun getDerivative() {
  24.         return new CPUFun() {
  25.             @Override public double invoke(double x) { return ( x >= 0 ? 1 : .01); }
  26.             @Override public float invoke(float x) { return ( x >= 0 ? 1f : .01f ); }
  27.         };
  28.     }

  29. }