Cat.java
package neureka.backend.main.operations.other;
import neureka.Neureka;
import neureka.Tensor;
import neureka.backend.api.Algorithm;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.framing.NDFrame;
import neureka.math.Function;
import neureka.math.args.Arg;
import java.util.*;
import java.util.stream.Collectors;
public class Cat extends AbstractOperation
{
public Cat()
{
super(
new OperationBuilder()
.identifier( "concat" )
.operator( "concat" )
.arity( -1 ) // Any number of arguments
.isOperator( false )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
Algorithm
.withName("concat")
.setIsSuitableFor( call -> {
Integer dim = call.getValOf(Arg.Axis.class);
Tensor<?> a = call.input(0);
Tensor<?> b = call.input(1);
if ( a.rank() != b.rank() ) return SuitabilityPredicate.UNSUITABLE;
for ( int i = 0; i < a.rank(); i++ )
if ( i != dim && a.shape(i) != b.shape(i) )
return SuitabilityPredicate.UNSUITABLE;
return SuitabilityPredicate.GOOD;
})
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setExecution(
( caller, call ) ->
{
// The dimension alongside we want to concat:
Integer dim = call.getValOf(Arg.Axis.class);
// First let's find out the shape of the concatenated result:
Tensor<?>[] inputs = call.inputs();
List<Integer> axes = Arrays.stream(inputs).map( t -> t.shape(dim) ).collect(Collectors.toList());
int newAxisSize = axes.stream().mapToInt( i -> i ).sum();
List<Integer> newShape = new ArrayList<>();
for ( int i = 0; i < call.input(0).rank(); i++ )
newShape.add( i == dim ? newAxisSize : call.input(0).shape(i) );
// We create the output tensor:
Tensor<?> c = Tensor.of( call.input(0).getItemType(), newShape, 0 );
// We make the axes list entries cumulative:
for ( int i = 0; i < axes.size(); i++ )
axes.set( i, ( i == 0 ? axes.get(i) : axes.get( i - 1 ) + axes.get(i) ) );
// Now we need to create the slices of c needed to populate c:
for ( int i = 0; i < inputs.length; i++ ) {
int start = i == 0 ? 0 : axes.get( i - 1 );
int end = ( axes.get( i ) - 1 );
Tensor<?> slice = c.slice().axis( dim ).from( start ).to( end ).detached();
Neureka.get().backend().getFunction().idy().execute( slice, call.input( i ) );
}
c.mut().setIsIntermediate(true);
try {
_catFrames( inputs, c, dim );
} catch ( Exception e ) {
e.printStackTrace();
// Framing is not that important, a result however is!
// So an exception in the frame concatenation is not fatal!
}
return
Result.of(c)
.withADAction( target -> {
int i = target.inputIndex();
int start = i == 0 ? 0 : axes.get( i - 1 );
int end = axes.get( i ) - 1;
return target.error().slice().axis(dim).from(start).to(end).detached();
});
}
)
.buildFunAlgorithm()
);
}
@Override public Result execute( Function caller, ExecutionCall<?> call )
{
if ( caller.isFlat() && caller.numberOfArgs() != call.inputs().length )
throw new IllegalArgumentException("The number of arguments of the function call does not match the number of inputs!");
return super.execute( caller, call );
}
private void _catFrames(Tensor<?>[] inputs, Tensor<?> concat, int dim )
{
boolean inputsAreFramed = Arrays.stream(inputs).anyMatch( t -> t.frame().isPresent() );
if ( !inputsAreFramed ) return;
String label =
Arrays.stream(inputs)
.map(Tensor::frame)
.filter(Optional::isPresent)
.map(Optional::get)
.map(NDFrame::getLabel)
.collect(Collectors.joining("+"));
if ( !label.isEmpty() ) concat.mut().label(label);
List<Map<Object, List<Object>>> labels =
Arrays.stream(inputs)
.map(Tensor::frame)
.filter(Optional::isPresent)
.map(Optional::get)
.map(NDFrame::getState)
.collect(Collectors.toList());
List<List<Object>> allKeys = labels.stream().map( l -> new ArrayList<>(l.keySet()) ).collect(Collectors.toList());
Map<Object, List<Object>> concatFrame = new LinkedHashMap<>();
for ( int ci = 0; ci < concat.rank(); ci++ ) {
int finalCi = ci;
List<Object> distinctKeys = allKeys.stream().map(ks->ks.get(finalCi) ).distinct().collect(Collectors.toList());
Object key;
{
boolean allString = distinctKeys.stream().allMatch(k -> k instanceof String);
if (allString) // We join using the "+" operator:
key = distinctKeys.stream().map(k -> (String) k).collect(Collectors.joining("+"));
else // We simply take the first one:
key = distinctKeys.get(0);
}
List<Object> values = new ArrayList<>();
if ( ci == dim ) {
/*
We need to join the value lists of all the frames
and then set the state of the concatenated tensor frame.
*/
for ( int i = 0; i < labels.size(); i++ ) {
Map<Object, List<Object>> current = labels.get(i);
List<Object> currentKeys = allKeys.get(i);
List<Object> currentValues = current.get(currentKeys.get(ci));
values.addAll(currentValues);
}
} else {
/*
This is not as simple as the above case!
We have conflicting values for the same key, so we do the following:
1. If the values are all equal we just take the first one.
2. If the values are not equal but all of type string, we join them with a "+".
3. If the values are not equal and not all of type string, we just take the first one.
*/
for ( int j = 0; j < concat.shape(ci); j++ ) {
List<Object> valuesForThisIndex = new ArrayList<>();
for ( int i = 0; i < labels.size(); i++ ) {
Map<Object, List<Object>> current = labels.get(i);
List<Object> currentKeys = allKeys.get(i);
List<Object> currentValues = current.get(currentKeys.get(ci));
if ( j < currentValues.size() )
valuesForThisIndex.add(currentValues.get(j));
}
boolean allEqual = valuesForThisIndex.stream().distinct().count() == 1;
if ( allEqual )
values.add(valuesForThisIndex.get(0));
else if ( !valuesForThisIndex.isEmpty() ) {
boolean allString = valuesForThisIndex.stream().allMatch( v -> v instanceof String );
if ( allString )
values.add(valuesForThisIndex.stream().map( v -> (String) v ).collect(Collectors.joining("+")));
else
values.add(valuesForThisIndex.get(0));
}
}
}
concatFrame.put(key, values);
}
if ( !concatFrame.isEmpty() ) concat.mut().labelAxes(concatFrame);
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }
}