NdaBuilder.java

package neureka.fluent.building;

import neureka.Nda;
import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.common.utility.DataConverter;
import neureka.common.utility.LogUtil;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.dtype.DataType;
import neureka.fluent.building.states.*;
import neureka.ndim.Filler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

/**
 *  This is the implementation of the fluent builder API for creating {@link Nda}/{@link Tensor} instances.
 *  A simple example would be:
 * <pre>{@code
 *
 *    Tensor.of(Double.class)
 *          .withShape( 2, 3, 4 )
 *          .andFill( 5, 3, 5 )
 *
 * }</pre>
 *
 * It is also possible to define a range using the API to populate the tensor with values:
 * <pre>{@code
 *
 *    Tensor.of(Double.class)
 *          .withShape( 2, 3, 4 )
 *          .andFillFrom( 2 ).to( 9 ).step( 2 )
 *
 * }</pre>
 *
 * If one needs a simple scalar then the following shortcut is possible:
 * <pre>{@code
 *
 *    Tensor.of(Float.class).scalar( 3f )
 *
 * }</pre>
 *
 * This principle works for vectors as well:
 * <pre>{@code
 *
 *     Tensor.of(Byte.class).vector( 2, 5, 6, 7, 8 )
 *
 * }</pre>
 * For more fine-grained control over the initialization one can
 * pass an initialization lambda to the API:
 * <pre>{@code
 *
 *     Tensor.of(Byte.class).withShape(2, 3).andWhere( (i, indices) -> i * 5 - 30 )
 *
 * }</pre>
 *
 * @param <V> The type of the values which ought to be represented by the {@link Tensor} built by this {@link NdaBuilder}.
 */
public final class NdaBuilder<V> implements WithShapeOrScalarOrVectorOnDevice<V>, IterByOrIterFromOrAllTensor<V>, ToForTensor<V>, StepForTensor<V>
{
    private static final Logger _LOG = LoggerFactory.getLogger(NdaBuilder.class);

    private final DataType<V> _dataType;
    private Shape _shape;
    private V _from;
    private V _to;
    private Device<? super V> _device = CPU.get();

    /**
     * @param typeClass The type of the values which ought to be represented by the {@link Tensor} built by this {@link NdaBuilder}.
     */
    public NdaBuilder( Class<V> typeClass ) {
        LogUtil.nullArgCheck( typeClass, "typeClass", Class.class, "Cannot build tensor without data type information!" );
        _dataType = DataType.of( typeClass );
    }

    private Tensor<V> _get(Object value ) {
        LogUtil.nullArgCheck( value, "value", Object.class, "Cannot build tensor where value is null!" );
        return Tensor.of( _dataType, _device, _shape, value );
    }

    /**
     * @param values The values which will recurrently populate the returned {@link Tensor} with values until it is filled.
     * @return A new {@link Tensor} instance populated by the array of values supplied to this method.
     */
    @SafeVarargs
    @Override
    public final Tensor<V> andFill(V... values ) {
        LogUtil.nullArgCheck( values, "values", _dataType.getItemTypeClass(), "Cannot fill a tensor will a value array that is null!" );
        if ( _isAllOne(values) ) return _get( values[0] );
        return _get( values );
    }

    private <T> boolean _isAllOne(T[] values) {
        if ( values.length > 0 ) {
            T first = values[0];
            if ( values.length == 1 ) return true;
            else if ( values.length <= 42 ) {
                for ( T value : values ) {
                    if ( !Objects.equals(first, value) )
                        return false;
                }
                return true;
            }
        }
        return false;
    }

    /**
     *  This method receives an {@link Filler} lambda which will be
     *  used to populate the {@link Tensor} instance produced by this API with values.
     *
     * @param filler The {@link Filler} which ought to populate the returned {@link Tensor}.
     * @return A new {@link Tensor} instance populated by the lambda supplied to this method.
     */
    @Override
    public Tensor<V> andWhere(Filler<V> filler ) {
        return Tensor.of( _dataType, _shape, filler ).to( _device );
    }

    @Override
    public ToForTensor<V> andFillFrom(V index ) {
        LogUtil.nullArgCheck(index, "index", _dataType.getItemTypeClass(), "Cannot create a range where the last index is undefined!");
        _from = _checked(index);
        return this;
    }

    @Override
    public Tensor<V> all(V value ) { return _get( value ); }

    @Override
    public Tensor<V> andSeed(Object seed ) {
        Class<V> type = _dataType.getItemTypeClass();
        Class<?> seedType = seed.getClass();
        try {
            Function random = Neureka.get().backend().getFunction().random();
            if (type == Double.class && seedType == Long.class)
                return random.with( Arg.Seed.of((Long) seed) ).call( _get( 0d ) );
            else if (type == Float.class && seedType == Long.class)
                return random.with( Arg.Seed.of((Long) seed) ).call( _get( 0f ) );
            else
                return Tensor.of( type, _shape, Arg.Seed.of(seed.toString()) ).to( _device );
        } catch ( Exception e ) {
            IllegalArgumentException exception =
                    new IllegalArgumentException(
                         "Could not create a random tensor for type '"+type+"'!"
                    );
            _LOG.error( exception.getMessage(), e );
            throw exception;
        }
    }

    @Override
    public IterByOrIterFromOrAllTensor<V> withShape(int... shape ) {
        LogUtil.nullArgCheck(shape, "shape", int[].class, "Cannot create a tensor without shape!");
        if ( shape.length == 0 )
            throw new IllegalArgumentException("Cannot instantiate a tensor without shape arguments.");
        _shape = Shape.of(shape);
        return this;
    }

