NDFrame.java

package neureka.framing;

import neureka.Tensor;
import neureka.common.composition.Component;
import neureka.common.utility.LogUtil;
import neureka.framing.fluent.AxisFrame;

import java.util.*;
import java.util.function.Function;

/**
 *  Instances of this class are components of tensors, which store aliases for the indices of the tensor.
 *  These indices aliases can be anything that has an identity, meaning any plain old object. <br>
 *  There are two layers of aliasing/labeling provided by this class:
 *  <ul>
 *      <li>
 *          Labels for the axis of a tensor, which are the indices of its shape array.
 *      </li>
 *      <li>
 *          Labels for the indices of a specific axis.
 *      </li>
 *  </ul>
 *  Lets for example imagine a tensor of rank 2 with the shape (3, 4), then the axis could for example be labeled
 *  with a tuple of two {@link String} instances like: ("a","b"). <br>
 *  Labeling the indices of the axis for this example requires 2 arrays whose length matches the axis sizes. <br>
 *  The following mapping would be able to label both the axis and their indices: <br>
 *                                                                             <br>
 *  "a" : ["first", "second", "third"],                                        <br>
 *  "b" : ["one", "two", "three", "four"]                                      <br>
 *                                                                             <br>
 *
 * @param <V> The type parameter of the value type of the tensor type to whom this component should belong.
 */
public final class NDFrame<V> implements Component<Tensor<V>>
{
    private final List<Object> _hiddenKeys = new ArrayList<>();
    /**
     *  This {@link Map} contains all the aliases for axis as well as individual
     *  positions for a given axis (in the form of yet another {@link Map}).
     */
    private final Map<Object, Object> _mapping;
    /**
     *  A frame can also carry a name.
     *  When loading a CSV file for example the label would be the first cell if
     *  both index and header labels are included in the file.
     */
    private final String _mainLabel;

    public NDFrame(List<List<Object>> labels, Tensor<V> host, String mainLabel ) {
        this(Collections.emptyMap(), host, mainLabel);
        _label(labels);
    }

    private NDFrame<V> _label( List<List<Object>> labels ) {
        for ( int i = 0; i < labels.size(); i++ ) _mapping.put( i, new LinkedHashMap<>() );
        for ( int i = 0; i < labels.size(); i++ ) {
            if ( labels.get( i ) != null ) {
                for ( int j = 0; j < labels.get( i ).size(); j++ ) {
                    if ( labels.get( i ).get( j ) != null )
                        atAxis( i ).atIndexAlias( labels.get( i ).get( j ) ).setIndex( j );
                }
            }
        }
        return this;
    }

    public NDFrame(Tensor<V> host, String tensorName ) {
        this(Collections.emptyMap(), host, tensorName);
    }

    public NDFrame(
            Map<Object, List<Object>> labels,
            Tensor<V> host,
            String ndaMainLabel
    ) {
        _mainLabel = ndaMainLabel;
        _mapping = new LinkedHashMap<>( labels.size() * 3 );
        int[] index = { 0 };
        labels.forEach( ( k, v ) -> {
            if ( !k.equals( index[ 0 ] ) ) _hiddenKeys.add( index[ 0 ] );
            if ( v != null ) {
                Map<Object, Integer> indicesMap = new LinkedHashMap<>( v.size() * 3 );
                for ( int i = 0; i < v.size(); i++ ) indicesMap.put( v.get( i ), i );
                _mapping.put( k, indicesMap );
            }
            else
                _mapping.put( k, host.getNDConf().shape( index[ 0 ] ) );

            index[ 0 ]++;
        });
        index[0] = 0;
        labels.forEach( ( k, v ) -> {
            _mapping.put( index[ 0 ], _mapping.get(k) ); // default integer index should also always work!
            index[ 0 ]++;
        });
        for ( int i = index[0]; i < host.rank(); i++ )
            if ( !_mapping.containsKey( i ) )
                _mapping.put(i, new LinkedHashMap<>());
    }

    private NDFrame( List<Object> hiddenKeys, Map<Object, Object> mapping, String tensorName ) {
        _hiddenKeys.addAll( hiddenKeys );
        _mapping = new LinkedHashMap<>(mapping);
        _mainLabel = tensorName;
    }

    public NDFrame<V> withLabel( String newLabel ) {
        return new NDFrame<>( _hiddenKeys, _mapping, newLabel );
    }

    public NDFrame<V> withAxesLabels( List<List<Object>> labels ) {
        return new NDFrame<V>( _hiddenKeys, _mapping, _mainLabel )._label(labels);
    }

