Skip to main content

July 28, 2021

Introducing Triton: Open-source GPU programming for neural networks

Introducing Triton Open Source Gpu Programming For Neural Networks

We’re releasing Triton 1.0, an open-source Python-like programming language which enables researchers with no CUDA experience to write highly efficient GPU code—most of the time on par with what an expert would be able to produce.

Why it matters

Triton makes it possible to reach peak hardware performance with relatively little effort; for example, it can be used to write FP16 matrix multiplication kernels that match the performance of cuBLAS—something that many GPU programmers can’t do—in under 25 lines of code. Our researchers have already used it to produce kernels that are up to 2x more efficient than equivalent Torch implementations, and we’re excited to work with the community to make GPU programming more accessible to everyone.

Novel research ideas in the field of Deep Learning are generally implemented using a combination of native framework operators. While convenient, this approach often requires the creation (and/or movement) of many temporary tensors, which can hurt the performance of neural networks at scale. These issues can be mitigated by writing specialized GPU kernels, but doing so can be surprisingly difficult due to the many intricacies of GPU programming.1, 2, 3 And, although a variety of systems have recently emerged4, 5 to make this process easier, we have found them to be either too verbose, lack flexibility or generate code noticeably slower than our hand-tuned baselines. This has led us to extend and improve Triton6, a recent language and compiler whose original creator now works at OpenAI.

The challenges of GPU programming

The architecture of modern GPUs can be roughly divided into three major components—DRAM, SRAM and ALUs—each of which must be considered when optimizing CUDA code:

  • Memory transfers from DRAM must be coalesced into large transactions to leverage the large bus width of modern memory interfaces.

  • Data must be manually stashed to SRAM prior to being re-used, and managed so as to minimize shared memory bank conflicts upon retrieval.

  • Computations must be partitioned and scheduled carefully, both across and within Streaming Multiprocessors (SMs), so as to promote instruction/thread-level parallelism and leverage special-purpose ALUs (e.g., tensor cores).

GPU Architecture

Basic architecture of a GPU.

Reasoning about all these factors can be challenging, even for seasoned CUDA programmers with many years of experience. The purpose of Triton is to fully automate these optimizations, so that developers can better focus on the high-level logic of their parallel code. Triton aims to be broadly applicable, and therefore does not automatically schedule work across SMs -- leaving some important algorithmic considerations (e.g. tiling, inter-SM synchronization) to the discretion of developers.

CUDA

TRITON

Memory Coalescing

Manual

Automatic

Shared Memory Management

Manual

Automatic

Scheduling (Within SMs)

Manual

Automatic

Scheduling (Across SMs)

Manual

Manual

Compiler optimizations in CUDA vs Triton.

Programming model

Out of all the Domain Specific Languages and JIT-compilers available, Triton is perhaps most similar to Numba: kernels are defined as decorated Python functions, and launched concurrently with different program_id’s on a grid of so-called instances. However, as shown in the code snippet below, the resemblance stops there: Triton exposes intra-instance parallelism via operations on blocks—small arrays whose dimensions are powers of two—rather than a Single Instruction, Multiple Thread (SIMT)7 execution model. In doing so, Triton effectively abstracts away all the issues related to concurrency within CUDA thread blocks (e.g., memory coalescing, shared memory synchronization/conflicts, tensor core scheduling).

Loading...

While this may not be particularly helpful for embarrassingly parallel (i.e., element-wise) computations, it can greatly simplify the development of more complex GPU programs.

Consider for example the case of a fused softmax kernel (below) in which each instance normalizes a different row of the given input tensor X_∈R_M_×_N. Standard CUDA implementations of this parallelization strategy can be challenging to write, requiring explicit synchronization between threads as they concurrently reduce the same row of X. Most of this complexity goes away with Triton, where each kernel instance loads the row of interest and normalizes it sequentially using NumPy-like primitives.

python
1
import triton
2
import triton.language as tl
3

