Write a program that multiplies two matrices of 32-bit floating point numbers on a GPU. Given matrix A of dimensions MxN and matrix B of dimensions NxK, compute the product matrix C, which will have dimensions MxK. All matrices are stored in row-major format.
__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
}
// A, B, C are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K) {
dim3 threadsPerBlock(16, 16);
dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x,
(M + threadsPerBlock.y - 1) / threadsPerBlock.y);
matrix_multiplication_kernel<<>>(A, B, C, M, N, K);
cudaDeviceSynchronize();
} Global Indexes, revisited
Unlike the previous problem, we have multiple dimensions now. We need two global indexes:
- A global row index (i)
- A global column index (j)
i = (blockIdx.y * blockDim.y) + threadIdx.y
j = (blockIdx.x * blockDim.x) + threadIdx.xWith these two indicies, each thread is assigned to calculate one unique element .
Checks for integer division
The bounds check is written as follows:
if (i>=M || j>=K){
return;
}Core Matrix Multiplication
Now that a valid thread is assigned to calculate a unique element , we need to perform the dot product calculation of the i-th row of matrix A and the j-th column of matrix B.
To do this calculations, we need a loop that iterates N times with a local accumulator variable to store the dot product sum.
2D to 1D Indexing
Since the matrices are stored in row-major format, we can find the 1D index for an element in a matrix with columns:
Thus, we can calculate the index for :
$ idx_{A_{i,k}} = iN+k \newline idx_{B_{k,j}} = kK+j \newline idx_{C_{i,j}} = i*K+j \newline $
Solving
__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
// 1. Calculate i and j
int i = (blockIdx.y * blockDim.y) + threadIdx.y;
int j = (blockIdx.x * blockDim.x) + threadIdx.x;
// 2. Bounds check
if (i>=M || j>=K){
return;
}
// 3. Initialize accumulator
float d = 0.0;
// 4. Inner loop for dot product
for (int k=0;k