Permute.java
package neureka.backend.main.operations.other;
import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.math.parsing.FunctionParser;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.framing.Relation;
import neureka.ndim.NDUtil;
import neureka.ndim.config.NDConfiguration;
public class Permute extends AbstractOperation
{
public Permute()
{
super(
new OperationBuilder()
.identifier( "permute" )
.operator( "," )
.arity( -1 )
.isOperator( true )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
Algorithm
.withName( "permute" )
.setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setExecution(
( caller, call ) ->
{
Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
int[] axisIndicesOrder = call.getValOf( Arg.Indices.class );
if ( axisIndicesOrder == null ) {
axisIndicesOrder = new int[inputs.length - 1];
for (int i = 0; i < inputs.length - 1; i++)
axisIndicesOrder[i] = ((Number) inputs[i].item(0)).intValue();
}
else if ( inputs.length > 1 )
throw new IllegalArgumentException(
"Conflicted arguments detected, either the first inputs are tensors representing indices, " +
"or the indices are given as meta arguments, in which case only a single tensor is expected as input!"
);
if ( call.getValOf( Arg.DerivIdx.class ) >= 0 ) //reverse permute:
axisIndicesOrder = invert( axisIndicesOrder );
return Result.of(_rearrangeAxisOf( inputs[ inputs.length - 1 ], axisIndicesOrder, true ))
.withADAction( target -> new FunctionParser( Neureka.get().backend() ).parse( caller.toString(), false ).derive( new Tensor[]{ target.error() },0 ) );
}
)
.buildFunAlgorithm()
);
}
private static Tensor<?> _rearrangeAxisOf(Tensor<?> tensor, int[] indicesOrder, boolean newTensor )
{
Tensor<?> parent = tensor;
tensor = newTensor ? tensor.shallowCopy().mut().setIsIntermediate( true ) : tensor;
NDConfiguration newNDC = tensor.getNDConf().newReshaped( indicesOrder );
_shapeCheck( newNDC.shape(), tensor );
tensor.mut().setNDConf( newNDC );
if ( newTensor ) {
Relation r = parent.get( Relation.class );
r.addPermuteRelationFor( tensor, indicesOrder );
}
return tensor;
}
public static void makeFit(Tensor<?>[] tensors, boolean doesAD )
{
int largest = -1;
int[] shape = null;
for ( Tensor<?> t : tensors ) if ( t.rank() > largest ) {
largest = t.rank();
shape = t.getNDConf().shape();
}
int[] endings = DimTrim.endsFrom( shape );
int prefix = endings[0];
int postfix = endings[1];
for ( int i = 0; i < tensors.length; i++ ) {
if ( tensors[ i ].rank() != largest ) {
int[] oldShape = tensors[ i ].getNDConf().shape();
int[] newReshape = new int[ largest ];
int padding = largest - oldShape.length;
int handle = ( postfix <= prefix ) ? padding : largest - padding;
for ( int ii = 0; ii < handle; ii++ ) newReshape[ ii ] = ( postfix <= prefix ) ? -1 : ii;
for ( int ii = handle; ii < largest; ii++ ) newReshape[ ii ] = ( postfix <= prefix ) ? ii - padding : -1;
Function f = Function.of(
NDUtil.shapeString( newReshape ) + ":(I[ 0 ])",
doesAD
);
tensors[ i ] = f.execute( tensors[ i ] );
}
}
}
public static int[] invert( int[] axisIndicesOrder )
{
int reverseLength = 0;
for ( int e : axisIndicesOrder )
if ( e >= 0 ) reverseLength++;
int[] reversed = new int[ reverseLength ];
int currentIndex = 0;
int reverseIndex = 0;
while ( reverseIndex < reverseLength ) {
if ( axisIndicesOrder[ currentIndex ] >= 0 ) {
reversed[ axisIndicesOrder[ currentIndex ] ] = currentIndex;
reverseIndex++;
}
currentIndex++;
}
return reversed;
}
@Override
public String stringify( String[] children ) {
java.util.function.Function<String, Boolean> isConstantNumeric =
s -> {
try {
Double.parseDouble(s);
return true;
} catch (Exception e) { return false; }
};
StringBuilder reconstructed = new StringBuilder();
reconstructed.insert(0, "[");
for ( int i = 0; i < children.length; ++i ) {
if ( i == children.length - 1 ) {
reconstructed.append("]:(").append(
( isConstantNumeric.apply( children[ i ] ) )
? children[ i ].split("\\.")[ 0 ]
: children[ i ]
).append(")");
} else
reconstructed.append(
( isConstantNumeric.apply( children[ i ] ) )
? children[ i ].split("\\.")[ 0 ]
: children[ i ]
);
if ( i < children.length - 2 )
reconstructed.append(",");
}
return "(" + reconstructed + ")";
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src )
{
return src[ 0 ].call( inputs, j );
}
private static void _shapeCheck( int[] newShp, Tensor<?> t ) {
if ( NDConfiguration.Utility.sizeOfShape( newShp ) != t.size() ) {
throw new IllegalArgumentException(
"New shape does not match tensor size!" +
" (" +
NDUtil.shapeString( newShp ) +
((NDConfiguration.Utility.sizeOfShape( newShp ) < t.size()) ? "<" : ">") +
NDUtil.shapeString(t.getNDConf().shape()) + "" +
")"
);
}
}
}