AbstractCPUConvolution.java

package neureka.backend.main.implementations.convolution;

import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.backend.main.implementations.fun.api.CPUBiFun;
import neureka.math.args.Arg;
import neureka.devices.host.CPU;
import neureka.ndim.iterator.NDIterator;

public abstract class AbstractCPUConvolution implements ImplementationFor<CPU>
{
    protected abstract CPUBiFun _getFun();

    @Override
    public Tensor<?> run(ExecutionCall<CPU> call )
    {
        SimpleCPUConvolution simpleConvolution = new SimpleCPUConvolution(call.input(1), call.input(2), call.input(0));

        if ( simpleConvolution.isSuitable() && call.getValOf(Arg.DerivIdx.class) < 0 )
            simpleConvolution.run();
        else
            _doNDConvolutionFor( call ); // General purpose ND convolution, -> any dimensionality.

        return call.input(0);
    }

    private void _doNDConvolutionFor( ExecutionCall<CPU> call )
    {
        call.getDevice()
            .getExecutor()
            .threaded(
                call.input(0).size(),
                _workloadFor(call)
            );
    }

    private CPU.RangeWorkload _workloadFor(
        ExecutionCall<CPU> call
    ) {
        Tensor<Number> t0_drn = call.input( Number.class, 0 );
        Tensor<Number> t1_src = call.input( Number.class, 1 ).mut().setIsVirtual(false);
        Tensor<Number> t2_src = call.input( Number.class, 2 ).mut().setIsVirtual(false);

        Class<?> typeClass = t0_drn.getItemType();

        int d = call.getDerivativeIndex();
        CPUBiFun f = _getFun();
        CPU.RangeWorkload workload = null;

        if ( typeClass == Double.class ) {
            if ( d < 0 )
                workload = (i, end) -> _convolve64( t0_drn, t1_src, t2_src, i, end, f );
            else
                workload = (i, end) -> _deConvolve64( t0_drn, t1_src, t2_src, i, end, f );
        }
        else if ( typeClass == Float.class ) {
            if ( d < 0 )
                workload = (i, end) -> _convolve32(t0_drn, t1_src, t2_src, i, end, f);
            else
                workload = (i, end) -> _deConvolve32( t0_drn, t1_src, t2_src, i, end, f );
        }

        if ( workload == null )
            throw new IllegalArgumentException("Could not create convolution worker for type class '"+typeClass+"'!");
        else
            return workload;
    }

