Notes on reading "How to Optimize a CUDA Matmul Kernel"
Some quick notes and commentary from reading this article on writing CUDA kernels from matmuls.
This article was pretty daunting, even having watched the CS336 lectures on GPU optimizations. I spent probably a collective 5 hours reading this material alongside another Recurser.
more references
kernel 1
This one is the most straightforward implementation, where the block structure inside the grid matches 1:1 with the output shape of C. There aren't optimizations being done.
kernel 2
The key trick here that isn’t very explicitly stated is the block dimension is now “unrolled” and flattened into a 1D vector of 1024 (32*32) length, instead of having a 2D vector of 32 by 32. Then, the x and y for a given thread are set such that the threadId x dimension aligns with the warp and aligns with the memory access that’s coalesced.
One other interesting part here is that the coalescing is automatic and implicit to the memory access pattern. I thought this might be possible due to the threads in a warp sharing a PC register, i.e. the instructions are running in lockstep. Claude tells me that newer GPU architecture stopped doing that and coalescing is effectively best effort, though it seems like it generally works. This seemed counterintuitive to my general experience with batching behavior where you have to be rather intentional about aligning batched access to not get thrashing behavior, but I can how this works since GPUs are simpler and perhaps better behaved.
kernel 3
This kernel introduces a simple form of tiling. The code for this section was a bit tricky to read because the definitions for some variables is elided, but my guess is that As
and Bs
are arrays instantiated with __shared__
and threadCol
is set to threadIdx.x
. I feel like threadCol
must always be 0 for things to work?
Looking at the algorithm, there’s two major steps that are separated by __syncthreads
. First, all threads in the block fill the shared memory with the current tile. Then in the second phase, the contents of the tile are added up to tmp
. The whole thing is stored in C
at the end like the previous kernels.
occupancy calculation This part had a somewhat surprising result to me - the kernel is bottlenecked on smem load/store, even though the chip can do 12TB/s. I guess the rough issue is that even if you’re doing very fast read/write, memory load time still dwarfs computation which is even cheaper.
I was also surprised to see that we can’t get better warp utilization, since the kernel is using too many registers.
kernel 4
This one really confused me because the code has quite a bit of setup code elided. I think looking at the full program in the godbolt link helps a lot (and probably would have helped in previous steps too).
The rough idea here seems to be that instead of each thread outputting only one cell of C, we can have threads output multiple threads at a time given some clever use of what’s already loaded into shared mem for a block. Each thread still loops through the entirety of K (the inner dimension), i.e. loops through one column of B but all columns of A in a tile. Threads in a warp work on the neighboring columns of B. We now also need to write out to threadResults
as a temporary buffer because we need to go through all of the blocks before we can write once out to each spot in the output C.
A question I had: does this lead to fewer threads running for a fixed shape of matrix inputs? Presumably yes, because each thread is doing more work.
kernel 5
The same idea as before, but done in both directions this time. We now bring values from smem into registers because it’s even faster.
One thing we noticed while analyzing this kernel is that it seems like we do a lot more work in registers, so we thought it might actually use allow for less occupancy. However, in godbolt, the compiler actually shows:
ptxas info : Used 128 registers, 8192 bytes smem, 400 bytes cmem[0]
Only 128 registers used. Looking at the program, this kind of makes sense because there’s 8 registers each for regM
and regN
, then another 8x8 register for threadResults
, leaving us with ~48 registers for other intermediate computation. Super cool! This kernel uses ~3.5x more registers but gets 64x more work done with them.
kernel 6
I originally thought this was similar to the change in kernel 2 with global memory coalescing, but it’s slightly different. The coalescing behavior in kernel 2 is across threads, taking advantage of the burst group. This optimization is within the thread as an explicit group load.
I am vaguely surprised this optimization helps achieve 20% speedup (though, also with transposing As in memory), since I would have expected coalescing to have given most gains.
Looking at the godbolt for both made things more obvious.
Compare the non-vectorized load:
with the vectorized load:
Just by amount of instructions issued, the SM is doing a lot more work to load memory in kernel 5. If we look more closely at the assembly for kernel 5, we can see that it has to do a series of add
and mul
for each ld.global
to figure out the indexing at runtime. This made a lot more sense to me - the compiler doesn't figure out that we’re just accessing adjacent memory locations and has to do a bunch of work at runtime, whereas the vectorized operation can just have a go at it.
kernel 7 and 8
The author skips these since they didn’t seem to work, but they were attempts at avoiding shared-memory bank conflicts. I believe this is when threads touch the same shared memory at a given time, the SM has to go slower to serialize those operations and avoid concurrent read/write.
kernel 9
This one is kind of simple but cool to see. We do a sort of hyperparameter search to figure out what constants are best for this GPU. I kind of imagine there’s some matrix of cards sitting around in the office at Nvidia so CUDA developers can autotune the whole product line together.
kernel 10
At this point of the post, I was kind of mentally drained. My high-level takeaway is that we haven’t optimized for warps yet, even though warps can give us some even better locality by doing things like sharing a register cache, and not doing conflicting ops on shared memory.
I don’t fully understand this one but the vague intuition seems to be that this kernel shapes the warps into tiles so that threads in the same warp won’t try to touch the same shared memory locations at the same time (“conflicting shared bank access”). If we didn’t do this, then threads in a warp might be arbitrarily bunched into access that go across the same row in the A matrix, which is slower.