The world of GPU programming for AI has come a long way since I worked on writing a CUDA-based matrix library back in 2009. Both NVIDIA hardware and the CUDA ecosystem have evolved dramatically and are now the basis for the majority of AI compute in the world today (hence the $4T+ market cap). Nevertheless, writing CUDA is still really hard, primarily because you need a good understanding of low-level mechanisms within the GPU (e.g. memory hierarchy, warp scheduling, memory coalescing) to produce performant code. I recently came across a project called Triton, which is a Python-based DSL that makes it easier to build high-performance GPU kernels. I ended up writing a handful of LLM-related kernels to understand Triton better and found it quite interesting, so I want to share a little bit about this technology.
| Improving performance of CUDA matrix multiplication kernel. Source: siboehm.com. |
To illustrate how Triton works, it's helpful to start out with one of the simplest GPU kernels and the backbone of modern deep learning: matrix multiplication (C = A * B). There is a really nice blog post written by a performance engineer at Anthropic that walks through what it takes to get matrix multiplication performance on par with a mature library like cuBLAS. It validates my point above -- you not only have to understand these low-level GPU concepts, you need to reason carefully about how they interact with your specific computation. Even for something as basic as matrix multiplication, this is quite complex. But at a high level, you can boil it down to: how do we move data from High Bandwidth Memory (HBM) through the cache hierarchy (e.g. Shared Memory) efficiently and then make sure we have high enough arithmetic intensity to avoid being memory-bound.
By contrast, Triton has a block-centric programming model where you're abstracted away from the details of threads, warps, shared memory, etc. Instead, you schedule the execution of programs across a grid (similar to the CUDA grid), and each program instance operates on "blocks" of data, i.e. sub-regions of tensors. In the official Triton tutorial for matrix multiplication, you still need to be aware of memory hierarchy and choose the correct ordering to compute the output C. But the coalescing, shared memory caching, blocktiling, vectorization, and warptiling get done behind-the-scenes as optimizations by the Triton compiler.
Here is the Triton program code from the tutorial:
Let's break it down. First, assume we're launching a grid of programs computing [BLOCK_SIZE_M, BLOCK_SIZE_N] blocks of the output C (we'll show the grid construction later). The above function computes one such block, indexed by the program ID: pid = tl.program_id(axis=0). The first chunk of code decides which block we are going to compute:
| |||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |||
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | |||
| group_id = pid // num_pid_in_group | |||
| first_pid_m = group_id * GROUP_SIZE_M | |||
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |||
| pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |||
| pid_n = (pid % num_pid_in_group) // group_size_m |
For reasons explained in the tutorial, there is a somewhat sophisticated decision for which block to compute in order to improve L2 cache reuse in the GPU. The key insight is that we can reuse the same blocks that we loaded from A and B across consecutive program instances so that they are warm in the L2 cache and don't require HBM reads. This brings to light the importance of understanding program scheduling when writing Triton -- you can't hide all of the complexity! This next section of code is a very Triton-esque pattern:
| # Create pointers for the first blocks of A and B. |
| # We will advance this pointer as we move in the K direction |
| # and accumulate |
| # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers |
| # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers |
| offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| offs_k = tl.arange(0, BLOCK_SIZE_K) |
| a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
Based on which block we decided (pid_m, pid_n), we calculate the blocks of A and B that we will perform the computation over. offs_am is an array of row offsets in A (tl.arange returns a range of indices), offs_bn is an array of column offsets in B, and offs_k is an array of column offsets in A and row offsets in B (the dimensions that we accumulate over in matrix multiplication). From these 1-D arrays of row and column offsets, we broadcast them to produce the 2-D arrays of row + column offsets in A and B that represent the entries we are going to use in the computation.
The final two lines show how, in Triton, we think of tensors as pointers (to the beginning of the data) and need to address entries within the tensor by their absolute position relative to the beginning. That's why this function needs to know the strides, or how much to advance the pointer to represent moving forward one row or column. Here's a visual representation of each of these variables:
| Visual representation of offsets and blocks in Triton matrix multiplication kernel. |
Once we've figured out the two blocks that we're performing the computation over, it's just a matter of actually computing the result:
| # Iterate to compute a block of the C matrix. |
| # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block |
| # of fp32 values for higher accuracy. |
| # `accumulator` will be converted back to fp16 after the loop. |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| # Load the next block of A and B, generate a mask by checking the K dimension. |
| # If it is out of bounds, set it to 0. |
| a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| # We accumulate along the K dimension. |
| accumulator = tl.dot(a, b, accumulator) |
| # Advance the ptrs to the next K block. |
| a_ptrs += BLOCK_SIZE_K * stride_ak |
| b_ptrs += BLOCK_SIZE_K * stride_bk |
| c = accumulator.to(tl.float16) |
This is your typical matrix multiplication inner loop, reframed using Triton block loads and pointer advancement. Note that masking is performed to handle cases where the dimensions are not multiples of our block sizes (another common pattern you'll see in both Triton and CUDA code). Finally, we store the result into the actual output location, as accumulator is a temporarily allocated block.
| # Write back the block of the output matrix C with masks. |
| offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| tl.store(c_ptrs, c, mask=c_mask) |
We do similar offset computations to identify the correct block of C to write to, and we again mask to handle non-divisible dimensions. Finally, we note that the function we examined is the code for a single program instance, and we still need to know the grid that we're executing these programs on. Here's how that is defined:
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
The main thing worth noting here is that the grid is not strictly a constant size. It takes advantage of Triton's autotune capability that searches over a list of options (specifically for BLOCK_SIZE_M, BLOCK_SIZE_N, and GROUP_SIZE_M in this case) to find the best configuration for a particular input size. This helps better align the program scheduling with the underlying hardware.
The Triton code is far from simple, but it does eliminate having to think about many of the harder aspects of GPU programming. I won't go into the details here, but Triton does this by compiling the DSL code into an LLVM intermediate representation where it can perform optimizations. Because the Triton compiler sees block-level access patterns rather than thread per thread-accesses, it can perform dataflow analysis on blocks and reason at a much higher level about what the code does. In doing so, it can decide when to prefetch memory, do hierarchical tiling (e.g. to fit the thread/warp model), schedule to coalesce memory accesses, manage and synchronize shared memory, and more. See the original paper (not up-to-date for Triton 2.0) for more details.
| Triton vs cuBLAS matrix multiplication performance. Source: triton-lang.org. |
For further examples, my triton-practice repository has a few well-documented implementations of core LLM building blocks like Rotary Positional Embeddings (RoPE) and Flash Attention.