AbstractImageFileHandle.java
package neureka.devices.file;
import neureka.Shape;
import neureka.Tensor;
import neureka.common.utility.LogUtil;
import neureka.devices.Storage;
import neureka.devices.host.CPU;
import neureka.dtype.DataType;
import neureka.dtype.custom.UI8;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.Raster;
import java.io.File;
import java.io.IOException;
abstract class AbstractImageFileHandle<C> extends AbstractFileHandle<C, Number>
{
private final ImageFileType _type;
private int _width;
private int _height;
protected AbstractImageFileHandle(Tensor<Number> t, String filename, ImageFileType type ) {
super( filename, type );
LogUtil.nullArgCheck( type, "type", ImageFileType.class );
_type = type;
if ( t == null ) _loadHead();
else
{
if ( t.rank() != 3 || t.rank() == 2 )
throw new IllegalArgumentException(
"Expected tensor of rank 3, or 2 but encountered rank " + t.rank() + ". " +
"Cannot interpret tensor as image!"
);
if ( t.shape(t.rank()-1) != _type.numberOfChannels() )
throw new IllegalArgumentException(
"Expected last tensor axes length " + t.shape(t.rank()-1) + " to be equal " +
"to " + _type.numberOfChannels() + ", the number of expected color channels!"
);
_height = t.shape(0);
_width = t.shape(1);
t.getMut().setIsVirtual(false);
store(t);
}
}
private void _loadHead()
{
final File found = _loadFile();
final BufferedImage image;
try {
image = ImageIO.read(found);
Raster data = image.getData();
_height = data.getHeight();
_width = data.getWidth();
} catch ( Exception exception ) {
String message = _type.imageTypeName().toUpperCase() + " '"+_fileName+"' could not be read from file!";
_LOG.error( message, exception );
throw new IllegalStateException( message );
}
if ( _height < 1 || _width < 1 ) {
String message = "The height and width of the " + _type + " at '"+_fileName+"' is "+_height+" & "+_width+"." +
"However both dimensions must at least be of size 1!";
IllegalStateException e = new IllegalStateException( message );
_LOG.error( message, e );
throw e;
}
}
/** {@inheritDoc} */
@Override
public Tensor<Number> load() throws IOException {
Object value = _loadData(); // This is simply some kind of primitive array.
Tensor<?> t = Tensor.of(
_type.targetedValueType(),
Shape.of(_height, _width, _type.numberOfChannels()),
value
);
return t.getMut().upcast(Number.class);
}
@Override protected Object _loadData() throws IOException
{
File found = _loadFile();
BufferedImage image;
try
{
image = ImageIO.read( found );
byte[] data = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
if ( data.length != (_height * _width * _type.numberOfChannels()) )
throw new IllegalStateException("Loaded image data array does not match expected number of elements!");
if ( _type.targetedValueType() == Short.class ) {
short[] newData = new short[data.length];
UI8 ui8 = new UI8();
CPU.get().getExecutor().threaded(
data.length,
(start, end) -> { for (int i = start; i < end; i++) newData[i] = ui8.toTarget(data[i]); }
);
return newData;
}
else throw new IllegalStateException("Alternative types not yet supported!");
}
catch ( IOException e )
{
_LOG.error( "Failed loading " + _type + " file!", e );
throw e;
}
}
/** {@inheritDoc} */
@Override public int getValueSize() { return _width * _height * _type.numberOfChannels(); }
/** {@inheritDoc} */
@Override public int getDataSize() { return _width * _height * _type.numberOfChannels(); }
/** {@inheritDoc} */
@Override public int getTotalSize() { return _width * _height * _type.numberOfChannels(); }
/** {@inheritDoc} */
@Override public DataType<?> getDataType() { return DataType.of( UI8.class ); }
/** {@inheritDoc} */
@Override public Shape getShape() { return Shape.of( _height, _width, _type.numberOfChannels() ); }
/** {@inheritDoc} */
@Override
public <T extends Number> Storage<Number> store( Tensor<T> tensor )
{
LogUtil.nullArgCheck( tensor, "tensor", Tensor.class );
if ( _width != tensor.shape(1) )
throw new IllegalArgumentException(
"Cannot store tensor, because length " + tensor.shape(1) + " " +
"of axis 1 is not equal to image width " + _width + "."
);
if ( _height != tensor.shape(0) )
throw new IllegalArgumentException(
"Cannot store tensor, because length " + tensor.shape(0) + " " +
"of axis 0 is not equal to image width " + _height + "."
);
BufferedImage buff = tensor.asImage( _type.imageType() );
try {
ImageIO.write( buff, extension(), new File( _fileName ) );
} catch ( Exception e ) {
String message = "Failed writing tensor as " + extension() + " file!";
_LOG.error(message, e);
throw new IllegalStateException(message);
}
return this;
}
}