Result.java

package neureka.backend.api;

import neureka.Tensor;
import neureka.autograd.ADAction;
import neureka.backend.api.fun.ADActionSupplier;
import neureka.backend.api.fun.Execution;
import neureka.common.utility.LogUtil;

/**
 *  An immutable wrapper for a tensor as a result of anb {@link Execution}
 *  as well as an {@link ADActionSupplier} for providing auto-differentiation support.
 */
public final class Result
{
    private final Tensor<?> _tensor;
    private final ADActionSupplier _agent;

    public static Result of( Tensor<?> tensor ) {
        LogUtil.nullArgCheck( tensor, "tensor", Tensor.class, "An operation may not return 'null'!" );
        return new Result(tensor, null);
    }

    private Result(Tensor<?> tensor, ADActionSupplier agent ) {
        _tensor = tensor;
        _agent = agent;
    }

    public Result withADAction( ADAction action ) {
        return this.withAutoDiff( (caller, call) -> ADAction.of(action) );
    }

    public Result withAutoDiff( ADActionSupplier agent ) {
        LogUtil.nullArgCheck( agent, "agent", ADAction.class );
        if ( _agent != null )
            throw new IllegalArgumentException("Autograd algorithm already specified!");
        return new Result( _tensor, agent );
    }

    public <V> Tensor<V> get() { return (Tensor<V>) _tensor; }

    public ADActionSupplier getAgentSupplier() { return _agent; }

}