AbstractNDC.java
package neureka.ndim.config;
import neureka.Neureka;
import neureka.common.utility.Cache;
import neureka.ndim.config.types.permuted.Permuted1DConfiguration;
import neureka.ndim.config.types.permuted.Permuted2DConfiguration;
import neureka.ndim.config.types.permuted.Permuted3DConfiguration;
import neureka.ndim.config.types.permuted.PermutedNDConfiguration;
import neureka.ndim.config.types.simple.*;
import neureka.ndim.config.types.sliced.*;
import neureka.ndim.config.types.views.SimpleReshapeView;
import java.util.Arrays;
import java.util.Objects;
/**
* The following is an abstract implementation of the {@link NDConfiguration} which offers a basis for
* instantiation and caching of concrete implementations extending this abstract class.
* Concrete {@link NDConfiguration} implementations are expected to be immutable which ensures that sharing them is safe.
* In order to cash instances based in their field variables, this class comes with a common
* implementation of the {@link NDConfiguration#hashCode()} method.
* {@link NDConfiguration} implementation instances will be used by tensors which often times
* share the same shape, and way of mapping indices to their respective data.
* In these cases tensors can simply share their {@link NDConfiguration} instances for memory efficiency.
*/
public abstract class AbstractNDC implements NDConfiguration
{
/**
* Instances implementing the {@link NDConfiguration} interface will be cached in the hashmap below.
* In production, we can expect a multitude of tensors having the same shape and also the same way of viewing their data.
* Therefore, they will have configuration instances with the same state.
* Implementations of {@link NDConfiguration} are expected to be immutable which allows us to have them be
* shared between tensors (because they are read only, meaning no side-effects).
*/
private static final Cache<NDConfiguration> _CACHED_NDCS; // Cached ND-Configurations.
static
{
_CACHED_NDCS = new Cache<>(512);
}
/**
* The following is a global cache for readonly integer arrays.
* Warning! This can of course become dangerous when these arrays are being shared and modified.
* Please copy them when exposing them to the user.
*/
private static final Cache<int[]> _CACHED_INT_ARRAYS; // ND-Configurations are often based on integer arrays representing things like shape, steps, etc...
static
{
_CACHED_INT_ARRAYS = new Cache<>(512);
}
/**
* This method receives an int array and returns an int array which
* can either be the one provided or an array found in the global int array cache residing inside
* this class.
* Integer array based configurations are not very large,
* that is why their state can uniquely be encoded in {@code long} keys.
*
* @param data The integer array which ought to be cached.
* @return The provided array or an already present array found in the int array cache.
*/
protected static int[] _cacheArray( int[] data ) { return _CACHED_INT_ARRAYS.process( data ); }
/**
* A factory method which creates and {@link NDConfiguration} instances best suited for the
* provided raw configuration data...
*/
static NDConfiguration construct (
int[] shape,
int[] strides,
int[] indicesMap,
int[] spread,
int[] offset
) {
for ( int dim : shape )
if ( dim == 0 )
throw new IllegalStateException(
"Trying to create tensor configuration containing shape with dimension 0.\n" +
"Shape dimensions must be greater than 0!\n"
);
for ( int i = 0; i < shape.length; i++ )
if ( shape[i] == 1 && spread[i] > 1 )
throw new IllegalStateException(
"Trying to create an '" + NDConfiguration.class.getSimpleName() + "' with a " +
"nonsensical spread/step value for dimension " + i + ", using " +
"shape " + Arrays.toString(shape) + " and spread " + Arrays.toString(spread) + ".\n" +
"A spread/step of size " + spread[i] + " does not make sense for a dimension of size 1 " +
"because you need at least 2 elements to be able to step over them!\n " +
"This is most likely a bug in the Neureka library, please report it!\n"
);
if ( Neureka.get().settings().ndim().isOnlyUsingDefaultNDConfiguration() )
return SlicedNDConfiguration.construct(shape, strides, indicesMap, spread, offset);
boolean isSimple = _isSimpleConfiguration(shape, strides, indicesMap, spread, offset);
boolean isSimpleTransposed = _isSimpleTransposedConfiguration(shape, spread, offset);
if ( isSimple )
{
if ( shape.length == 1 ) {
if ( shape[ 0 ] == 1 )
return Simple0DConfiguration.construct();
else
return Simple1DConfiguration.construct(shape, strides);
}
else if ( shape.length == 2 )
return Simple2DConfiguration.construct(shape, strides);
else if ( shape.length == 3 )
return Simple3DConfiguration.construct(shape, strides);
else
return SimpleNDConfiguration.construct(shape, strides);
}
if ( isSimpleTransposed )
{
if ( shape.length == 1 )
return Permuted1DConfiguration.construct(shape, strides, indicesMap);
else if ( shape.length == 2 )
return Permuted2DConfiguration.construct(shape, strides, indicesMap);
else if ( shape.length == 3 )
return Permuted3DConfiguration.construct(shape, strides, indicesMap);
else
return PermutedNDConfiguration.construct(shape, strides, indicesMap);
}
if ( shape.length == 1 ) {
if ( shape[ 0 ] == 1 )
return Sliced0DConfiguration.construct(shape, offset);
else
return Sliced1DConfiguration.construct(shape, strides, indicesMap, spread, offset);
}
else if ( shape.length == 2 )
return Sliced2DConfiguration.construct(shape, strides, indicesMap, spread, offset);
else if ( shape.length == 3 )
return Sliced3DConfiguration.construct(shape, strides, indicesMap, spread, offset);
// This configuration fits every shape:
return SlicedNDConfiguration.construct(shape, strides, indicesMap, spread, offset);
}
protected static <T extends NDConfiguration> T _cached( T ndc ) { return _CACHED_NDCS.process( ndc ); }
private static boolean _isSimpleConfiguration(
int[] shape,
int[] strides,
int[] indicesMap,
int[] spread,
int[] offset
) {
// Note: Column major is not simple because there are no simple column major implementations...
int[] newStrides = Layout.ROW_MAJOR.newStridesFor( shape );
int[] newSpread = new int[ shape.length ];
Arrays.fill( newSpread, 1 );
return Arrays.equals( strides, newStrides ) &&
Arrays.equals( indicesMap, newStrides ) &&
Arrays.equals( offset, new int[ shape.length ] ) &&
Arrays.equals( spread, newSpread );
}
private static boolean _isSimpleTransposedConfiguration(
int[] shape, int[] spread, int[] offset
) {
int[] newSpread = new int[ shape.length ];
Arrays.fill( newSpread, 1 );
return Arrays.equals( offset, new int[ shape.length ] ) &&
Arrays.equals( spread, newSpread );
}
@Override
public final String toString() {
return "NDConfiguration@"+Integer.toHexString(hashCode())+"#"+Long.toHexString(this.hashCode())+"[" +
"layout="+getLayout().name()+","+
"shape="+Arrays.toString(shape())+","+
"strides="+Arrays.toString(strides())+","+
"indicesMap="+Arrays.toString(indicesMap())+","+
"spread="+Arrays.toString(spread())+","+
"offset="+Arrays.toString(offset())+""+
"]";
}
protected static NDConfiguration _simpleReshape( int[] newForm, NDConfiguration ndc )
{
int[] newShape = Utility.rearrange( ndc.shape(), newForm );
int[] newStrides = ndc.getLayout().rearrange( ndc.strides(), newShape, newForm );
int[] newIndicesMap = ndc.getLayout().newStridesFor( newShape );
int[] newSpread = new int[ newForm.length ];
for ( int i = 0; i < newForm.length; i++ ) {
if ( newForm[ i ] < 0 ) newSpread[ i ] = 1;
else if ( newForm[ i ] >= 0 ) newSpread[ i ] = ndc.spread( newForm[ i ] );
}
int[] newOffset = new int[newForm.length];
for ( int i = 0; i < newForm.length; i++ ) {
if ( newForm[ i ] < 0 ) newOffset[ i ] = 0;
else if ( newForm[ i ] >= 0 ) newOffset[ i ] = ndc.offset( newForm[ i ] );
}
return AbstractNDC.construct(
newShape,
newStrides,
newIndicesMap,
newSpread,
newOffset
);
}
@Override
public NDConfiguration newReshaped( int[] newForm )
{
//TODO : shape check!
if ( _isSimpleConfiguration( shape(), strides(), indicesMap(), spread(), offset() ) )
return _simpleReshape( newForm, this );
else
return new SimpleReshapeView( newForm, this );
}
@Override
public int hashCode() {
return Long.valueOf(
this.getClass().hashCode() +
Arrays.hashCode( shape() ) * 1L +
Arrays.hashCode( strides() ) * 2L +
Arrays.hashCode( indicesMap() ) * 3L +
Arrays.hashCode( spread() ) * 4L +
Arrays.hashCode( offset() ) * 5L +
getLayout().hashCode()
)
.hashCode();
}
@Override
public final boolean equals( Object other ) {
if ( other == null ) return false;
if ( !( other instanceof NDConfiguration ) ) return false;
if ( other == this ) return true;
NDConfiguration ndc = (NDConfiguration) other;
return this.equals( ndc );
}
@Override
public final boolean equals( NDConfiguration ndc ) {
if ( ndc == this ) return true;
return this.getClass() == ndc.getClass() && // TODO: Think about this -> do we require them to be of the same class?
Arrays.equals(this.shape(), ndc.shape() ) &&
Arrays.equals(this.strides(), ndc.strides() ) &&
Arrays.equals(this.indicesMap(), ndc.indicesMap() ) &&
Arrays.equals(this.spread(), ndc.spread() ) &&
Arrays.equals(this.offset(), ndc.offset() ) &&
Objects.equals(this.getLayout(), ndc.getLayout());
}
}