CPUScalarBroadcastFunction.java
package neureka.backend.main.implementations.scalar;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.backend.main.implementations.fun.api.CPUFun;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.math.args.Arg;
import neureka.devices.host.CPU;
import neureka.ndim.iterator.NDIterator;
public class CPUScalarBroadcastFunction implements ImplementationFor<CPU>
{
private final ScalarFun _fun;
public CPUScalarBroadcastFunction(ScalarFun fun ) { _fun = fun; }
@Override
public Tensor<?> run(ExecutionCall<CPU> call) {
call.getDevice()
.getExecutor()
.threaded(
call.input(0).size(),
_workloadFor(call)
);
return call.input(0);
}
private CPU.RangeWorkload _workloadFor(
ExecutionCall<CPU> call
) {
Tensor<Number> t0_drn = call.input( Number.class, 0 );
Tensor<Number> src = call.input( Number.class, 1 );
Class<?> typeClass = t0_drn.getItemType();
CPU.RangeWorkload workload = null;
int d = call.getValOf(Arg.DerivIdx.class);
CPUFun f = d < 0 ? _fun.getActivation() : _fun.getDerivative();
if ( typeClass == Double.class ) {
double value = src.at(0).get().doubleValue();
double[] t0_value = t0_drn.mut().getDataForWriting(double[].class);
double finalValue = f.invoke(value);
workload = ( i, end ) -> {
NDIterator t0Idx = NDIterator.of(t0_drn);
NDIterator srcIdx = NDIterator.of(src);
t0Idx.set(t0_drn.indicesOfIndex(i));
srcIdx.set(src.indicesOfIndex(i));
while ( i < end ) // increment on drain accordingly:
{
// setInto _value in drn:
t0_value[t0Idx.i()] = finalValue;
// increment on drain:
t0Idx.increment();
srcIdx.increment();
//NDConfiguration.Utility.increment(t0Idx, t0Shp);
i++;
}
};
}
if ( typeClass == Float.class ) {
float value = src.at(0).get().floatValue();
float[] t0_value = t0_drn.mut().getDataForWriting(float[].class);
float finalValue = f.invoke(value);
workload = ( i, end ) -> {
NDIterator t0Idx = NDIterator.of(t0_drn);
NDIterator srcIdx = NDIterator.of(src);
t0Idx.set(t0_drn.indicesOfIndex(i));
srcIdx.set(src.indicesOfIndex(i));
while (i < end) // increment on drain accordingly:
{
// setInto _value in drn:
t0_value[t0Idx.i()] = finalValue;
// increment on drain:
t0Idx.increment();
srcIdx.increment();
//NDConfiguration.Utility.increment(t0Idx, t0Shp);
i++;
}
};
}
if ( typeClass == Integer.class ) {
int value = src.at(0).get().intValue();
int[] t0_value = t0_drn.mut().getDataForWriting(int[].class);
int finalValue = f.invoke(value);
workload = ( i, end ) -> {
NDIterator t0Idx = NDIterator.of(t0_drn);
NDIterator srcIdx = NDIterator.of(src);
t0Idx.set(t0_drn.indicesOfIndex(i));
srcIdx.set(src.indicesOfIndex(i));
while (i < end) // increment on drain accordingly:
{
// setInto _value in drn:
t0_value[t0Idx.i()] = finalValue;
// increment on drain:
t0Idx.increment();
srcIdx.increment();
//NDConfiguration.Utility.increment(t0Idx, t0Shp);
i++;
}
};
}
if ( t0_drn.mut().getData().getOrNull().getClass() == Object[].class ) {
Object value = src.at(0).get();
Object[] t0_value = t0_drn.mut().getDataForWriting(Object[].class);
Object finalValue = f.invoke(value);
workload = (i, end ) -> {
NDIterator t0Idx = NDIterator.of(t0_drn);
NDIterator srcIdx = NDIterator.of(src);
t0Idx.set(t0_drn.indicesOfIndex(i));
srcIdx.set(src.indicesOfIndex(i));
while (i < end) // increment on drain accordingly:
{
// setInto _value in drn:
t0_value[t0Idx.i()] = finalValue;
// increment on drain:
t0Idx.increment();
srcIdx.increment();
//NDConfiguration.Utility.increment(t0Idx, t0Shp);
i++;
}
};
}
if ( workload == null )
throw new IllegalArgumentException("");
else
return workload;
}
}