    public int[] get( List<Object> keys ) {
        LogUtil.nullArgCheck( keys, "keys", List.class );
        return get( keys.toArray( new Object[0] ) );
    }

    public int[] get( Object... keys ) {//Todo: iterate over _mapping
        LogUtil.nullArgCheck( keys, "keys", Object[].class );
        int[] indices = new int[ keys.length ];
        for( int i = 0; i < indices.length; i++ ) {
            Object am = _mapping.get( i );
            if ( am instanceof Map )
                indices[ i ] = ( (Map<Object, Integer>) am ).get( keys[ i ] );
            else if ( am instanceof Integer )
                indices[ i ] = (Integer) am;
        }
        return indices;
    }

    public boolean hasLabelsForAxis( Object axisAlias ) {
        LogUtil.nullArgCheck( axisAlias, "axisAlias", Object.class );
        return !atAxis(axisAlias).getAllAliases().isEmpty();
    }

    /**
     *  A {@link NDFrame} exposes aliases for axes as well as aliases for individual positions within an axis.
     *  This method returns a view on a axis which is targeted by an axis alias as key.
     *  This view is an instance of the {@link AxisFrame} class which provides useful methods
     *  for getting or setting alias objects for individual positions for the given axis.
     *  This is useful when for example replacing certain aliases or simply taking a look at them.
     *
     * @param axisAlias The axis alias object which targets an {@link AxisFrame} of {@link NDFrame}.
     * @return A view of the targeted axis in the for of an{@link AxisFrame} which provides getters and setters for aliases.
     */
    public AxisFrame<Integer, V> atAxis( Object axisAlias )
    {
        LogUtil.nullArgCheck( axisAlias, "axisAlias", Object.class );
        return AxisFrame.<Integer, Integer, V>builder()
                .getter(
                        atKey -> () ->
                        {
                            Object am =  _mapping.get( axisAlias );
                            if ( am instanceof Map )
                                return ((Map<Object, Integer>) _mapping.get( axisAlias )).get(atKey);
                            else
                                return 0;
                        }
                )
                .setter(
                        atKey -> (int setValue) ->
                        {
                            Map<Object, Integer> am = _initializeIndexMap( axisAlias, atKey, setValue );
                            am.put( atKey, setValue );
                            return this;
                        }
                )
                .replacer(
                        ( currentIndexKey ) -> (newIndexKey ) -> {
                            Map<Object, Integer> am = _initializeIndexMap( axisAlias, currentIndexKey, (Integer) currentIndexKey );
                            if (am.containsKey( currentIndexKey ) )
                                am.put( newIndexKey, am.remove( currentIndexKey ) ); // This...
                            return this;
                        }
                )
                .allAliasGetter(
                        () -> {
                            Object am =  _mapping.get( axisAlias );
                            if ( am == null ) return new ArrayList<>();
                            List<Object> keys = new ArrayList<>();
                            if ( am instanceof Map ) ( (Map<Object, Integer>) am ).forEach( ( k, v ) -> keys.add( k ) );
                            else for ( int i = 0; i < ( (Integer) am ); i++ ) keys.add( i );
                            return keys;
                        }
                )
                .allAliasGetterFor(
                        (index) -> {
                            List<Object> keys = new ArrayList<>();
                            Object am =  _mapping.get( axisAlias );
                            if ( am instanceof Map ) ( (Map<Object, Integer>) am ).forEach( (k, v) -> { if (v.equals(index)) keys.add( k ); } );
                            else keys.add( index );
                            return keys;
                        }
                )
                .build();
    }

    private Map<Object, Integer> _initializeIndexMap(Object axis, Object key, int index ) {
        Object am =  _mapping.get( axis );
        if ( am instanceof Map )
            return (Map<Object, Integer>) _mapping.get( axis );

        int size = (Integer) am;

        Map<Object, Integer> newIdxmap = new LinkedHashMap<>( size * 3 );
        for( int i = 0; i < size; i++ ) {
            if ( index == i ) newIdxmap.put( key, i );
            else newIdxmap.put( i, i );
        }
        _mapping.put( axis, newIdxmap );
        return newIdxmap;
    }

