IDXHandle.java

package neureka.devices.file;



import neureka.Neureka;
import neureka.Shape;
import neureka.Tensor;
import neureka.dtype.DataType;
import neureka.dtype.NumericType;
import neureka.dtype.custom.*;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.stream.Collectors;

/**
 *  This class is one of many extensions of the {@link AbstractFileHandle} which
 *  is therefore ultimately an implementation of the {@link FileHandle} interface.
 *  Like other {@link FileHandle} implementations this class represents a file
 *  of a given type, in this case it represents a IDX file.
*/
public final class IDXHandle extends AbstractFileHandle<IDXHandle, Number>
{
    static {
        _LOG = LoggerFactory.getLogger( IDXHandle.class );
    }
    private DataType<?> _dataType;
    private int _dataOffset;
    private int _valueSize;
    private Shape _shape;

    private static final Map<Integer, Class<?>> TYPE_MAP;
    static {
        TYPE_MAP = new HashMap<>();
        TYPE_MAP.put( 0x08, UI8.class );  // unsigned byte
        TYPE_MAP.put( 0x09, I8.class  );  // signed byte
        TYPE_MAP.put( 0x0A, UI16.class);  //-> !! This is speculation !!
        TYPE_MAP.put( 0x0B, I16.class );  // short (2 bytes)
        TYPE_MAP.put( 0x0C, I32.class );  // int (4 bytes)
        TYPE_MAP.put( 0x0D, F32.class );  // float (4 bytes)
        TYPE_MAP.put( 0x0E, F64.class );  // double (8 bytes)
        TYPE_MAP.put( 0x0F, I64.class );  //-> !! This is speculation !!
    }

    private final static Map<Class<?>, Integer> CODE_MAP = TYPE_MAP.entrySet()
                                                            .stream()
                                                            .collect(
                                                                Collectors.toMap(
                                                                    Map.Entry::getValue,
                                                                    Map.Entry::getKey
                                                                )
                                                            );

    public IDXHandle(String fileName)
    {
        super( fileName, new IDXType() );
        try {
            _loadHead();
        } catch( Exception e ) {
            e.printStackTrace();
            System.err.print("Failed reading IDX file!");
        }
    }

    public IDXHandle(Tensor<Number> t, String filename ) {
        super( filename, new IDXType() );
        _shape = t.shape();
        _dataType = t.getDataType();
        t.getMut().setIsVirtual( false );
        store( t );
    }

    private void _loadHead() throws IOException
    {
        FileInputStream f = _loadFileInputStream();

        NumberReader numre = new NumberReader( f );

        int zeros = numre.read( new UI16() );
        assert zeros == 0;

        int typeId = numre.read( new UI8() );
        Class<?> typeClass = TYPE_MAP.get( typeId );
        _dataType = DataType.of( typeClass );

        int rank = numre.read( new UI8() );
        int[] shape = new int[ rank ];

        int size = 1;
        for ( int i = 0; i < rank; i++ ) {
            shape[ i ] = numre.read( new UI32() ).intValue();
            size *= shape[ i ];
        }

        _shape = Shape.of(shape);
        _valueSize = size;
        _dataOffset = numre.bytesRead();
    }


    @Override
    public <T extends Number> IDXHandle store( Tensor<T> tensor )
    {
        Iterator<T> data = tensor.iterator();
        FileOutputStream fos;
        try
        {
            fos = new FileOutputStream(_fileName);
        }
        catch (FileNotFoundException e)
        {
            try {
                File newFile = new File( _fileName );
                fos = new FileOutputStream( newFile );
            } catch ( Exception innerException ) {
                innerException.printStackTrace();
                return this;
            }
        }
        BufferedOutputStream f = new BufferedOutputStream(fos);

        int offset = 0;

        try {
            Class<?> representativeType = _dataType.getRepresentativeType();
            Integer code = CODE_MAP.get( representativeType );
            if ( code == null )
                throw new IllegalStateException(
                            "Unable to store nd-array of type: " + _dataType + ", because " +
                            "no suitable IDX type code could be found for it!"
                        );

            f.write( new byte[]{ 0, 0 } );
            offset += 2;
            f.write( code.byteValue() );
            offset += 1;
            byte rank = (byte) _shape.size();
            f.write( rank );
            offset += 1;
            int bodySize = 1;
            for ( int i = 0; i < rank; i++ ) {
                byte[] integer = ByteBuffer.allocate( 4 ).putInt( _shape.get( i ) ).array();
                assert integer.length == 4;
                f.write(integer);
                bodySize *= _shape.get( i );
                offset += 4;
            }
            _dataOffset = offset;
            _valueSize = bodySize;
            NumericType<Number, Object, Number, Object> type = (NumericType<Number, Object, Number, Object>) _dataType.getTypeClassInstance(NumericType.class);

            type.writeDataTo( new DataOutputStream( f ), (Iterator<Number>) data );
            f.close();
        } catch ( Exception e ) {
            e.printStackTrace();
        }
        return this;
    }

    @Override
    protected Object _loadData() throws IOException {
        FileInputStream fs = new FileInputStream( _fileName );
        Class<?> clazz = _dataType.getRepresentativeType();
        if ( NumericType.class.isAssignableFrom( clazz ) ) {
            NumericType<?,?,?,?> type = _dataType.getTypeClassInstance(NumericType.class);
            DataInput stream = new DataInputStream(
                    new BufferedInputStream( fs, _dataOffset + _valueSize * type.numberOfBytes() )
            );
            stream.skipBytes( _dataOffset );
            if ( Neureka.get().settings().dtype().getIsAutoConvertingExternalDataToJVMTypes() )
                return type.readAndConvertForeignDataFrom( stream, _valueSize);
            else
                return type.readForeignDataFrom( stream, _valueSize);
        }
        return null;
    }

    @Override
    public Tensor<Number> load() throws IOException
    {
        Object value = _loadData();
        DataType<?> type = Neureka.get().settings().dtype().getIsAutoConvertingExternalDataToJVMTypes()
                            ? DataType.of( _dataType.getTypeClassInstance(NumericType.class).getNumericTypeTarget() )
                            : _dataType;
        return Tensor.of( type, _shape, value ).getMut().upcast(Number.class);
    }

    @Override
    public int getDataSize() {
        int bytes = ( _dataType.typeClassImplements( NumericType.class ) )
                ? _dataType.getTypeClassInstance(NumericType.class).numberOfBytes()
                : 1;
        return _valueSize * bytes;
    }

    @Override
    public int getTotalSize() {
        return getDataSize() + _dataOffset;
    }

    public DataType<?> getDataType() {
        return _dataType;
    }

    public int getValueSize() {
        return _valueSize;
    }

    @Override
    public Shape getShape() {
        return _shape;
    }

    private static class IDXType implements FileType
    {
        @Override public String defaultExtension() { return "idx"; }
    }
}