SliceBuilder.java

package neureka.fluent.slicing;

import neureka.Tensor;
import neureka.fluent.slicing.states.AxisOrGetTensor;
import neureka.fluent.slicing.states.FromOrAtTensor;
import neureka.math.args.Arg;
import neureka.fluent.slicing.states.FromOrAt;

import java.util.function.Function;


/**
 *  This class is the heart of the slice builder API, collecting range configurations by
 *  exposing an API consisting of multiple interfaces which form a call state transition graph.
 *  Instances of this class do not perform the actual slicing of a {@link Tensor} instance themselves,
 *  however instead they merely serve as collectors of slice configuration data.
 *  The API exposed by the {@link SliceBuilder} uses method chaining as well as a set of implemented interfaces
 *  which reference themselves in the form of the return types defined by the method signatures of said interfaces.
 *  A user of the API can only call methods exposed by the current "view" of the builder, namely a interface.
 *  This ensures a controlled order of calls to the API...
 *
 * @param <V> The type of the value(s) held by the tensor which ought to be sliced with the help of this builder.
 */
public class SliceBuilder<V> implements AxisOrGetTensor<V>
{
    private interface CreationCallback<V> {
        Tensor<V> sliceOf(int[] newShape, int[] newOffset, int[] newSpread, boolean autograd);
    }

    private final Function<Boolean, Tensor<V>> _create;
    private final AxisSliceBuilder<V>[]  _axisSliceBuilders;

    /**
     *  An instance of a slice builder does not perform the actual slicing itself!
     *  Instead, it merely serves as a collector of slice configuration data.
     *  The actual slicing will be performed by the {@link CreationCallback} passed
     *  to this constructor.
     *
     * @param toBeSliced The {@link Tensor} instance which ought to be sliced.
     */
    public SliceBuilder( Tensor<V> toBeSliced )
    {
        CreationCallback<V> sliceCreator = // A callback lambda which receives the final slice configuration to perform the actual slicing.
        ( int[] newShape, int[] newOffset, int[] newSpread, boolean allowAutograd )->
        {
            boolean isIntermediate = toBeSliced.isIntermediate();
            toBeSliced.getMut().setIsIntermediate(false); // To avoid deletion!
            Tensor<V> slice = neureka.math.Function.of("slice(I[0])", allowAutograd)
                                .with(Arg.Shape.of(newShape),Arg.Offset.of(newOffset),Arg.Stride.of(newSpread))
                                .call(toBeSliced);

            slice.getMut().setIsIntermediate(false);
            toBeSliced.getMut().setIsIntermediate(isIntermediate);
            return slice;
        };

        int[] shape = toBeSliced.getNDConf().shape();
        _axisSliceBuilders = new AxisSliceBuilder[ shape.length ];
        int[] newShape = new int[shape.length];
        int[] newSpread = new int[shape.length];
        int[] newOffset = new int[shape.length];
        for ( int i = 0; i < shape.length; i++ ) {
            int finalI = i;
            _axisSliceBuilders[ i ] = new AxisSliceBuilder<>(
                                                shape[ i ],
                                                ( from, to, step ) -> {
                                                    if ( from < 0 && to < 0 && from > to ) {
                                                        int temp = from;
                                                        from = to;
                                                        to = temp;
                                                    }
                                                    from = ( from < 0 ) ? shape[finalI] + from : from;
                                                    to = ( to < 0 ) ? shape[finalI] + to : to;
                                                    if ( to < 0 ) to += shape[ finalI ];
                                                    if ( from < 0 ) from += shape[ finalI ];
                                                    newOffset[ finalI ] = from;
                                                    newShape[ finalI ] = ( to - from + 1 ) / step;
                                                    newSpread[ finalI ] = step;
                                                    _axisSliceBuilders[ finalI ] = null;
                                                    return this;
                                                });
        }
        _create = allowAutograd -> {
            for ( AxisSliceBuilder<V> axis : _axisSliceBuilders ) {
                if ( axis != null ) axis.resolve();
            }
            return sliceCreator.sliceOf( newShape, newOffset, newSpread, allowAutograd );
        };
    }

    /**
     *  This method returns an instance of the {@link AxisSliceBuilder} disguised by the {@link FromOrAt} interface.
     *  The {@link AxisSliceBuilder} class implements the {@link FromOrAt} interface in order to ensure
     *  that the builder methods of this API are being called in the correct order.
     *
     * @param axis The index of the axis which ought to be sliced.
     * @return An instance of the {@link AxisSliceBuilder} disguised by the {@link FromOrAt} interface.
     */
    @Override
    public FromOrAtTensor<V> axis(int axis ) {
        if ( axis >= _axisSliceBuilders.length ) throw new IllegalArgumentException("");
        return _axisSliceBuilders[ axis ];
    }

    /**
     *  This method will create and return a new slice tensor based on the
     *  provided configuration through methods like {@link AxisSliceBuilder#from(int)},
     *  {@link AxisSliceBuilder#to(int)} and {@link AxisSliceBuilder#at(int)}... <br>
     *
     * @return The slice of the tensor supplied to the constructor of this {@link SliceBuilder} instance.
     */
    @Override
    public Tensor<V> get() {
        return _create.apply(true);
    }

    @Override
    public Tensor<V> detached() {
        return _create.apply(false);
    }


}