package neureka.backend.main.operations.other;
import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.framing.Relation;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.NDConstructor;
import neureka.ndim.config.NDConfiguration;
public class Reshape extends AbstractOperation
public Reshape()
new OperationBuilder()
.identifier( "reshape" )
.operator( "reshape" )
.arity( 1 )
.isOperator( false )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
.withName( "reshape" )
.setIsSuitableFor( call -> SuitabilityPredicate.GOOD )
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
( caller, call ) ->
Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
Tensor<Object> input = (Tensor<Object>) inputs[0];
int[] foundShape = call.getValOf( Arg.Shape.class );
if ( foundShape == null )
throw new IllegalArgumentException("Shape argument is missing!");
int[] shape = _resolveNewShape(input.size(), foundShape);
Tensor reshaped = Tensor.of(
NDConstructor.of( shape ),
reshaped.set( Relation.newChildToParent( input ) );
Relation parent = input.find( Relation.class ).orElseGet(Relation::newParentToChildren);
parent.addChild( reshaped );
input.set( parent );
if ( input.isOutsourced() )
input.getDevice().store( reshaped );
NDConfiguration originalConfig = input.getNDConf();
return Result.of(reshaped.mut().setIsIntermediate(true))
.withADAction( target -> {
Tensor<Object> error = (Tensor<Object>) target.error();
return Tensor.of(
NDConstructor.of( originalConfig ),
* If the provided shape array contains a -1 as one of its elements,
* then this method will resolve the -1 to the correct value
* which results in a shape array which is compatible with the provided size,
* meaning that when we multiply all the elements of the resolved shape array
* we will get the provided size.
* @param size The size which the resolved shape array should be compatible with.
* @param shape The shape array which may contain a -1.
* @return The resolved shape array.
private static int[] _resolveNewShape( int size, int[] shape )
int[] resolvedShape = new int[ shape.length ];
int minusOneIndex = -1;
int minusOneCount = 0;
int product = 1;
for ( int i = 0; i < shape.length; i++ )
if ( shape[ i ] == -1 )
minusOneIndex = i;
resolvedShape[ i ] = shape[ i ];
product *= shape[ i ];
if ( minusOneCount > 1 )
throw new IllegalArgumentException("The shape array contains more than one -1!");
if ( minusOneCount == 1 )
resolvedShape[ minusOneIndex ] = size / product;
return resolvedShape;
public double calculate( double[] inputs, int j, int d, Function[] src )
return src[ 0 ].call( inputs, j );