ScalarAlgorithm.java
package neureka.backend.main.algorithms;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.Result;
import neureka.backend.api.fun.SuitabilityPredicate;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.algorithms.AbstractFunDeviceAlgorithm;
import neureka.backend.api.template.algorithms.FallbackAlgorithm;
import neureka.devices.Device;
import neureka.dtype.NumericType;
import neureka.ndim.NDimensional;
public class ScalarAlgorithm extends AbstractFunDeviceAlgorithm<ScalarAlgorithm>
{
public ScalarAlgorithm() {
super("scalar activation");
setAutogradModeFor(
call -> call
.validate().allNotNullHaveSame(NDimensional::shape)
.ifValid(AutoDiffMode.FORWARD_AND_BACKWARD)
.orElse(AutoDiffMode.BACKWARD_ONLY)
);
setIsSuitableFor( call ->
call.validate()
.allNotNull( t -> t.getDataType().typeClassImplements(NumericType.class) )
.tensors( tensors -> {
if ( tensors.length != 2 ) return false;
if ( !tensors[1].isVirtual() ) return false;
if ( tensors[0] != null && !tensors[0].isVirtual() ) return false;
return tensors[0] == null && tensors[1] != null || tensors[0].shape().equals(tensors[1].shape());
})
.suitabilityIfValid( SuitabilityPredicate.EXCELLENT )
);
setCallPreparation(
call -> {
Device<Number> device = call.getDeviceFor(Number.class);
assert call.input( 0 ) == null; // Creating a new tensor:
Shape outShape = call.input( 1 ).shape();
Class<Object> type = (Class<Object>) call.input( 1 ).getItemType();
Tensor output = Tensor.of( type, outShape, 0.0 ).mut().setIsIntermediate( true );
device.store( output );
return call.withInputAt( 0, output );
}
);
setExecution(
(caller, call) ->
Result.of(AbstractDeviceAlgorithm.prepareAndExecute(call,AbstractDeviceAlgorithm::executeDeviceAlgorithm))
.withAutoDiff( FallbackAlgorithm::ADAction )
);
}
}