4
@triton.jit
5
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
6
# row index
7
m = tl.program_id(0)
8
# col indices
9
# this specific kernel only works for matrices that
10
# have less than BLOCK_SIZE columns
11
BLOCK_SIZE = 1024
12
n = tl.arange(0, BLOCK_SIZE)
13
# the memory address of all the elements
14
# that we want to load can be computed as follows
15
X = X + m * stride_xm + n * stride_xn
16
# load input data; pad out-of-bounds elements with 0
17
x = tl.load(X, mask=n < N, other=-float('inf'))
18
# compute numerically-stable softmax
19
z = x - tl.max(x, axis=0)
20
num = tl.exp(z)
21
denom = tl.sum(num, axis=0)
22
y = num / denom
23
# write back to Y
24
Y = Y + m * stride_ym + n * stride_yn
25
tl.store(Y, y, mask=n < N)
26

27
import torch
28
# Allocate input/output tensors
29
X = torch.normal(0, 1, size=(583, 931), device='cuda')
30
Y = torch.empty_like(X)
31
# SPMD launch grid
32
grid = (X.shape[0], )
33
# enqueue GPU kernel
34
softmax[grid](Y, Y.stride(0), Y.stride(1),
35
X, X.stride(0), X.stride(1),
36
X.shape[0] , X.shape[1])

Note that the Triton JIT treats X and Y as pointers rather than tensors; we felt like retaining low-level control of memory accesses was important to address more complex data structures (e.g., block-sparse tensors).

Importantly, this particular implementation of softmax keeps the rows of X in SRAM throughout the entire normalization process, which maximizes data reuse when applicable (~<32K columns). This differs from PyTorch’s internal CUDA code, whose use of temporary memory makes it more general but significantly slower (below). The bottom line here is not that Triton is inherently better, but that it simplifies the development of specialized kernels that can be much faster than those found in general-purpose libraries.

Loading...

The lower performance of the Torch (v1.9) JIT highlights the difficulty of automatic CUDA code generation from sequences of high-level tensor operations.

python
1
@torch.jit.script
2
def softmax(x):
3
x_max = x.max(dim=1)[0]
4
z = x - x_max[:, None]
5
numerator = torch.exp(x)
6
denominator = numerator.sum(dim=1)
7
return numerator / denominator[:, None]
Fused softmax with the Torch JIT.

Matrix multiplication

Being able to write fused kernels for element-wise operations and reductions is important, but not sufficient given the prominence of matrix multiplication tasks in neural networks. As it turns out, Triton also works very well for those, achieving peak performance with just ~25 lines of Python code. On the other hand, implementing something similar in CUDA would take a lot more effort(opens in a new window) and would even be likely to achieve lower performance.

python
1
@triton.jit
2
def matmul(A, B, C, M, N, K, stride_am, stride_ak,
3
stride_bk, stride_bn, stride_cm, stride_cn,
4
**META):
5
# extract metaparameters
6
BLOCK_M, GROUP_M = META['BLOCK_M'], META['GROUP_M']
7
BLOCK_N = META['BLOCK_N']
8
BLOCK_K = META['BLOCK_K']
9
# programs are grouped together to improve L2 hit rate
10
_pid_m = tl.program_id(0)
11
_pid_n = tl.program_id(1)
12
pid_m = _pid_m // GROUP_M
13
pid_n = (_pid_n * GROUP_M) + (_pid_m % GROUP_M)
14
# rm (resp. rn) denotes a range of indices
15
# for rows (resp. col) of C
16
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
17
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
18
# rk denotes a range of indices for columns
19
# (resp. rows) of A (resp. B)
20
rk = tl.arange(0, BLOCK_K)
21
# the memory addresses of elements in the first block of
22
# A and B can be computed using numpy-style broadcasting
23
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
24
B = B + (rk [:, None] * stride_bk + rn[None, :] * stride_bn)
25
# initialize and iteratively update accumulator
26
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
27
for k in range(K, 0, -BLOCK_K):
28
a = tl.load(A)
29
b = tl.load(B)
30
# block level matrix multiplication
31
acc += tl.dot(a, b)
32
# increment pointers so that the next blocks of A and B
33
# are loaded during the next iteration
34
A += BLOCK_K * stride_ak
35
B += BLOCK_K * stride_bk
36
# fuse leaky ReLU if desired
37
# acc = tl.where(acc >= 0, acc, alpha * acc)
38
# write back result
39
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
40
mask = (rm[:, None] < M) & (rn[None, :] < N)
41
tl.store(C, acc, mask=mask)
Matrix multiplication in Triton.

