Broadcast.java

package neureka.backend.main.algorithms;

import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.backend.main.operations.other.Permute;
import neureka.devices.Device;
import neureka.dtype.NumericType;

public final class Broadcast extends AbstractFunDeviceAlgorithm<Broadcast>
{
    public Broadcast()
    {
        super("broadcast");
        setIsSuitableFor(
            call ->
            {
                boolean isInvalid =
                            !call.validate()
                                 .allNotNull( t -> t.getDataType().typeClassImplements(NumericType.class) )
                                 .isValid();

                if ( isInvalid )
                    return SuitabilityPredicate.UNSUITABLE;

                int maxRank = 0;
                for ( Tensor<?> t : call.inputs() )
                    if ( t != null && t.rank() > maxRank ) maxRank = t.rank();

                for ( int i = 0; i < maxRank; i++ )
                {
                    int currentDim = -1;
                    for( Tensor<?> t : call.inputs() )
                    {
                        if ( t != null && i < t.rank() ) {
                            if ( currentDim == -1 ) currentDim = t.shape( i );
                            else if ( currentDim != t.shape( i ) && currentDim != 1 && t.shape( i ) != 1 ) return 0.0f;
                        }
                    }
                }
                return SuitabilityPredicate.GOOD;
            }
        );
        setAutogradModeFor(
            call ->
                call.validate()
                    .all( ( first, second ) -> first.shape().equals(second.shape()) )
                    .ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
                    .orElse(AutoDiffMode.BACKWARD_ONLY)
        );
        setExecution( (outerCaller, outerCall) ->
                        Result.of(AbstractDeviceAlgorithm.prepareAndExecute(
                                outerCall,
                                innerCall -> AbstractDeviceAlgorithm.executeDeviceAlgorithm( innerCall )
                        ))
        );
        setCallPreparation(
            call ->
            {
                if ( call.arity() < 3 ) call = call.withAddedInputAt(0, null);
                int offset = ( call.input( Number.class, 0 ) == null ? 1 : 0 );
                if (
                    call.input( Number.class, offset).shape().size() != call.input( Number.class, 1+offset).shape().size()
                )
                {
                    Tensor<?>[] inputs = {call.input( Number.class, offset), call.input( Number.class, 1+offset) };
                    Permute.makeFit( inputs, true );
                    inputs = new Tensor[]{ null, inputs[0], inputs[1] };
                    call = call.withInputs( inputs );
                }

                Device device = call.getDevice();
                if ( call.input( 0 ) == null ) // Creating a new tensor:
                {
                    int[] s1 = call.input( 1 ).getNDConf().shape();
                    int[] s2 = call.input( 2 ).getNDConf().shape();

                    assert s1.length == s2.length;
                    int[] outShape = new int[s1.length];

                    for ( int i = 0; i < outShape.length; i++ )
                        assert s1[ i ] == 1 || s2[ i ] == 1 || s1[ i ] == s2[ i ];

                    for ( int i = 0; i < outShape.length; i++ )
                        outShape[ i ] = ( s1[ i ] == 1 ? s2[ i ] : s1[ i ] );

                    Class<Object> type = (Class<Object>) call.input(  1 ).getItemType();
                    Tensor<?> output = Tensor.of(type).withShape(outShape).all( 0.0 ).mut().setIsIntermediate( true );
                    output.mut().setIsVirtual( false );
                    device.store( output );
                    call = call.withInputAt( 0, output );
                }
                return call;
            }
        );
    }

}