BiScalarBroadcast.java
package neureka.backend.main.algorithms;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.devices.Device;
import neureka.dtype.NumericType;
public class BiScalarBroadcast extends AbstractFunDeviceAlgorithm<BiScalarBroadcast>
{
public BiScalarBroadcast() {
super("scalarization");
setAutogradModeFor( call -> AutoDiffMode.FORWARD_AND_BACKWARD );
setIsSuitableFor( call ->
call.validate()
.allNotNull(t -> t.getDataType().typeClassImplements(NumericType.class))
.tensors(tensors -> {
if (tensors.length != 2 && tensors.length != 3) return false;
int offset = ( tensors.length -2 );
if (tensors[1 + offset].size() > 1 && !tensors[1 + offset].isVirtual()) return false;
return
(tensors.length == 2 && tensors[0] != null && tensors[1] != null)
||
(tensors.length == 3 && tensors[1] != null && tensors[2] != null);
})
.suitabilityIfValid(SuitabilityPredicate.VERY_GOOD)
);
setCallPreparation(
call -> {
int offset = ( call.input( Number.class, 0 ) == null ? 1 : 0 );
Device<Number> device = call.getDeviceFor(Number.class);
Shape outShape = call.input( offset ).shape();
Class<Object> type = (Class<Object>) call.input( offset ).getItemType();
Tensor output = Tensor.of( type, outShape, 0.0 ).mut().setIsIntermediate( true );
output.mut().setIsVirtual( false );
device.store( output );
if ( call.arity() == 3 ) {
assert call.input( 0 ) == null;
return call.withInputAt( 0, output );
}
else
return call.withAddedInputAt( 0, output );
}
);
}
}