    /**
     *  This method simply pads the provided string based on the size passed to it.
     *
     * @param string The {@link String} which ought to be padded by white spaces.
     * @param cellSize The length of the padded {@link String} which will be returned.
     * @return The padded {@link String}.
     */
    private String _paddedCentered( String string, int cellSize ) {
        if ( string.length() < cellSize ) {
            int first = cellSize / 2;
            int second = cellSize - first;
            first -= string.length() / 2;
            second -= string.length() - string.length() / 2;
            StringBuilder strBuilder = new StringBuilder(string);
            // Now we prepend the prefix spaces:
            for ( int i = 0; i < first; i++ ) strBuilder.insert(0, " ");
            strBuilder.append( String.join("", Collections.nCopies(Math.max( 0, second ), " ")) );
            // ...equal to the following expression:  " ".repeat( Math.max( 0, second ) ) );
            return strBuilder.toString();
        }
        return string;
    }


    @Override
    public String toString()
    {
        final int TABLE_CELL_WIDTH = 16;
        final String WALL = " | ";
        final String HEADLINE = "=";
        final String ROWLINE = "-";
        final String CROSS = "+";

        int indexShift = WALL.length() / 2;
        int crossMod = TABLE_CELL_WIDTH+WALL.length();
        Function<Integer, Boolean> isCross = i -> ( i - indexShift ) % crossMod == 0;
        StringBuilder builder = new StringBuilder();
        builder.append( WALL );

        _mapping.forEach( ( k, v ) -> {
            if ( !_hiddenKeys.contains( k ) ) {
                String axisHeader = k.toString();
                axisHeader = _paddedCentered(axisHeader, TABLE_CELL_WIDTH);
                builder.append(axisHeader);
                builder.append(WALL);
            }
        });
        int lineLength = builder.length();
        builder.append( "\n" );
        for ( int i = 0; i < lineLength; i++ ) builder.append( ( isCross.apply( i ) ) ? CROSS : HEADLINE );
        builder.append( "\n" );
        boolean[] hasMoreIndexes = { true };
        int[] depth = { 0 };
        while ( hasMoreIndexes[ 0 ] ) {
            Object[] keyOfDepth = { null };
            builder.append( WALL );
            _mapping.forEach( ( k, v ) -> {
                if ( !_hiddenKeys.contains( k ) ) {
                    keyOfDepth[0] = null;
                    if (v instanceof Map) {
                        ((Map<Object, Integer>) v).forEach((ik, iv) -> {
                            if ( iv == depth[0] ) keyOfDepth[0] = ik;
                        });
                    } else if (v instanceof Integer) {
                        if (depth[0] < ((Integer) v)) keyOfDepth[0] = depth[0];
                    }
                    if (keyOfDepth[0] != null) {
                        builder.append(_paddedCentered((keyOfDepth[0]).toString(), TABLE_CELL_WIDTH));
                    } else {
                        builder.append(_paddedCentered("---", TABLE_CELL_WIDTH));
                    }
                    builder.append(WALL);
                }
            });
            depth[ 0 ]++;
            builder.append( "\n" );
            for( int i = 0; i < lineLength; i++ ) builder.append( ( isCross.apply( i ) ) ? CROSS : ROWLINE );
            builder.append( "\n" );
            if ( keyOfDepth[ 0 ] == null ) hasMoreIndexes[ 0 ] = false;
        }

        StringBuilder result = new StringBuilder().append( "\nTensor IndexAlias: axis/indexes" );
        result.append( "\n" );
        for ( int i = 0; i < lineLength; i++ ) result.append( HEADLINE );
        result.append( "\n" );

        result.append( builder );
        return result.toString();
    }


    @Override
    public boolean update( OwnerChangeRequest<Tensor<V>> changeRequest ) {
        changeRequest.executeChange(); // This can be an 'add', 'remove' or 'transfer' of this component!
        // This component does not have anything to do when switching owner...
        return true;
    }

    private Map<Object, Object> _mapping() { return Collections.unmodifiableMap(_mapping); }

    public Map<Object, List<Object>> getState() {
        Map<Object, Object> internalState = _mapping();
        Map<Object, List<Object>> simpleState = new LinkedHashMap<>();
        for ( Object k : internalState.keySet() ) {
            Object al = internalState.get(k);
            if ( al instanceof Integer ) simpleState.put( k, null ); // newShape[i]
            else {
                List<Object> map = new ArrayList<>();
                List<Map.Entry<Object,Object>> entries = new ArrayList<>(((Map<Object,Object>)al).entrySet());
                for ( Map.Entry<Object, Object> entry : entries ) map.add(entry.getKey());
                simpleState.put( k, map );
            }
        }
        return simpleState;
    }

    public String getLabel() {
        return _mainLabel;
    }
}