MemValidator.java

package neureka.backend.main.memory;

import neureka.Tensor;
import neureka.backend.api.Result;

import java.util.Arrays;
import java.util.function.Supplier;
import java.util.stream.IntStream;

/**
 *  This class validates the states of tensors with respect to memory management
 *  before and after a lambda executes a function or some kind of algorithm on said tensors.
 *  This validity refers to the {@link Tensor#isIntermediate()} flag, whose state should
 *  adhere to strict rules in order to allow for safe deletion of tensors.
 *  The lambda wrapped by this may be a {@link neureka.math.Function} call or a lower level
 *  procedure defined a {@link neureka.backend.api.Algorithm} implementation.
 *  <br><br>
 *  <b>Warning! This is an internal class. Do not depend on it.</b>
 */
public class MemValidator {

    private final Result _result;
    private final boolean _wronglyIntermediate;
    private final boolean _wronglyNonIntermediate;

    /**
     * @param inputs The inputs used by the {@link Supplier} implementation to provide a result.
     * @param resultProvider The callback providing the result which ought to be validated.
     * @return The {@link MemValidator} which ought to validate the provided result.
     */
    public static MemValidator forInputs(Tensor<?>[] inputs, Supplier<Result> resultProvider ) {
        return new MemValidator( inputs, resultProvider );
    }

    private MemValidator(Tensor<?>[] tensors, Supplier<Result> execution ) {
        /*
            Now before calling the function we will do a snapshot of the inputs
            in order to later on verify the output validity with respect
            to the 'intermediate' flag.
        */
        Tensor<?>[] inputs = tensors.clone();
        Boolean[] areIntermediates = Arrays.stream(tensors).map(Tensor::isIntermediate).toArray(Boolean[]::new);
        Boolean[] gradIntermediates = Arrays.stream(tensors).map(t -> (t.hasGradient() && t.gradient().get().isIntermediate())).toArray(Boolean[]::new);
        /*
            Finally, we dispatch the call to the function implementation to get as result!
        */
        Result result = execution.get();
        /*
            Now on to validation!
            First we check if the function executed successfully:
        */
        if ( result == null )
            throw new IllegalStateException( "Failed to execute function! Returned result was null." );
        if ( result.get() == null )
            throw new IllegalStateException( "Failed to execute function! Returned result was null." );

        /*
            After that we analyse the validity of the result
            with respect to memory safety!
            We expect internally created tensors to be flagged as 'intermediate':
            First we check if the result tensor was created inside the function or not:
         */
        boolean resultIsInputGradient = Arrays.stream( tensors ).anyMatch( t -> t.gradient().orElse(null) == result.get() );
        boolean resultIsInputMember   = Arrays.stream( tensors ).anyMatch( t -> t == result.get() );
        /*
            Then we check if this is valid with respect to the "isIntermediate" flag:
         */
        if ( resultIsInputMember || resultIsInputGradient ) {
            int positionInInput =
                    resultIsInputGradient
                        ? IntStream.range( 0, inputs.length )
                                   .filter( i -> inputs[i].gradient().orElse(null) == result.get())
                                   .findFirst()
                                   .getAsInt()
                        : IntStream.range( 0, inputs.length )
                                   .filter( i -> inputs[i] == result.get())
                                   .findFirst()
                                   .getAsInt();

            boolean resultWasIntermediate =
                            resultIsInputGradient
                                ? gradIntermediates[positionInInput]
                                : areIntermediates[positionInInput];

            _wronglyIntermediate = result.get().isIntermediate() && !resultWasIntermediate;
            _wronglyNonIntermediate = false;
        } else if ( !result.get().isIntermediate() ) {
            _wronglyIntermediate = false;
            _wronglyNonIntermediate = true;
        } else {
            _wronglyIntermediate = false;
            _wronglyNonIntermediate = false;
        }
        /*
            Last but not least we return the result
        */
        _result = result;
    }

    /**
     * @return Is {@code true} if the result tensor is wrongfully flagged as intermediate (see {@link Tensor#isIntermediate()}).
     */
    public boolean isWronglyIntermediate() { return _wronglyIntermediate; }

    /**
     * @return Is {@code true} if the result tensor is wrongfully flagged as non-intermediate (see {@link Tensor#isIntermediate()}).
     */
    public boolean isWronglyNonIntermediate() { return _wronglyNonIntermediate; }

    /**
     * @return The result tensor returned by the {@link Supplier} lambda passed to this {@link MemValidator}.
     */
    public Result getResult() { return _result; }

}