    @Override
    public Tensor<V> vector(Object[] values ) {
        LogUtil.nullArgCheck(values, "values", Object[].class, "Cannot create a vector without data array!");
        _shape = Shape.of( values.length );
        if ( _isAllOne(values) ) return _get( values[0] );
        return _get( values );
    }

    @Override
    public Tensor<V> scalar(V value ) {
        if ( value != null ) {
            value = _checked( value );
            if ( !_dataType.getItemTypeClass().isAssignableFrom(value.getClass()) ) {
                try {
                    value = DataConverter.get().convert( value, _dataType.getItemTypeClass() );
                } catch (Exception e) {
                    throw new IllegalArgumentException(
                        "Provided value is incompatible with the specified data-type!" +
                        "Expected type "+_dataType.getItemTypeClass().getSimpleName()+"\n" +
                        "but encountered "+value.getClass().getSimpleName() + " instead!"
                    );
                }
            }
        }
        _shape = Shape.of( 1 );
        return _get( value );
    }

    /**
     *  This method makes sure that the data provided by the user is indeed of the right type
     *  by converting it if possible to the previously provided data type.
     *
     * @param o The scalar value which may need to be converted to the provided data type.
     * @return The value converted to the type defined by the provided {@link #_dataType}.
     */
    private V _checked( V o ) {
        Class<?> jvmType = _dataType.getItemTypeClass();
        if ( Number.class.isAssignableFrom(jvmType) ) {
            if ( o instanceof Number && o.getClass() != jvmType ) {
                Number n = (Number) o;
                if ( jvmType == Integer.class ) return (V) ((Integer) n.intValue()   );
                if ( jvmType == Double.class  ) return (V) ((Double)  n.doubleValue());
                if ( jvmType == Short.class   ) return (V) ((Short)   n.shortValue() );
                if ( jvmType == Byte.class    ) return (V) ((Byte)    n.byteValue()  );
                if ( jvmType == Long.class    ) return (V) ((Long)    n.longValue()  );
                if ( jvmType == Float.class   ) return (V) ((Float)   n.floatValue() );
            }
        }
        return o;
    }

    @Override
    public StepForTensor<V> to( V index ) { _to = _checked(index); return this; }

    @Override
    public Tensor<V> step(double size ) {
        int tensorSize = _size();
        Object data = null;
        int itemLimit = _size();
        int itemIndex = 0;
        if ( _dataType == DataType.of( Integer.class ) ) {
            List<Integer> range = new ArrayList<>();
            for ( int index = ((Integer) _from); index <= ((Integer)_to) && itemIndex < itemLimit; index += size ) {
                range.add( index );
                itemIndex++;
            }
            data = IntStream.iterate( 0, i -> i + 1 )
                            .limit( tensorSize )
                            .map( i -> range.get( i % range.size() ) )
                            .toArray();
        }
        else if ( _dataType == DataType.of( Double.class ) ) {
            List<Double> range = new ArrayList<>();
            for ( double index = ((Double) _from); index <= ((Double)_to) && itemIndex < itemLimit; index += size ) {
                range.add( index );
                itemIndex++;
            }
            data = IntStream.iterate( 0, i -> i + 1 )
                            .limit( tensorSize )
                            .mapToDouble( i -> range.get( i % range.size() ) )
                            .toArray();
        }
        else if ( _dataType == DataType.of( Long.class ) ) {
            List<Long> range = new ArrayList<>();
            for ( long index = ((Long) _from); index <= ((Long)_to) && itemIndex < itemLimit; index += size ) {
                range.add( index );
                itemIndex++;
            }
            data = IntStream.iterate( 0, i -> i + 1 )
                    .limit( tensorSize )
                    .mapToLong( i -> range.get( i % range.size() ) )
                    .toArray();
        }
        else if ( _dataType == DataType.of( Float.class ) ) {
            List<Float> range = new ArrayList<>();
            for ( double index = ((Float) _from); index <= ((Float)_to) && itemIndex < itemLimit; index += size ) {
                range.add( (float) index );
                itemIndex++;
            }
            float[] primData = new float[ tensorSize ];
            for ( int ii = 0; ii < tensorSize; ii++ )
                primData[ ii ] = range.get( ii % range.size() );

            data = primData;
        }
        else if ( _dataType == DataType.of( Byte.class ) ) {
            List<Byte> range = new ArrayList<>();
            for ( byte index = ((Byte) _from); index <= ((Byte)_to) && itemIndex < itemLimit; index += size ) {
                range.add( index );
                itemIndex++;
            }
            byte[] primData = new byte[ tensorSize ];
            for ( int ii = 0; ii < tensorSize; ii++ )
                primData[ ii ] = range.get( ii % range.size() );

            data = primData;
        }
        else if ( _from instanceof Comparable && _to instanceof Comparable ) {
            //data = new ObjectRange( (Comparable<V>) _from, (Comparable<V>) _to ).step( (int) size );
            throw new IllegalStateException("Cannot form a range for the provided elements...");
            // TODO: make it possible to have ranges like 'a' to 'z'...
        }
        return _get( data );
    }

    private int _size() {
        int size = 1;
        for ( int axis : _shape ) size *= axis;
        return size;
    }

    @Override
    public WithShapeOrScalarOrVectorTensor<V> on(Device<? super V> device ) {
        LogUtil.nullArgCheck(device, "device", Device.class, "Cannot create a tensor with an undefined device!");
        _device = device;
        return this;
    }
}