Max.java
package neureka.backend.main.operations.other;
import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.AutoDiffMode;
import neureka.backend.api.DeviceAlgorithm;
import neureka.backend.api.Result;
import neureka.backend.api.template.algorithms.AbstractDeviceAlgorithm;
import neureka.backend.api.template.operations.AbstractOperation;
import neureka.backend.api.template.operations.OperationBuilder;
import neureka.backend.main.operations.ElemWiseUtil;
import neureka.backend.main.operations.linear.internal.opencl.CLReduce;
import neureka.backend.main.operations.other.internal.CPUReduce;
import neureka.math.Function;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.devices.opencl.OpenCLDevice;
public class Max extends AbstractOperation
{
public Max()
{
super(
new OperationBuilder()
.identifier( "max" )
.operator( "max" )
.arity( 1 )
.isOperator( false )
.isIndexer( false )
.isDifferentiable( true )
.isInline( false )
);
setAlgorithm(
DeviceAlgorithm
.withName("max_algorithm")
.setIsSuitableFor(
call -> call.validate()
.allNotNull( t -> Number.class.isAssignableFrom(t.getItemType()) )
.basicSuitability()
)
.setAutogradModeFor( call -> AutoDiffMode.BACKWARD_ONLY )
.setExecution( (caller, call) -> {
Tensor<?>[] inputs = AbstractDeviceAlgorithm.flatten(caller, call).inputs();
call = call.withInputs(inputs);
Tensor<Integer> index = ((DeviceAlgorithm)call.getAlgorithm()).getImplementationFor(call.getDevice()).run(call);
int i = index.item();
Tensor<?> in = inputs[0] == null ? inputs[1] : inputs[0];
Class<Object> typeClass = (Class<Object>) in.itemType();
Shape shape = in.shape();
Device<Object> device = (Device<Object>) call.getDevice();
return Result.of(
Tensor.of(in.itemType(), Shape.of( 1 ), in.item(i)).to(call.getDevice()).mut().setIsIntermediate(true)
)
.withADAction( target -> {
Tensor<Object> error = (Tensor<Object>) target.error();
assert error.size() == 1;
Tensor<Object> newError = ElemWiseUtil.newTensorLike(typeClass, shape, true, device, 0);
newError.mut().setIsVirtual(false);
newError.mut().setItemAt(i, error.item(0));
return newError;
});
})
.setCallPreparation( call ->
{
if ( call.input( 0 ) == null )
call = call.withInputAt( 0, call.input( 1 ) );
return call;
})
.buildFunAlgorithm()
.setImplementationFor( CPU.class, new CPUReduce(CPUReduce.Type.MAX) )
.setImplementationFor( OpenCLDevice.class, new CLReduce(CLReduce.Type.MAX) )
);
}
@Override
public double calculate( double[] inputs, int j, int d, Function[] src ) { return src[ 0 ].call( inputs, j ); }
}