ScalarSumAlgorithm.java

package neureka.backend.main.algorithms;

import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunAlgorithm;

public class ScalarSumAlgorithm extends AbstractFunAlgorithm
{
    public ScalarSumAlgorithm() {
        super("scalar_sum_algorithm");
        setIsSuitableFor(
            call ->
                call.validate()
                .allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
                .allNotNull( t -> t.size() == 1 || t.isVirtual() )
                .suitabilityIfValid( PERFECT ) // You cannot come up with something faster than this! ;D
        )
        .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
        .setExecution( (caller, call) -> {
            Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
            call = call.withInputs(inputs);
            if ( call.input( 0 ) == null )
                call = call.withInputAt( 0, call.input( 1 ) );

            Tensor<?> in = call.input(0);
            Shape originalShape = in.shape();
            Number item = (Number) in.item();
            double sum = item.doubleValue() * in.size();
            Tensor<?> result = Tensor.of( in.itemType(), Shape.of( 1 ), sum ).to( in.getDevice() );
            return Result.of( result.mut().setIsIntermediate(true) )
                    .withADAction( target -> {
                        Tensor<Object> error = (Tensor<Object>) target.error();
                        assert error.size() == 1;
                        return Tensor.of(error.itemType(), originalShape, error.item()).to(error.getDevice());
                    });
        })
        .buildFunAlgorithm();
    }
}