ScalarBroadcast.java
package neureka.backend.main.algorithms;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.backend.main.implementations.fun.api.CPUFun;
import neureka.backend.main.implementations.fun.api.ScalarFun;
import neureka.backend.main.implementations.scalar.CPUScalarBroadcastFunction;
import neureka.math.args.Arg;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.devices.opencl.OpenCLDevice;
import neureka.dtype.NumericType;
import neureka.ndim.NDimensional;
public class ScalarBroadcast extends AbstractFunDeviceAlgorithm<ScalarBroadcast>
{
public ScalarBroadcast(ScalarFun fun) {
super("scalar broadcast");
setAutogradModeFor(
call -> call
.validate().allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
);
setIsSuitableFor( call ->
call.validate()
.allNotNull( t -> t.getDataType().typeClassImplements(NumericType.class) )
.tensors( tensors -> {
if ( tensors.length != 2 ) return false;
if ( !tensors[1].isVirtual() ) return false;
if ( tensors[0] != null && tensors[0].isVirtual() ) return false;
return tensors[0] == null && tensors[1] != null || tensors[0].shape().equals(tensors[1].shape());
})
.suitabilityIfValid( SuitabilityPredicate.VERY_GOOD )
);
setCallPreparation(
call -> {
Device<Number> device = call.getDeviceFor(Number.class);
assert call.input( 0 ) == null; // Creating a new tensor:
Shape outShape = call.input( 1 ).shape();
Class<Object> type = (Class<Object>) call.input( 1 ).getItemType();
Tensor output = Tensor.of( type, outShape, 0.0 ).mut().setIsIntermediate( true );
output.mut().setIsVirtual( false );
try {
device.store( output );
} catch( Exception e ) {
e.printStackTrace();
}
return call.withInputAt( 0, output );
}
);
setExecution(
(caller, call) ->
Result.of(AbstractDeviceAlgorithm.prepareAndExecute(call,AbstractDeviceAlgorithm::executeDeviceAlgorithm))
.withAutoDiff( FallbackAlgorithm::ADAction )
);
setImplementationFor( CPU.class, new CPUScalarBroadcastFunction( fun ) );
setImplementationFor(
OpenCLDevice.class,
call -> {
int d = call.getValOf(Arg.DerivIdx.class);
CPUFun f = d < 0 ? fun.getActivation() : fun.getDerivative();
double value = f.invoke( call.input( Number.class, 1 ).at(0).get().doubleValue() );
Tensor<Number> t = call.input( Number.class, 0 );
int gwz = t.size();
call.getDevice()
.getKernel("scalar_broadcast")
.passAllOf(t)
.pass((float) value)
.pass(t.rank())
.call( gwz );
return call.input(0);
}
);
}
}