ScalarSeLU.java

package neureka.backend.main.implementations.fun;

import neureka.backend.main.implementations.fun.api.CPUFun;
import neureka.backend.main.implementations.fun.api.ScalarFun;

/**
 * The Scaled Exponential Linear Unit, or SELU, is an activation
 * function that induces self-normalizing properties.
 * The SELU activation function is implemented as:
 * <i>{@code
 *      if      ( x >  0 ) return SCALE * x;
 *      else if ( x <= 0 ) return SCALE * ALPHA * (Math.exp(x) - 1);
 *      else               return Float.NaN;
 * }</i><br>
 * ...where {@code ALPHA == 1.6733} and {@code SCALE == 1.0507}.
 */
public class ScalarSeLU implements ScalarFun
{
    private static final double ALPHA = 1.6732632423543772848170429916717;
    private static final double SCALE = 1.0507009873554804934193349852946;
    private static final float  ALPHA_F32 = (float) ALPHA;
    private static final float  SCALE_F32 = (float) SCALE;


    @Override public String id() { return "selu"; }

    @Override public String activationCode() {
        return "if      ( input > 0  ) output = "+SCALE_F32+"f * input;\n" +
               "else if ( input <= 0 ) output = "+SCALE_F32+"f * "+ALPHA_F32+"f * (exp(input) - 1.0f);\n" +
               "else                   output = 0.0f;\n";
    }

    @Override public String derivationCode() {
        return "if      ( input >  0 ) output = "+SCALE_F32+"f;\n" +
               "else if ( input <= 0 ) output = "+SCALE_F32+"f * "+ALPHA_F32+"f * exp(input);\n" +
               "else                   output = 1.0f;\n";
    }

    @Override
    public CPUFun getActivation() {
        return new CPUFun() {
            @Override public double invoke(double x) { return selu(x); }
            @Override public float invoke(float x) { return (float) selu(x); }
        };
    }

    @Override
    public CPUFun getDerivative() {
        return new CPUFun() {
            @Override
            public double invoke(double x) {
                if      ( x >  0 ) return SCALE;
                else if ( x <= 0 ) return SCALE * ALPHA * Math.exp(x);
                else               return Double.NaN;
            }

            @Override
            public float invoke(float x) {
                if      ( x >  0 ) return SCALE_F32;
                else if ( x <= 0 ) return (float) ( SCALE * ALPHA * Math.exp(x) );
                else               return Float.NaN;
            }
        };
    }


    public static double selu(double x) {
        if      ( x >  0 ) return SCALE * x;
        else if ( x <= 0 ) return SCALE * ALPHA * (Math.exp(x) - 1);
        else               return Float.NaN;
    }

}