Slice.java
package neureka.backend.main.operations.other;
import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.operations.ElemWiseUtil;
import neureka.devices.Device;
import neureka.framing.Relation;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.NDConstructor;
import org.slf4j.Logger;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class Slice extends AbstractOperation
{
private static final Logger _LOG = org.slf4j.LoggerFactory.getLogger( Slice.class );
public Slice()
{
super(
new OperationBuilder()
.identifier( "slice" )
.operator( "slice" )
.arity( 1 )
.isOperator( false )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
Algorithm.withName("slice")
.setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setExecution(
( caller, call ) ->
{
int[] newShape = call.getValOf(Arg.Shape.class);
int[] newOffset = call.getValOf(Arg.Offset.class);
int[] newSpread = call.getValOf(Arg.Stride.class);
Tensor<Object> input = (Tensor<Object>) call.input(0);
Tensor<?> subset = _slice(newShape, newOffset, newSpread, input);
//---
Class<?> typeClass = input.itemType();
Shape shape = input.shape();
boolean isOutsourced = input.isOutsourced();
Device<Object> device = input.getDevice();
//---
_sliceFrame( input, subset, newShape, newOffset, newSpread );
return
Result.of(subset.mut().setIsIntermediate(true))
.withADAction( t -> {
Tensor<Object> newError = ElemWiseUtil.newTensorLike((Class<Object>) typeClass, shape, isOutsourced, device, 0);
boolean isIntermediate = newError.isIntermediate();
newError.mut().setIsIntermediate(false); // To avoid deletion!
Tensor<Object> slice = Function.of("slice(I[0])", false)
.with(Arg.Shape.of(newShape),Arg.Offset.of(newOffset),Arg.Stride.of(newSpread))
.call(newError);
newError.mut().setIsIntermediate(isIntermediate);
slice.mut().setIsIntermediate(false);
Neureka.get().backend().getFunction().idy().execute( slice, t.error().mut().setIsVirtual(false) );
return newError;
});
}
)
.buildFunAlgorithm()
);
}
private static Tensor<?> _slice(
int[] newShape,
int[] newOffset,
int[] newSpread,
Tensor<Object> input
) {
input.mut().setIsVirtual( false );
int[] newStrides = input.getNDConf().strides();
int[] newIndicesMap = input.getNDConf().getLayout().newStridesFor( newShape );
for ( int i = 0; i < input.rank(); i++ )
newSpread[ i ] = ( newSpread[i] == 0 ) ? 1 : newSpread[ i ];
for ( int i = 0; i < newOffset.length; i++ )
newOffset[ i ] = newOffset[ i ] + input.getNDConf().offset( i ); // Offset is being inherited!
Relation<?> inputRelation = input.get( Relation.class );
Tensor<?> rootTensor = ( input.isSlice() ? inputRelation.findRootTensor().orElseThrow(IllegalStateException::new) : input );
Tensor<?> parentTensor = ( input.isSlice() ? inputRelation.getParent().orElseThrow(IllegalStateException::new) : input );
/*
The following code check the validity of the slice shape ranges with
respect to the 'parentTensor' of this new slice.
*/
if ( parentTensor.rank() == newShape.length && rootTensor == parentTensor ) {
Shape parentShape = parentTensor.shape();
if ( Shape.of(newShape).elements() > parentShape.elements() )
throw new IllegalArgumentException(
"The new shape of the slice exceeds the number of elements of the parent tensor!"
);
boolean sliceSeemsToBeCompletelyReshaped = false;
for ( int i = 0; i < newShape.length; i++ ) {
if ( newShape[i] > parentShape.get(i) ) {
sliceSeemsToBeCompletelyReshaped = true;
break;
}
}
/*
1. We know that inside this else branch 'this' tensor is a first order slice!
(So it is not a slice of a slice... reason : 'rootTensor == parentTensor' )
2. There is however uncertainty about the 'true shape' of this parent tensor!
Meaning : It might have been permuted and could therefore be distorted with
respect to the slice that is currently being prepared!
-> This means we have to take this possible reshaping into account!
Like so:
The following uses an int array also called 'permuteRelation'.
This is simply the 'permute array' which has been recorded inside the 'Relation' component
by the 'Reshape' operation! ( Hopefully! :) ... custom shape operations need to consider this as well! )
The following would occur when : "Tensor.of(...).T().getAt(...);"
Transposing a tensor performs an inline reshaping of an identical
slice of the original tensor! Then again slicing this tensor
via the 'getAt(...)' method leads us to a situation where
the following variable is NOT NULL! :
*/
int[] permute = ( input.isSlice() ? parentTensor.get( Relation.class ).getPermuteRelationFor( input ) : null );
permute = ( permute != null ) ? Permute.invert( permute ) : null;
if ( !sliceSeemsToBeCompletelyReshaped ) // If the slice is not reshaped we can do some basic verification:
for ( int i = 0; i < parentShape.size(); i++ ) {
int ii = ( permute != null ) ? permute[ i ] : i;
int top = newOffset[ i ] + newShape[ i ];
if ( top > parentShape.get( ii ) ) {
throw new IllegalArgumentException(
"Cannot create slice because ranges are out of the bounds of the targeted tensor.\n" +
"At index '" + i + "' : offset '" + newOffset[ i ] + "' + shape '" + newShape[ i ] + "' = '" + top + "',\n" +
"which is larger than the target shape '" + parentTensor.shape( ii ) + "' at the same index!"
);
}
}
}
else if ( rootTensor != parentTensor ) {
// TODO! This requires some more thought about how handle slices of slices!
_LOG.warn(
"Exceptional higher order slice request detected. " +
"This type of tensor cannot yet be sliced. " +
"Please copy this tensor before slicing."
);
}
Tensor<Object> subset =
Tensor.of(
input.getDataType(),
NDConstructor.of( newShape, newStrides, newIndicesMap, newSpread, newOffset ),
input.mut().getData()
);
subset.set( Relation.newChildToParent( input ) );
Relation<Object> parent = input.find( Relation.class ).map(r->(Relation<Object>)r).orElseGet(Relation::newParentToChildren);
parent.addChild( subset );
input.set( parent );
if ( input.isOutsourced() )
input.getDevice().store( subset );
if ( input.isVirtual() ) subset.mut().setIsVirtual( true );
return subset;
}
private void _sliceFrame(
Tensor<?> input, Tensor<?> subset, int[] newShape, int[] newOffset, int[] newSpread
) {
// Now if the parent tensor has a name and or axes labels we carry them over to the subset:
String label = input.label();
if ( !label.isEmpty() ) subset.mut().label( label + ":slice" );
input.frame().ifPresent( frame -> {
Map<Object, List<Object>> state = frame.getState();
Map<Object, List<Object>> sliceState = new LinkedHashMap<>();
int i = 0;
for ( Object k : state.keySet() ) {
List<Object> axesLabels = state.get(k);
if ( axesLabels == null )
sliceState.put( k, null ); // newShape[i]
else {
List<Object> slicedLabels = new ArrayList<>();
if ( !axesLabels.isEmpty() ) {
for ( int j = 0; j < newShape[i]; j++ ) {
int index = newOffset[i] + j * newSpread[i];
slicedLabels.add( axesLabels.get(index) );
}
}
sliceState.put( k, slicedLabels );
}
i++;
if ( i == newShape.length ) break;
}
subset.mut().labelAxes( sliceState );
});
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }
}