Arg.java
package neureka.math.args;
import neureka.Tensor;
import neureka.common.composition.Component;
import neureka.devices.Device;
import neureka.ndim.config.NDConfiguration;
/**
* Extend this class to define additional meta arguments for {@link neureka.math.Functions}.
* More complex types of operations need additional parameters/arguments.
* The {@link neureka.backend.main.operations.other.Randomization}
* operation for example receives the {@link Seed} argument as a basis
* for deterministic pseudo random number generation...
*
* @param <T> The type parameter defining the type of argument.
*/
public abstract class Arg<T> implements Component<Args> {
private final T _value;
public Arg( T arg ) { _value = arg; }
public T get() {
if ( _value instanceof int[] ) return (T) ((int[]) _value).clone();
if ( _value instanceof float[] ) return (T) ((float[]) _value).clone();
if ( _value instanceof double[] ) return (T) ((double[]) _value).clone();
if ( _value instanceof long[] ) return (T) ((long[]) _value).clone();
if ( _value instanceof short[] ) return (T) ((short[]) _value).clone();
if ( _value instanceof byte[] ) return (T) ((byte[]) _value).clone();
if ( _value instanceof char[] ) return (T) ((char[]) _value).clone();
if ( _value instanceof boolean[] ) return (T) ((boolean[])_value).clone();
return _value;
}
@Override
public boolean update(OwnerChangeRequest<Args> changeRequest) { return true; }
public static class Derivative<V> extends Arg<Tensor<V>> {
public static <V> Derivative<V> of(Tensor<V> arg) { return new Derivative<>(arg); }
private Derivative(Tensor<V> arg) { super(arg); }
}
/**
* This is an import argument whose
* role might not be clear at first :
* An operation can have multiple inputs, however
* when calculating the derivative for a forward or backward pass
* then one must know which derivative ought to be calculated.
* So the "derivative index" targets said input.
* This property is -1 when no derivative should be calculated,
* however 0... when targeting an input to calculate the derivative of.
*/
public static class DerivIdx extends Arg<Integer> {
public static DerivIdx of( int index ) { return new DerivIdx(index); }
private DerivIdx(int arg) { super(arg); }
}
public static class Axis extends Arg<Integer> {
public static Axis of(int index ) { return new Axis(index); }
private Axis(int arg) { super(arg); }
}
public static class Ends extends Arg<int[]> {
public static Ends of( int[] arg ) { return new Ends(arg); }
private Ends(int[] arg) { super(arg); }
}
public static class TargetDevice extends Arg<Device<?>> {
public static TargetDevice of( Device<?> arg ) { return new TargetDevice(arg); }
private TargetDevice(Device<?> arg) { super(arg); }
}
/**
* The following argument is relevant for a particular type of operation, namely: an "indexer". <br>
* An indexer automatically applies an operation on all inputs for a given function.
* The (indexer) function will execute the sub functions (of the AST) for every input index.
* If a particular index is not targeted however this variable will simply default to -1.
*/
public static class VarIdx extends Arg<Integer> {
public static VarIdx of( int arg ) { return new VarIdx( arg ); }
private VarIdx(int arg) { super(arg); }
}
public static class MinRank extends Arg<Integer> {
public static MinRank of( int arg ) { return new MinRank( arg ); }
private MinRank( int arg ) { super(arg); }
}
public static class Seed extends Arg<Long> {
public static Seed of( String arg ) { return new Seed( _longStringHash( arg ) ); }
public static Seed of( long arg ) { return new Seed( arg ); }
private Seed( long arg ) { super(arg); }
private static long _longStringHash( String string )
{
long h = 1125899906842597L; // prime
int len = string.length();
for ( int i = 0; i < len; i++ ) h = 31 * h + string.charAt( i );
return h;
}
}
public static class Shape extends Arg<int[]> {
public static Shape of( int... arg ) { return new Shape( arg ); }
private Shape( int[] arg ) { super(arg); }
}
public static class Offset extends Arg<int[]> {
public static Offset of( int... arg ) { return new Offset( arg ); }
private Offset( int[] arg ) { super(arg); }
}
public static class Stride extends Arg<int[]> {
public static Stride of( int... arg ) { return new Stride( arg ); }
private Stride( int[] arg ) { super(arg); }
}
public static class Indices extends Arg<int[]> {
public static Indices of( int... arg ) { return new Indices( arg ); }
private Indices( int[] arg ) { super(arg); }
}
public static class Layout extends Arg<NDConfiguration.Layout> {
public static Layout of(NDConfiguration.Layout arg) { return new Layout( arg ); }
private Layout( NDConfiguration.Layout arg ) { super(arg); }
}
@Override
public String toString() { return this.getClass().getSimpleName() + "[" + _value + "]"; }
}