Randomization.java
package neureka.backend.main.operations.other;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.algorithms.ElementwiseAlgorithm;
import neureka.backend.main.implementations.elementwise.CPURandomization;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.dtype.NumericType;
import java.util.Arrays;
/**
* This {@link neureka.backend.api.Operation} takes an optional user seed,
* the shape of its input tensor, and
* the indices of individual elements within said tensor to generate
* floats or doubles with a gaussian distribution where the mean
* is 0 and the standard deviation is 1.
* This operation is very fast because it generates numbers in parallel unlike
* the JDKs random number generator class {@link java.util.Random}.
*/
public class Randomization extends AbstractOperation
{
public Randomization()
{
super(
new OperationBuilder()
.identifier( "random" )
.operator( "rand" )
.arity( 1 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( false )
.isInline( true )
);
setAlgorithm(
new ElementwiseAlgorithm()
.setIsSuitableFor(
call -> call.validate()
.allNotNull( t ->
t.getDataType().typeClassImplements(NumericType.class)
||
t.itemType() == Character.class
||
t.itemType() == Boolean.class
)
.basicSuitability()
)
.setAutogradModeFor( call -> AutoDiffMode.NOT_SUPPORTED)
.setExecution( (caller, call) -> Result.of(AbstractDeviceAlgorithm.prepareAndExecute(call, AbstractDeviceAlgorithm::executeDeviceAlgorithm)).withAutoDiff( FallbackAlgorithm::ADAction ))
.setCallPreparation( call ->
{
if ( call.input( 0 ) == null )
call = call.withInputAt( 0, call.input( 1 ) );
call.input( 0 ).mut().incrementVersion(call);
int hash = Arrays.hashCode( call.input( 0 ).getNDConf().shape() );
Arg.Seed seed = call.get(Arg.Seed.class);
if ( seed != null ) seed = Arg.Seed.of( CPURandomization.initialScramble(seed.get() + hash) );
else seed = Arg.Seed.of( CPURandomization.initialScramble(hash) );
return call.withArgs(seed);
})
.buildFunAlgorithm()
);
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
return src[ 0 ].call( inputs, j );
}
}