One important advantage of handwritten matrix multiplication kernels is that they can be customized as desired to accommodate fused transformations of their inputs (e.g., slicing) and outputs (e.g., Leaky ReLU). Without a system like Triton, non-trivial modifications of matrix multiplication kernels would be out-of-reach for developers without exceptional GPU programming expertise.

Loading...

High-level system architecture

The good performance of Triton comes from a modular system architecture centered around Triton-IR, an LLVM-based intermediate representation in which multi-dimensional blocks of values are first-class citizens.

Loading...

The @triton.jit decorator works by walking the Abstract Syntax Tree (AST) of the provided Python function so as to generate Triton-IR on-the-fly using a common SSA construction algorithm.8 The resulting IR code is then simplified, optimized and automatically parallelized by our compiler backend, before being converted into high-quality LLVM-IR—and eventually PTX—for execution on recent NVIDIA GPUs. CPUs and AMD GPUs are not supported at the moment, but we welcome community contributions aimed at addressing this limitation.

Compiler backend

We have found that the use of blocked program representations via Triton-IR allows our compiler to automatically perform a wide variety of important program optimizations. For example, data can be automatically stashed to shared memory by looking at the operands of computationally intensive block-level operations (e.g., tl.dot)—and allocated/synchronized using standard liveness analysis techniques.

Loading...

On the other hand, Triton programs can be efficiently and automatically parallelized both (1) across SMs by executing different kernel instances concurrently, and (2) within SMs by analyzing the iteration space of each block-level operation and partitioning it adequately across different SIMD units, as shown below.

Loading...

Contributing

We intend for Triton to become a community-driven project. Feel free to fork our repository on GitHub(opens in a new window)!

If you’re interested in joining our team and working on Triton & GPU kernels, we’re hiring!

References

  1. 1

    Gray, S. (2017). SGEMM Walkthrough(opens in a new window)URL https://github.com/NervanaSystems/maxas/wiki/SGEMM(opens in a new window).

  2. 2

    Kerr, A. (2020). Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100(opens in a new window)URL https://developer.nvidia.com/gtc/2020/video/s21745-vid(opens in a new window).

  3. 3

    Yan, D., Wang, W., & Chu, X. (2020, May). Demystifying tensor cores to optimize half-precision matrix multiply(opens in a new window). In 2020 IEEE International Parallel and Distributed Processing Symposium (IPDPS). IEEE.

  4. 4

    NVIDIA CUTLASS(opens in a new window)

  5. 5

    Apache TVM(opens in a new window)

  6. 6

    Tillet, P., Kung, H. T., & Cox, D. (2019, June). Triton: an intermediate language and compiler for tiled neural network computations(opens in a new window). In Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages (pp. 10-19).

  7. 7

    Lin, Y. & Grover, V. (2018). Using CUDA Warp-Level Primitives(opens in a new window)URL https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/(opens in a new window).

  8. 8

    Braun, M., Buchwald, S., Hack, S., Leißa, R., Mallon, C., & Zwinkau, A. (2013, March). Simple and efficient construction of static single assignment form(opens in a new window). In International Conference on Compiler Construction (pp. 102-122). Springer, Berlin, Heidelberg.

Author

Philippe Tillet

Acknowledgments

Da Yan (HKUST), DeepSpeed (Microsoft), Anthropic