CLMatMul.java
package neureka.backend.main.implementations.matmul;
import neureka.backend.main.implementations.SimpleCLImplementation;
import neureka.backend.main.operations.linear.internal.opencl.CLGEMM;
import neureka.ndim.config.NDConfiguration;
public class CLMatMul extends SimpleCLImplementation
{
public CLMatMul() {
super(
call -> {
if (
call.validate()
.all( t -> t.getNDConf().getLayout() == NDConfiguration.Layout.COLUMN_MAJOR )
.isValid()
) {
return new CLGEMM().run( call );
} else {
int M = call.input(1).shape(0);
int N = call.input(2).shape(1);
int K = call.input(1).shape(1);
call.getDevice()
.getKernel(call)
.pass(M).pass(N).pass(K)
.pass(call.input(Number.class, 1))
.pass(call.input(Number.class, 2))
.pass(call.input(Number.class, 0))
.call(new long[]{M, N}, null);
return call.input(0);
}
},
3,
"simple_matMul",
" __kernel void simple_matMul( \n" +
" const int M, const int N, const int K, \n" +
" const __global float* A, \n" +
" const __global float* B, \n" +
" __global float* C \n" +
" ) { \n" +
" const int m = get_global_id(0); // Row index of C (0..M) \n" +
" const int n = get_global_id(1); // Col index of C (0..N) \n" +
" \n" +
" // Compute a single element (loop over K) \n" +
" float acc = 0.0f; \n" +
" for ( int k = 0; k < K; k++ ) \n" +
" acc += A[ k + m * K ] * B[ n + k * N ]; \n" +
" \n" +
" // Store the result \n" +
" C[ n + m * N ] = acc; \n" +
" } \n"
);
}
}