ScalarGaTU.java
package neureka.backend.main.implementations.fun;
import neureka.backend.main.implementations.fun.api.CPUFun;
import neureka.backend.main.implementations.fun.api.ScalarFun;
/**
* The Self Gated {@link ScalarTanh} Unit is based on the {@link ScalarTanh}
* making it an exponentiation based version of the {@link ScalarGaSU} function which
* is itself based on the {@link ScalarSoftsign} function
* (a computationally cheap non-exponential quasi {@link ScalarTanh}).
* Similar a the {@link ScalarSoftsign} and {@link ScalarTanh} function {@link ScalarGaTU}
* is 0 centered and caped by -1 and +1.
*/
public class ScalarGaTU implements ScalarFun
{
@Override public String id() { return "gatu"; }
@Override
public String activationCode() { return "output = tanh(input*input*input);\n"; }
@Override
public String derivationCode() {
return "float x2 = input * input; \n" +
"float x3 = x2 * input; \n" +
"float temp = 3 * x2; \n" +
"float tanh2 = pow(tanh(x3), 2); \n" +
"output = -temp * tanh2 + temp; \n";
}
@Override
public CPUFun getActivation() {
return new CPUFun() {
@Override public double invoke(double x ) { return ScalarTanh.tanh(x*x*x); }
@Override public float invoke(float x ) { return ScalarTanh.tanh(x*x*x); }
};
}
@Override
public CPUFun getDerivative() {
return new CPUFun() {
@Override public double invoke(double x ) {
double x2 = x * x;
double x3 = x2 * x;
double temp = 3 * x2;
double tanh2 = Math.pow(ScalarTanh.tanh(x3), 2);
return -temp * tanh2 + temp;
}
@Override public float invoke(float x ) {
float x2 = x * x;
float x3 = x2 * x;
float temp = 3 * x2;
float tanh2 = (float) Math.pow(ScalarTanh.tanh(x3), 2);
return -temp * tanh2 + temp;
}
};
}
}