CPUMatMul.java

package neureka.backend.main.implementations.matmul;

import neureka.Tensor;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.backend.main.operations.linear.internal.blas.GEMM;
import neureka.backend.main.operations.linear.internal.blas.IGEMM;
import neureka.devices.host.CPU;
import neureka.ndim.config.NDConfiguration;

/**
 *  This is a library internal class, do not depend on this.
 */
public class CPUMatMul implements ImplementationFor<CPU>
{
    @Override
    public Tensor<?> run(ExecutionCall<CPU> call )
    {
        if ( !call.validate().all( (t1, t2) -> t1.getNDConf().getLayout().isCompatible(t2.getNDConf().getLayout()) ).isValid() )
            throw new IllegalArgumentException(
                        "Data layout inconsistency between provided tensors encountered. " +
                        "All tensors must be of the same layout."
                    );

        if ( !call.validate().allShare(Tensor::getDataType).isValid() )
            throw new IllegalArgumentException(
                       "Type inconsistency between provided tensors encountered. " +
                       "All tensors must be of the same type."
                    );

        NDConfiguration.Layout layout = call.input( 1 ).getNDConf().getLayout();
        NDConfiguration.Layout layout2 = call.input( 2 ).getNDConf().getLayout();

        boolean bothRowMajor = ( layout == NDConfiguration.Layout.ROW_MAJOR && layout2 == NDConfiguration.Layout.ROW_MAJOR );
        boolean oneRMOneSym = ( layout == NDConfiguration.Layout.ROW_MAJOR && layout2 == NDConfiguration.Layout.SYMMETRIC ) ||
                              ( layout == NDConfiguration.Layout.SYMMETRIC && layout2 == NDConfiguration.Layout.ROW_MAJOR );
        boolean bothSym = ( layout == NDConfiguration.Layout.SYMMETRIC && layout2 == NDConfiguration.Layout.SYMMETRIC );

        boolean rowMajor = bothRowMajor || oneRMOneSym || bothSym;

        int[] shapeA = call.input( 1 ).getNDConf().shape();
        int[] shapeB = call.input( 2 ).getNDConf().shape();
        int[] shapeC = call.input( 0 ).getNDConf().shape();

        // A * B = C // [MxK]*[KxN] = [MxN]
        int aRows = shapeA[0];
        int aCols = shapeA[1];
        int bRows = shapeB[0];
        int bCols = shapeB[1];

        if ( aCols != bRows )
            throw new IllegalArgumentException("'A' matrix rows " + aCols + " did not match 'B' matrix columns " + bRows + ".");

        Class<?> type = call.input( 0 ).getDataType().getItemTypeClass();
        if ( type == Double.class ) {
            double[] A = call.input(Double.class, 1).mut().getDataAs(double[].class);
            double[] B = call.input(Double.class, 2).mut().getDataAs(double[].class);
            double[] C = call.input(Double.class, 0).mut().getDataForWriting(double[].class);
            execute( rowMajor, A, B, C, aRows, aCols, bCols );
        } else if ( type == Float.class ) {
            float[] A = call.input(Float.class, 1).mut().getDataAs(float[].class);
            float[] B = call.input(Float.class, 2).mut().getDataAs(float[].class);
            float[] C = call.input(Float.class, 0).mut().getDataForWriting(float[].class);
            execute( rowMajor, A, B, C, aRows, aCols, bCols );
        }
        else if ( type == Long.class ) {
            long[] A = call.input(Long.class, 1).mut().getDataAs(long[].class);
            long[] B = call.input(Long.class, 2).mut().getDataAs(long[].class);
            long[] C = call.input(Long.class, 0).mut().getDataForWriting(long[].class);
            execute( rowMajor, A, B, C, aRows, aCols, bCols );
        }
        else if ( type == Integer.class ) {
            int[] A = call.input(Integer.class, 1).mut().getDataAs(int[].class);
            int[] B = call.input(Integer.class, 2).mut().getDataAs(int[].class);
            int[] C = call.input(Integer.class, 0).mut().getDataForWriting(int[].class);
            execute( rowMajor, A, B, C, aRows, aCols, bCols );
        }
        else
            throw new IllegalArgumentException(
                        "Data type '"+type.getSimpleName()+"' not yet supported " +
                        "for CPU based matrix multiplication!"
                    );

        return call.input( 0 );
    }

    public static void execute(
            boolean rowMajor, double[] A, double[] B, double[] C, int aRows, int aCols, int bCols
    ) {
        GEMM.operationForF64( rowMajor, aRows, bCols ).invoke( C, A, aCols, B );
    }

    public static void execute(
            boolean rowMajor, float[] A, float[] B, float[] C, int aRows, int aCols, int bCols
    ) {
        GEMM.operationForF32( rowMajor, aRows, bCols ).invoke( C, A, aCols, B );
    }

    public static void execute(
            boolean rowMajor, long[] A, long[] B, long[] C, int aRows, int aCols, int bCols
    ) {
        IGEMM.operationForI64( rowMajor, aRows, bCols ).invoke( C, A, aCols, B );
    }

    public static void execute(
            boolean rowMajor, int[] A, int[] B, int[] C, int aRows, int aCols, int bCols
    ) {
        IGEMM.operationForI32( rowMajor, aRows, bCols ).invoke( C, A, aCols, B );
    }

}