CPUReduce.java

package neureka.backend.main.operations.other.internal;

import neureka.Shape;
import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.devices.host.CPU;

/**
 *  An implementation of the min and max algorithm running on the CPU.
 *  This algorithm splits the provided input tensor into chucks which
 *  are then reduced to local min and max values.
 *  This happens iteratively until only a single value is left.
 *  Each workload also returns the index of the found min/max value,
 *  which is important for backpropagation...
 */
public class CPUReduce implements ImplementationFor<CPU>
{
    private interface ComparatorF32 { boolean compare(float current, float value); }
    private interface ComparatorF64 { boolean compare(double current, double value); }
    private interface ComparatorI32 { boolean compare(int current, int value); }
    private interface ComparatorI64 { boolean compare(long current, long value); }
    private interface ComparatorI8  { boolean compare(byte current, byte value); }
    private interface ComparatorI16 { boolean compare(short current, short value); }

    public enum Type {
        MIN, MAX;

        private ComparatorF32 getFloatComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
        private ComparatorF64 getDoubleComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
        private ComparatorI32 getIntComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
        private ComparatorI64 getLongComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
        private ComparatorI8 getByteComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
        private ComparatorI16 getShortComparator() {
            switch (this) {
                case MIN: return (current, value) -> current < value;
                case MAX: return (current, value) -> current > value;
                default: throw new IllegalArgumentException("Unsupported reduction type: "+this);
            }
        }
    }

    private final Type _type;


    public CPUReduce(Type type) {
        _type = type;
    }


    @Override
    public Tensor<Integer> run(ExecutionCall<CPU> call) {
        if ( call.getDevice() != CPU.get() )
            throw new IllegalArgumentException("This implementation is only available for the CPU!");
        Tensor<?> in = call.input(0) == null ? call.input(1) : call.input(0);
        int index = _runRecursively(in, CPU.get());
        return Tensor.of(Integer.class, Shape.of(1), index);
    }

    private int _runRecursively(Tensor<?> in, CPU device)
    {
        CPU.JVMExecutor executor = device.getExecutor();
        int RTS = 64;
        final int SIZE = in.size();

        double fraction = (double) SIZE / (double) RTS;
        // Determining optimal number of tiles!
        int N;
        // Check if fraction is an integer
        if ( fraction == Math.floor(fraction) )
            N = (int) fraction;
        else
            N = (int) Math.ceil(fraction); // The last tile we do a partial reduction (bound check)

        int[] out = new int[N];

        if ( in.size() == 1 ) {
            assert out.length == 1;
            return out[0];
        }
        Class<?> type = in.itemType();

        if ( type == Float.class ) {
            ComparatorF32 comparator = _type.getFloatComparator();
            float[] inData = in.mut().getDataForWriting(float[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                float value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    float current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                float[] reduced = new float[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Float.class, Shape.of(out.length), reduced), device)];
            }
        }
        if ( type == Double.class ) {
            ComparatorF64 comparator = _type.getDoubleComparator();
            double[] inData = in.mut().getDataForWriting(double[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                double value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    double current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                double[] reduced = new double[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Double.class, Shape.of(out.length), reduced), device)];
            }
        }
        if ( type == Integer.class ) {
            ComparatorI32 comparator = _type.getIntComparator();
            int[] inData = in.mut().getDataForWriting(int[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                int value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    int current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                int[] reduced = new int[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Integer.class, Shape.of(out.length), reduced), device)];
            }
        }
        if ( type == Long.class ) {
            ComparatorI64 comparator = _type.getLongComparator();
            long[] inData = in.mut().getDataForWriting(long[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                long value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    long current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                long[] reduced = new long[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Long.class, Shape.of(out.length), reduced), device)];
            }
        }
        if ( type == Short.class ) {
            ComparatorI16 comparator = _type.getShortComparator();
            short[] inData = in.mut().getDataForWriting(short[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                short value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    short current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                short[] reduced = new short[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Short.class, Shape.of(out.length), reduced), device)];
            }
        }
        if ( type == Byte.class ) {
            ComparatorI8 comparator = _type.getByteComparator();
            byte[] inData = in.mut().getDataForWriting(byte[].class);
            executor.threaded( N, ni -> {
                int offset = ni * RTS;
                int limit = Math.min( offset + RTS, SIZE );
                byte value = inData[offset];
                int found_index = offset;
                offset++;
                for ( int i=offset; i < limit; ++i ) {
                    byte current = inData[i];
                    if ( comparator.compare(current, value) ) {
                        value = current; found_index = i;
                    }
                }
                out[ni] = found_index;
            });
            if ( N > 1 ) {
                byte[] reduced = new byte[out.length];
                executor.threaded( out.length, (start, end) -> { for ( int i=start; i < end; ++i ) reduced[i] = inData[out[i]];});
                return out[_runRecursively(Tensor.of(Byte.class, Shape.of(out.length), reduced), device)];
            }
        }

        return out[0];
    }

}