PendingError.java

package neureka.autograd;

import neureka.Tensor;
import neureka.backend.main.memory.MemUtil;

/**
 *  A wrapper for a tensor which is used to accumulate error values
 *  during the back propagation phase of the autograd algorithm.
 *  This is a library internal class, do not depend on this.
 *  <p>
 *  The {@link PendingError} class also keeps track of how many
 *  more error values need to be accumulated before the error
 *  value is fully accumulated.
 *
 * @param <V> The data type of the tensor which is used to accumulate error values.
 */
final class PendingError<V>
{
    private final int _expectedToBeReceived;
    private int _received;
    private final Tensor<V> _accumulatedError;
    private final int _generation;

    public PendingError( Tensor<V> error, int toBeReceived, int generation ) {
        _expectedToBeReceived = toBeReceived;
        _received = 1; // 1 because the first error value is already given to the constructor.
        _accumulatedError = error;
        _generation = generation;
    }

    public void accumulate( Tensor<?> error ) {
        Tensor[] inputs = { _accumulatedError, error };
        MemUtil.keep( inputs, () -> {
                    _accumulatedError.mut().plusAssign((Tensor<V>)error);
                    return null;
                });
        _received++;
        if ( _received > _expectedToBeReceived ) {
            throw new IllegalStateException(
                    "Received more error values than expected! " +
                    "Expected: " + _expectedToBeReceived + ", " +
                    "Received: " + _received + "."
            );
        }
    }

    public boolean isFullyAccumulated() {
        return _received == _expectedToBeReceived;
    }

    public int getGeneration() { return _generation; }

    public String toString() {
        return this.getClass().getSimpleName()+"[" +
                    "received=" + _received + "," +
                    "accumulatedError=" + _accumulatedError + "," +
                    "generation=" + _generation +
                "]";
    }

    public int getReceived() { return _received; }

    public int getExpectedToBeReceived() { return _expectedToBeReceived; }

    public Tensor<V> getAccumulatedError() { return _accumulatedError; }

}