CPUElementwiseAssignFun.java
package neureka.backend.main.implementations.elementwise;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.devices.host.CPU;
import neureka.ndim.NDimensional;
public class CPUElementwiseAssignFun extends CPUElementwiseFunction
{
public CPUElementwiseAssignFun() { super(ScalarFun.IDENTITY); }
@Override
public Tensor<?> run(ExecutionCall<CPU> call )
{
assert call.arity() == 2;
boolean allVirtual = call.validate().all( Tensor::isVirtual ).isValid();
if ( allVirtual ) {
call.input(Object.class, 0).mut().setDataAt(0, call.input(1).item() );
assert call.input(0).isVirtual();
assert call.input(1).isVirtual();
return call.input(0);
}
call.input(0).mut().setIsVirtual(false);
boolean isSimple = call.validate()
.allShare(Tensor::isVirtual)
.allShare(NDimensional::getNDConf)
.all( t -> t.getNDConf().isSimple() )
.isValid();
if ( isSimple ) {
Class<?> type = call.input(0).itemType();
if ( type == Double.class ) {
double[] output = call.input(0).mut().getDataForWriting(double[].class);
double[] input = call.input(1).mut().getDataAs(double[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Integer.class ) {
int[] output = call.input(0).mut().getDataForWriting(int[].class);
int[] input = call.input(1).mut().getDataAs(int[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Float.class ) {
float[] output = call.input(0).mut().getDataForWriting(float[].class);
float[] input = call.input(1).mut().getDataAs(float[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Long.class ) {
long[] output = call.input(0).mut().getDataForWriting(long[].class);
long[] input = call.input(1).mut().getDataAs(long[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Boolean.class ) {
boolean[] output = call.input(0).mut().getDataForWriting(boolean[].class);
boolean[] input = call.input(1).mut().getDataAs(boolean[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Character.class ) {
char[] output = call.input(0).mut().getDataForWriting(char[].class);
char[] input = call.input(1).mut().getDataAs(char[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Byte.class ) {
byte[] output = call.input(0).mut().getDataForWriting(byte[].class);
byte[] input = call.input(1).mut().getDataAs(byte[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
} else if ( type == Short.class ) {
short[] output = call.input(0).mut().getDataForWriting(short[].class);
short[] input = call.input(1).mut().getDataAs(short[].class);
if ( input.length >= output.length ) {
System.arraycopy( input, 0, output, 0, call.input(0).size() );
return call.input(0);
}
}
}
return super.run( call );
}
}