SumAlgorithm.java

package neureka.backend.main.algorithms;

import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.DeviceAlgorithm;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;

public class SumAlgorithm extends AbstractFunDeviceAlgorithm<SumAlgorithm>
{
    public SumAlgorithm() {
        super("sum_algorithm");
        setIsSuitableFor(
                call -> call.validate()
                        .allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
                        .basicSuitability()
        )
        .setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
        .setExecution( (caller, call) -> {
            Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
            call = call.withInputs(inputs);
            Tensor<?> result = ((DeviceAlgorithm)call.getAlgorithm()).getImplementationFor(call.getDevice()).run(call);
            Shape originalShape = call.input(0).shape();
            return Result.of(
                            result.mut().setIsIntermediate(true)
                    )
                    .withADAction( target -> {
                        Tensor<Object> error = (Tensor<Object>) target.error();
                        assert error.size() == 1;
                        return Tensor.of(error.itemType(), originalShape, error.item()).to(error.getDevice());
                    });
        })
        .setCallPreparation( call ->
        {
            if ( call.input( 0 ) == null )
                call = call.withInputAt( 0, call.input( 1 ) );

            return call;
        })
        .buildFunAlgorithm();
    }
}