    private static void _convolve64(
            final Tensor<?> t0_drn, final Tensor<?> t1_src, final Tensor<?> t2_src,
            final int start,
            final int end,
            final CPUBiFun operation
    ) {
        NDIterator t0Idx = NDIterator.of( t0_drn );
        NDIterator t1Idx = NDIterator.of( t1_src );
        t0Idx.set( t0_drn.indicesOfIndex( start ) );
        NDIterator t2Idx = NDIterator.of( t2_src );
        int rank = t0Idx.rank();

        double[] t0_value = t0_drn.mut().getDataForWriting( double[].class );
        double[] t1_value = t1_src.mut().getDataAs( double[].class );
        double[] t2_value = t2_src.mut().getDataAs( double[].class );

        int i = start;

        while ( i < end )
        {//increment on drain accordingly:
            int ri = 0;
            while ( ri < rank ) {
                if ( t1Idx.shape( ri ) == t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, t0Idx.get( ri ) );
                } else if ( t1Idx.shape( ri ) > t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                } else if ( t1Idx.shape( ri ) < t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, 0 );
                    t2Idx.set( ri, t0Idx.get( ri ) );
                }
                ri++;
            }
            //----------
            // multiplication:
            double value = 0;
            boolean running = true;
            boolean incrementing = false;
            while ( running ) {
                ri = ( ri == rank ) ? 0 : ri;
                if ( !incrementing ) {
                    value += operation.invoke( t1_value[t1Idx.i()], t2_value[t2Idx.i()] );
                    incrementing = true;
                    ri = 0;
                } else { // incrementing:
                    if ( t1Idx.get( ri ) < t1Idx.shape( ri ) && t2Idx.get( ri ) < t2Idx.shape( ri ) ) {
                        t1Idx.set( ri, t1Idx.get( ri ) + 1 );
                        t2Idx.set( ri, t2Idx.get( ri ) + 1 );
                        if ( t1Idx.get( ri ) == t1Idx.shape( ri ) || t2Idx.get( ri ) == t2Idx.shape( ri )) {
                            running = ( ri != rank - 1 );
                            if ( t1Idx.shape( ri ) == t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, t0Idx.get( ri ) );
                                t2Idx.set( ri, t0Idx.get( ri ) );
                            } else if ( t1Idx.shape( ri ) > t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, t0Idx.get( ri ) );
                                t2Idx.set( ri, 0 );
                            } else if ( t1Idx.shape( ri ) < t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, 0 );
                                t2Idx.set( ri, t0Idx.get( ri ) );
                            }
                            ri++;
                        } else incrementing = false;
                    } else ri++;
                }
            }
            //set value in drn:
            t0_value[ t0Idx.i() ] = value;
            //increment on drain:
            t0Idx.increment();
            i++;
        }

    }


    private static void _deConvolve64(
            final Tensor<?> t0_drn, final Tensor<?> t1_src, final Tensor<?> t2_src,
            final int start, final int end,
            final CPUBiFun operation
    ) {
        NDIterator t0Idx = NDIterator.of( t0_drn );
        NDIterator t1Idx = NDIterator.of( t1_src );
        t0Idx.set( t0_drn.indicesOfIndex( start ) );
        NDIterator t2Idx = NDIterator.of( t2_src );
        int rank = t0Idx.rank();

        double[] t0_value = t0_drn.mut().getDataForWriting( double[].class );
        double[] t1_value = t1_src.mut().getDataAs( double[].class );
        double[] t2_value = t2_src.mut().getDataAs( double[].class );

        assert t0_value != null;
        assert t1_value != null;
        assert t2_value != null;

        // Incrementing if 'i>0' so that all indexes match:
        for ( int ii = 0; ii < start; ii++ ) {
            int ri = 0;
            while ( ri < rank ) {
                if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                }
                else
                    t1Idx.set(
                            ri ,
                            t0Idx.shape( ri ) > t1Idx.shape( ri )
                                    ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                                    : (t0Idx.get( ri ) + t2Idx.get( ri ))
                    );
                ri++;
            }
        }

        int i = start;

        // Looping through given range :
        while ( i < end ) {//increment on drain accordingly:
            int ri = 0;
            while ( ri < rank ) {
                if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {//setting 0
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                }
                else
                    t1Idx.set( ri, (t0Idx.shape( ri ) > t1Idx.shape( ri ))
                            ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                            : (t0Idx.get( ri ) + t2Idx.get( ri ))
                    );
                ri++;
            }
            //----------
            double value = 0;
            boolean running = true;
            boolean incrementing = false;
            while ( running ) {
                ri = ( ri == rank ? 0 : ri );
                if ( !incrementing ) {// := testing for match and applying operation:
                    boolean isMatch = true;
                    for ( int rii = 0; rii < rank; rii++ )
                        isMatch = (t1Idx.get( rii ) < t1Idx.shape( rii ) && t1Idx.get( rii ) >= 0) && isMatch;

                    value += (isMatch) ? operation.invoke( t1_value[t1Idx.i()], t2_value[t2Idx.i()] ) : 0;
                    incrementing = true;
                    ri = 0;
                } else { // incrementing:
                    if ( t2Idx.get( ri ) < t2Idx.shape( ri ) ) {
                        t2Idx.set( ri, t2Idx.get( ri ) + 1 );
                        if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {
                            running = ( ri != rank - 1 );
                            t1Idx.set( ri, t0Idx.get( ri ) );
                            t2Idx.set( ri, 0 );
                            ri++;
                        } else {
                            t1Idx.set( ri,
                                    t0Idx.shape( ri ) > t1Idx.shape( ri )
                                            ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                                            : (t0Idx.get( ri ) + t2Idx.get( ri ))
                            );
                            incrementing = false;
                        }
                    } else ri++;
                }
            }
            // set value in drn:
            t0_value[ t0Idx.i() ] = value;
            // increment on drain:
            t0Idx.increment();
            i++;
        }
    }

    // ---


    private static void _convolve32(
            final Tensor<?> t0_drn, final Tensor<?> t1_src, final Tensor<?> t2_src,
            final int start, final int end,
            final CPUBiFun operation
    ) {
        NDIterator t0Idx = NDIterator.of( t0_drn );
        NDIterator t1Idx = NDIterator.of( t1_src );
        t0Idx.set( t0_drn.indicesOfIndex( start ) );
        NDIterator t2Idx = NDIterator.of( t2_src );
        int rank = t0Idx.rank();

        float[] t0_value = t0_drn.mut().getDataForWriting( float[].class );
        float[] t1_value = t1_src.mut().getDataAs( float[].class );
        float[] t2_value = t2_src.mut().getDataAs( float[].class );

        int i = start;

        while ( i < end )
        { // increment on drain accordingly:
            int ri = 0;
            while ( ri < rank ) {
                if ( t1Idx.shape( ri ) == t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, t0Idx.get( ri ) );
                } else if ( t1Idx.shape( ri ) > t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                } else if ( t1Idx.shape( ri ) < t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, 0 );
                    t2Idx.set( ri, t0Idx.get( ri ) );
                }
                ri++;
            }
            //----------
            // multiplication:
            float value = 0;
            boolean running = true;
            boolean incrementing = false;
            while ( running ) {
                ri = ( ri == rank ? 0 : ri );
                if ( !incrementing ) {
                    value += operation.invoke( t1_value[t1Idx.i()], t2_value[t2Idx.i()] );
                    incrementing = true;
                    ri = 0;
                } else { // incrementing:
                    if ( t1Idx.get( ri ) < t1Idx.shape( ri ) && t2Idx.get( ri ) < t2Idx.shape( ri ) ) {
                        t1Idx.set( ri, t1Idx.get( ri ) + 1 );
                        t2Idx.set( ri, t2Idx.get( ri ) + 1 );
                        if ( t1Idx.get( ri ) == t1Idx.shape( ri ) || t2Idx.get( ri ) == t2Idx.shape( ri )) {
                            running = ( ri != rank - 1 );
                            if ( t1Idx.shape( ri ) == t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, t0Idx.get( ri ) );
                                t2Idx.set( ri, t0Idx.get( ri ) );
                            } else if ( t1Idx.shape( ri ) > t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, t0Idx.get( ri ) );
                                t2Idx.set( ri, 0 );
                            } else if ( t1Idx.shape( ri ) < t2Idx.shape( ri ) ) {
                                t1Idx.set( ri, 0 );
                                t2Idx.set( ri, t0Idx.get( ri ) );
                            }
                            ri++;
                        } else incrementing = false;
                    } else ri++;
                }
            }// set value in drain:
            t0_value[ t0Idx.i() ] = value;
            // increment on drain:
            t0Idx.increment();
            i++;
        }

    }


    private static void _deConvolve32(
            final Tensor<?> t0_drn, final Tensor<?> t1_src, final Tensor<?> t2_src,
            final int start, final int end,
            final CPUBiFun operation
    ) {
        NDIterator t0Idx = NDIterator.of( t0_drn );
        NDIterator t1Idx = NDIterator.of( t1_src );
        t0Idx.set( t0_drn.indicesOfIndex( start ) );
        NDIterator t2Idx = NDIterator.of( t2_src );
        int rank = t0Idx.rank();

        float[] t0_value = t0_drn.mut().getDataForWriting( float[].class );
        float[] t1_value = t1_src.mut().getDataAs( float[].class );
        float[] t2_value = t2_src.mut().getDataAs( float[].class );

        // Incrementing if 'i>0' so that all indexes match:
        for ( int ii = 0; ii < start; ii++ ) {
            int ri = 0;
            while ( ri < rank ) {
                if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                }
                else
                    t1Idx.set( ri ,
                            t0Idx.shape( ri ) > t1Idx.shape( ri )
                                    ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                                    : (t0Idx.get( ri ) + t2Idx.get( ri ))
                    );
                ri++;
            }
        }

        int i = start;

        // Looping through given range :
        while ( i < end ) { // increment on drain accordingly:
            int ri = 0;
            while ( ri < rank ) {
                if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {//setting 0
                    t1Idx.set( ri, t0Idx.get( ri ) );
                    t2Idx.set( ri, 0 );
                }
                else
                    t1Idx.set( ri, (t0Idx.shape( ri ) > t1Idx.shape( ri ))
                            ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                            : (t0Idx.get( ri ) + t2Idx.get( ri ))
                    );
                ri++;
            }
            //----------
            float value = 0;
            boolean running = true;
            boolean incrementing = false;
            while ( running ) {
                ri = ( ri == rank ? 0 : ri );
                if ( !incrementing ) {// := testing for match and applying operation:
                    boolean isMatch = true;
                    for ( int rii = 0; rii < rank; rii++ )
                        isMatch = ( t1Idx.get( rii ) < t1Idx.shape( rii ) && t1Idx.get( rii ) >= 0 ) && isMatch;

                    value += ( isMatch ? operation.invoke( t1_value[t1Idx.i()], t2_value[t2Idx.i()] ) : 0 );
                    incrementing = true;
                    ri = 0;
                } else { // incrementing:
                    if ( t2Idx.get( ri ) < t2Idx.shape( ri ) ) {
                        t2Idx.set( ri, t2Idx.get( ri ) + 1 );
                        if ( t2Idx.get( ri ) == t2Idx.shape( ri ) ) {
                            running = ( ri != rank - 1 );
                            t1Idx.set( ri, t0Idx.get( ri ) );
                            t2Idx.set( ri, 0 );
                            ri++;
                        } else {
                            t1Idx.set( ri,
                                    t0Idx.shape( ri ) > t1Idx.shape( ri )
                                            ? (t0Idx.get( ri ) - t2Idx.get( ri ))
                                            : (t0Idx.get( ri ) + t2Idx.get( ri ))
                            );
                            incrementing = false;
                        }
                    } else ri++;
                }
            }
            // set value in drain:
            t0_value[ t0Idx.i() ] = value;
            // increment on drain:
            t0Idx.increment();
            i++;
        }
    }

}