CLGEMM.java
package neureka.backend.main.operations.linear.internal.opencl;
import neureka.Tensor;
import neureka.backend.api.ImplementationFor;
import neureka.backend.api.ExecutionCall;
import neureka.devices.opencl.KernelCaller;
import neureka.devices.opencl.OpenCLDevice;
import java.util.function.Supplier;
public class CLGEMM implements ImplementationFor<OpenCLDevice>
{
@Override
public Tensor<?> run(ExecutionCall<OpenCLDevice> call)
{
Tensor<Float> c = call.input(Float.class, 0);
Tensor<Float> a = call.input(Float.class, 1);
Tensor<Float> b = call.input(Float.class, 2);
int M = a.shape(0);
int K = a.shape(1);
int N = b.shape(1);
int K1 = b.shape(0);
assert K == K1;
String kernelName = "fast_CM_MM_"+M+"x"+K+"x"+N+"";
// Determining optimal tile widths
int MW = 1;
int KW = 1;
for ( int s : new int[]{16,8,4,2,1} )
if ( M % s == 0 ) { MW = s; break; }
for ( int s : new int[]{8,4,2,1} )
if ( N % s == 0 && K % s == 0 ) { KW = s; break; }
int NW = KW;
int finalMW = MW;
int finalKW = KW;
Supplier<String> code = () ->
" #define K "+K+" \n" +
" #define N "+N+" \n" +
" #define MW "+ finalMW +" // M tile Width \n" +
" #define NW "+NW+" // N tile Width -- NW & KW should be the same ! \n" +
" #define KW "+ finalKW +" // K tile Width \n" +
" #define MT "+(int)Math.floor(M/ finalMW)+" // MT is max for 'mt' (M tile count) \n" +
" #define KT "+(int)Math.floor(K/ finalKW)+" // KT is max for 'kt' (K tile count) \n" +
" #define floatMW "+(finalMW != 1 ? "float"+ finalMW : "float")+" \n" +
" #define floatKW "+(finalKW != 1 ? "float"+ finalKW : "float")+" \n" +
" __kernel void "+kernelName+"( \n" +
" const __global floatMW* restrict A, \n" +
" const __global floatKW* restrict B, \n" +
" __global floatMW* C \n" +
" ) {{ \n" +
" size_t mt = get_global_id(0); //global M-tile id \n" +
" size_t nc = get_global_id(1); //global N-tile id \n" +
" size_t batch = get_global_id(2); \n" +
" \n" +
" float AT[KW][MW]; // sub tiles \n" +
" float BT[NW][KW]; \n" +
" float CT[NW][MW]; \n" +
" #pragma unroll \n" +
" for ( uint i=0; i<NW*MW; ++i ) // zero CT tile \n" +
" ((float*) CT)[i] = 0.0; \n" +
" for ( uint kt=0; kt<KT; ++kt ) // iterate over K-dim tiles \n" +
" {{ \n" +
" #pragma unroll \n" +
" for ( uint k=0; k<KW; ++k ) // every k-element inside K-dim tile \n" +
" *( (floatMW*) AT[k] ) = A[batch*K*MT + (kt*KW + k)*MT + mt]; // store M-Width floats\n" +
" #pragma unroll \n" +
" for ( uint n=0; n<NW; ++n ) // every n-element inside N-dim tile \n" +
" *( (floatKW*) BT[n] ) = B[batch*N*KT + (nc*NW + n)*KT + kt]; // store K-Width floats\n" +
" #pragma unroll \n" +
" for ( uint k=0; k<KW; ++k ) \n" +
" #pragma unroll \n" +
" for ( uint n=0; n<NW; ++n ) // sub tiles multiplication \n" +
" #pragma unroll \n" +
" for ( uint m=0; m<MW; ++m ) \n" +
" CT[n][m] += AT[k][m] * BT[n][k]; \n" +
" }} \n" +
" #pragma unroll \n" +
" for ( uint n = 0; n < NW; ++n ) \n" +
" C[ batch * N * MT + ( nc * NW + n ) * MT + mt ] += *( (floatMW*) CT[n] ); \n" +
" }} ";
KernelCaller caller =
call.getDevice().hasAdHocKernel(kernelName)
? call.getDevice().findAdHocKernel(kernelName).orElseThrow(()-> new RuntimeException("Kernel not found!"))
: call.getDevice().compileAndGetAdHocKernel(kernelName, code.get());
long[] local = null; // This kernel does not have local memory (uses register/private memory instead)
long[] global = new long[]{(long) Math.floor(M/MW), (long) Math.floor(N/NW), 1 };
caller.pass( a ).pass( b ).pass( c ).call( global, local );
return call.input(0);
}
}