CSVHandle.java
package neureka.devices.file;
import neureka.Shape;
import neureka.Tensor;
import neureka.devices.Storage;
import neureka.dtype.DataType;
import neureka.framing.NDFrame;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.lang.ref.WeakReference;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* This class is one of many extensions of the {@link AbstractFileHandle} which
* is therefore ultimately an implementation of the {@link FileHandle} interface.
* Like other {@link FileHandle} implementations this class represents a file
* of a given type, in this case it represents a CSV file.
*/
public final class CSVHandle extends AbstractFileHandle<CSVHandle, String>
{
static {
_LOG = LoggerFactory.getLogger( CSVHandle.class );
}
private String _label;
private final String _delimiter;
private final boolean _firstRowIsLabels;
private String[] _colLabels;
private final boolean _firstColIsIndex;
private String[] _rowLabels;
private Integer _numberOfRows = null;
private Integer _numberOfColumns = null;
private Integer _numberOfBytes = null;
private WeakReference<String[]> _rawData = null;
CSVHandle(Tensor<?> tensor, String filename )
{
super( filename, new CSVType() );
assert tensor.rank() == 2;
_delimiter = ",";
NDFrame<?> alias = tensor.get( NDFrame.class );
List<Object> index = ( alias != null ? alias.atAxis( 0 ).getAllAliases() : null );
List<Object> labels = ( alias != null ? alias.atAxis( 1 ).getAllAliases() : null );
_label = (alias != null) ? alias.getLabel() : null;
_firstRowIsLabels = labels != null;
_firstColIsIndex = index != null;
StringBuilder asCsv = new StringBuilder();
if ( _firstRowIsLabels ) {
if ( _firstColIsIndex ) labels.add( 0, (_label == null) ? "" : _label);
asCsv.append(labels.stream().map(Object::toString).collect(Collectors.joining(_delimiter)))
.append("\n");
}
int[] shape = tensor.getNDConf().shape();
assert shape.length == 2;
if ( _firstColIsIndex ) assert index.size() == shape[ 0 ];
int[] indices = new int[ 2 ];
for ( int i = 0; i < shape[ 0 ]; i++ ) {
indices[ 0 ] = i;
if ( _firstColIsIndex ) asCsv.append( index.get(i).toString() ).append( "," );
for ( int ii = 0; ii < shape[ 1 ]; ii++ ) {
indices[ 1 ] = ii;
asCsv.append( tensor.item( indices ) );
if ( ii < shape[ 1 ] - 1 ) asCsv.append( _delimiter );
}
asCsv.append( "\n" );
}
try {
PrintWriter out = new PrintWriter( filename );
out.print( asCsv );
out.close();
} catch ( Exception e ) {
e.printStackTrace();
}
}
public CSVHandle(
String fileName,
Map<String, Object> settings
) {
super( fileName, new CSVType() );
if ( settings != null ) {
_delimiter = (String) settings.getOrDefault( "delimiter", "," );
_firstRowIsLabels = (boolean) settings.getOrDefault( "firstRowIsLabels", false );
_firstColIsIndex = (boolean) settings.getOrDefault( "firstColIsIndex", false );
} else {
_delimiter = ",";
_firstRowIsLabels = false;
_firstColIsIndex = false;
}
}
private String[] _lazyLoad() {
if ( _rawData != null ) {
String[] alreadyLoaded = _rawData.get();
if ( alreadyLoaded != null ) return alreadyLoaded;
}
FileInputStream fis;
try {
fis = _loadFileInputStream();
} catch( Exception e ) {
e.printStackTrace();
System.err.print( "Failed reading CSV file!" );
_LOG.error( "Failed reading CSV file!" );
return new String[0];
}
List<String[]> table = new ArrayList<>();
List<String> rowLabels = ( _firstColIsIndex ) ? new ArrayList<>() : null;
try (
BufferedReader br = new BufferedReader( new InputStreamReader( fis, StandardCharsets.UTF_8 ) )
) {
String line;
while( ( line = br.readLine() ) != null ) {
table.add( line.split( _delimiter ) );
}
} catch ( IOException e ) {
e.printStackTrace();
}
int rowLength = -1;
int colHeight = 0;
int size = 0;
int numberOfBytes = 0;
if ( _firstRowIsLabels ) {
_colLabels = table.remove( 0 );
if ( _firstColIsIndex ) {
if ( !_colLabels[0].trim().equals("") ) _label = _colLabels[0].trim();
else _parseTensorNameFromFileName();
String[] newLabels = new String[ _colLabels.length - 1 ];
System.arraycopy( _colLabels, 1, newLabels, 0, newLabels.length );
_colLabels = newLabels;
}
else _parseTensorNameFromFileName();
}
else _parseTensorNameFromFileName();
for ( int ri = 0; ri < table.size(); ri++ ) {
String[] row = table.get( ri );
if ( _firstColIsIndex ) {
rowLabels.add( row[0] );
String[] newRow = new String[ row.length - 1 ];
System.arraycopy( row, 1, newRow, 0, newRow.length );
row = newRow;
table.set( ri, newRow );
}
if ( rowLength < 0 ) rowLength = row.length;
if ( rowLength == row.length ) {
size += row.length;
for ( String element : row )
numberOfBytes += element.getBytes( StandardCharsets.UTF_8 ).length;
colHeight++;
}
}
if ( rowLabels != null ) _rowLabels = rowLabels.toArray( new String[rowLabels.size()] );
_numberOfColumns = rowLength;
_numberOfRows = colHeight;
_numberOfBytes = numberOfBytes;
String[] rawData = new String[ size ];
_rawData = new WeakReference<>( rawData );
for ( int ri = 0; ri < _numberOfRows; ri++ ) {
for ( int ci = 0; ci < _numberOfColumns; ci++ ) {
rawData[ ri * rowLength + ci ] = table.get( ri )[ ci ];
}
}
return rawData;
}
private void _parseTensorNameFromFileName() {
String[] parts = _fileName.replace("\\", "/").split("/");
if ( parts.length > 0 ) parts = parts[ parts.length - 1 ].split("\\.");
_label = (parts.length > 0)? parts[0] : _label;
}
@Override
public <T extends String> Storage<String> store( Tensor<T> tensor ) {
throw new UnsupportedOperationException( "CSVHandle does not support storing tensors!" );
}
@Override protected Object _loadData() { return _lazyLoad(); }
@Override
public Tensor<String> load() throws IOException {
String[] data = _lazyLoad();
Tensor<String> loaded = Tensor.of(DataType.of( String.class ), getShape(), data);
String[] index;
String[] labels;
if ( !_firstColIsIndex ) {
index = new String[ _numberOfRows ];
for ( int i = 0; i < index.length; i++ ) index[ i ] = String.valueOf( i );
}
else index = _rowLabels;
if ( !_firstRowIsLabels ) {
labels = new String[ _numberOfColumns ];
StringBuilder prefix = new StringBuilder( );
for ( int i=0; i < labels.length; i++ ) {
int position = i % 26;
if ( position == 25 ) prefix.append( (char) ( i / 26 ) % 26 );
labels[ i ] = String.join( "", prefix.toString() + ( (char)( 'a' + position )) );
}
}
else labels = _colLabels;
loaded.getMut().labelAxes( index, labels );
loaded.getMut().label( _label );
return loaded;
}
@Override
public int getValueSize() {
String[] rawData;
if ( _rawData == null ) rawData = _lazyLoad();
else rawData = _rawData.get();
if ( rawData == null ) return 0;
return rawData.length;
}
@Override
public int getDataSize() {
if ( _numberOfBytes != null ) return _numberOfBytes;
else _lazyLoad();
return _numberOfBytes;
}
@Override
public int getTotalSize() {
return getDataSize();
}
@Override
public DataType<?> getDataType() {
return DataType.of( String.class );
}
@Override
public Shape getShape() {
return Shape.of( _numberOfRows, _numberOfColumns );
}
public String getDelimiter() {
return _delimiter;
}
public boolean isFirstRowIsLabels() {
return _firstRowIsLabels;
}
public String[] getColLabels() {
return _colLabels;
}
public boolean isFirstColIsIndex() {
return _firstColIsIndex;
}
public String[] getRowLabels() {
return _rowLabels;
}
public Integer getNumberOfRows() {
return _numberOfRows;
}
public Integer getNumberOfColumns() {
return _numberOfColumns;
}
private static class CSVType implements FileType
{
@Override public String defaultExtension() { return "csv"; }
}
}