ADAMFactory.java
package neureka.optimization.implementations;
import neureka.Tensor;
import neureka.optimization.OptimizerFactory;
public class ADAMFactory implements OptimizerFactory
{
private final double _learningRate;
private final long _time;
public ADAMFactory() { this(0.01, 0); }
// The copy constructor should be private, use withers instead!
private ADAMFactory( double learningRate, long time ) {
if ( time < 0 ) throw new IllegalArgumentException("The time must be a positive number!");
_learningRate = learningRate;
_time = time;
}
// Withers:
public ADAMFactory withLearningRate(double learningRate) { return new ADAMFactory(learningRate, _time); }
public ADAMFactory withTime(long time) { return new ADAMFactory(_learningRate, time); }
@Override
public <V extends Number> ADAM<V> create(Tensor<V> target) {
return new ADAM<>(_time, _learningRate, target);
}
public <V extends Number> ADAM<V> create(Tensor<V> momentum, Tensor<V> velocity) {
return new ADAM<>(_time, _learningRate, momentum, velocity);
}
}