RMSPropFactory.java
package neureka.optimization.implementations;
import neureka.Tensor;
import neureka.optimization.OptimizerFactory;
public class RMSPropFactory implements OptimizerFactory
{
private final double _learningRate;
private final double _decayRate;
public RMSPropFactory() {
_learningRate = 0.001;
_decayRate = 0.9;
}
// The copy constructor should be private, use withers instead!
private RMSPropFactory(double learningRate, double decayRate) {
_learningRate = learningRate;
_decayRate = decayRate;
}
// Withers:
public RMSPropFactory withLearningRate(double learningRate) {
return new RMSPropFactory(learningRate, _decayRate);
}
public RMSPropFactory withDecayRate(double decayRate) {
return new RMSPropFactory(_learningRate, decayRate);
}
@Override
public <V extends Number> RMSProp<V> create(Tensor<V> target) {
return new RMSProp<>((Tensor<Number>) target, _learningRate, _decayRate);
}
}