DimTrim.java
package neureka.backend.main.operations.other;
import neureka.Neureka;
import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.ndim.config.NDConfiguration;
import java.util.ArrayList;
import java.util.List;
public class DimTrim extends AbstractOperation
{
public DimTrim()
{
super(
new OperationBuilder()
.identifier( "dimtrim" )
.operator( "dimtrim" )
.arity( 1 )
.isOperator( false )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
Algorithm
.withName("dimTrim")
.setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setExecution(
( caller, call ) ->
{
ADAction autoDiff = target ->
{
int[] endings = endsFrom( call.input( 0 ).getNDConf().shape() );
int prefix = endings[ 0 ];
int postfix = endings[ 1 ];
return
call.autogradMode() == AutoDiffMode.FORWARD_ONLY
? new FunctionParser( Neureka.get().backend() )
.parse(caller.toString(), false)
.derive(new Tensor[]{target.error()},0)
: _pad(target.error(), new int[]{prefix, postfix}, true);
};
Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten( caller, call ).inputs();
assert inputs.length == 1;
Tensor<?> t = inputs[ 0 ];
if ( call.getValOf( Arg.DerivIdx.class ) == 0 ) {
int prefix = call.getValOf(Arg.Ends.class)[ 0 ];
int postfix = call.getValOf(Arg.Ends.class)[ 1 ];
return Result.of(_pad( t, new int[]{prefix, postfix}, true )).withADAction(autoDiff);
} else
return Result.of(_trim( t, true )).withADAction(autoDiff);
}
)
.buildFunAlgorithm()
);
}
private static <T> Tensor<T> _pad(Tensor<T> tensor, int[] ends, boolean newTensor ) {
if ( tensor.getNDConf().getLayout() == NDConfiguration.Layout.COLUMN_MAJOR )
throw new IllegalArgumentException("Column major not yet supported for shape trimming!");
tensor = ( newTensor ? tensor.getAt(new ArrayList<>()) : tensor );
List<Integer> newShape = new ArrayList<>();
List<Integer> newTranslation = new ArrayList<>();
List<Integer> newIndicesMap = new ArrayList<>();
List<Integer> newSpread = new ArrayList<>();
List<Integer> newOffset = new ArrayList<>();
int[] shape = tensor.getNDConf().shape();
int prefix = ends[ 0 ];
int postfix = ends[ 1 ];
for ( int i = 0; i < prefix; i++ ) {
newShape.add( 1 );
newTranslation.add( 1 );
newIndicesMap.add( 1 );
newSpread.add( 0 );
newOffset.add( 0 );
}
for ( int i = 0; i < shape.length; i++ ) {
newShape.add(shape[ i ]);
newTranslation.add(tensor.getNDConf().strides( i ));
newIndicesMap.add(tensor.getNDConf().indicesMap( i ));
newSpread.add(tensor.getNDConf().spread( i ));
newOffset.add(tensor.getNDConf().offset( i ));
}
for ( int i = 0; i < postfix; i++ ) {
newShape.add( 1 );
newTranslation.add( 1 );
newIndicesMap.add( 1 );
newSpread.add( 0 );
newOffset.add( 0 );
}
tensor
.mut()
.setNDConf(
NDConfiguration.of(
newShape.stream().mapToInt( i -> i ).toArray(),
newTranslation.stream().mapToInt( i -> i ).toArray(),
newIndicesMap.stream().mapToInt( i -> i ).toArray(),
newSpread.stream().mapToInt( i -> i ).toArray(),
newOffset.stream().mapToInt( i -> i ).toArray()
)
);
return tensor;
}
private static Tensor<?> _trim(Tensor<?> tensor, boolean newTensor )
{
if ( tensor.getNDConf().getLayout() == NDConfiguration.Layout.COLUMN_MAJOR )
throw new IllegalArgumentException("Column major not yet supported for shape trimming!");
tensor = ( newTensor ? tensor.getAt( new ArrayList<>() ).mut().setIsIntermediate( true ) : tensor );
List<Integer> newShape = new ArrayList<>();
List<Integer> newTranslation = new ArrayList<>();
List<Integer> newIndicesMap = new ArrayList<>();
List<Integer> newSpread = new ArrayList<>();
List<Integer> newOffset = new ArrayList<>();
int[] shape = tensor.getNDConf().shape();
int[] endings = endsFrom( tensor.getNDConf().shape() );
int prefix = endings[ 0 ];
int postfix = endings[ 1 ];
for ( int i = prefix; i < shape.length-postfix; i++ ) {
newShape.add( shape[ i ] );
newTranslation.add( tensor.getNDConf().strides( i ) );
newIndicesMap.add( tensor.getNDConf().indicesMap( i ) );
newSpread.add( tensor.getNDConf().spread( i ) );
newOffset.add( tensor.getNDConf().offset( i ) );
}
if ( newOffset.size() > 0 ) {
// We determine the prefix offset:
int prefixOffset = 0;
for (int i = 0; i < prefix; i++)
prefixOffset += tensor.getNDConf().strides(i) * tensor.getNDConf().offset(i);
// We adjust the offset of the first non-trimmed dimension:
newOffset.set(0, newOffset.get(0) + prefixOffset);
}
tensor
.mut()
.setNDConf(
NDConfiguration.of(
newShape.stream().mapToInt( i -> i ).toArray(),
newTranslation.stream().mapToInt( i -> i ).toArray(),
newIndicesMap.stream().mapToInt( i -> i ).toArray(),
newSpread.stream().mapToInt( i -> i ).toArray(),
newOffset.stream().mapToInt( i -> i ).toArray()
)
);
return tensor;
}
public static int[] endsFrom( int[] shape ) {
int prefix = 0;
for ( int s : shape ) if ( s == 1 ) prefix++; else break;
int postfix = 0;
for ( int i = shape.length-1; i >= 0; i-- ) if ( shape[ i ] == 1 ) postfix++; else break;
return new int[]{ prefix, postfix };
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) {
return src[ 0 ].call( inputs, j );
}
}