NodePayload.java

package neureka.autograd;

import neureka.Tensor;
import neureka.dtype.DataType;

import java.lang.ref.WeakReference;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

final class NodePayload<V> {

    private final int _payloadReferenceVersion;

    private final int[] _payloadShape;

    private final DataType<V> _payloadDataType;

    private final WeakReference<Tensor<V>> _payload;


    public NodePayload( Tensor<V> p ) {
        if ( p == null ) {
            _payload = null;
            _payloadShape = null;
            _payloadReferenceVersion = -1;
            _payloadDataType = null;
        }
        else {
            assert !p.isUndefined();
            _payload = new WeakReference<>( p );
            _payloadShape = p.getNDConf().shape();
            _payloadReferenceVersion = p.getVersion();
            _payloadDataType = p.getDataType();
        }
    }

    public DataType<V> payloadDataType() { return _payloadDataType; }

    public int payloadReferenceVersion() { return _payloadReferenceVersion; }

    /**
     *  The value of a graph node is the tensor to which it belongs (is a component of).  <br><br>
     *  Meaning it is the tensor owning this {@link GraphNode} component.
     *  It is referenced weakly because it might not be needed any more (Not referenced inside AD-Agent for example)
     *  and can therefore be garbage collected.
     *
     *  Warning: This method might return null because
     *           the payload is weakly referenced!
     *           Meaning that it might get garbage collected.
     *
     * @return The tensor payload of this graph-node.
     */
    public Tensor<V> getPayload() { return ( _payload == null ? null : _payload.get() ); }

    /**
     *  Note: This method will never return null even if the actual payload tensor was garbage collected.
     *  This is because the {@link GraphNode} will remember the shape of the tensor.
     *
     *  @return The shape of the payload tensor represented by this {@link GraphNode}.
     */
    public List<Integer> getPayloadShape() {
        return _payloadShape == null ? null : Arrays.stream(_payloadShape).boxed().collect(Collectors.toList());
    }

}