OpenAI 開源了全新的 GPU 編程語言 Triton,它能成為 CUDA 的替代品嗎?
過去十年中,深度神經(jīng)網(wǎng)絡(luò) (DNN) 已成為最重要的機(jī)器學(xué)習(xí)模型之一,創(chuàng)造了從自然語言處理到計(jì)算機(jī)視覺、計(jì)算神經(jīng)科學(xué)等許多領(lǐng)域的 SOTA 實(shí)現(xiàn)。DNN 模型的優(yōu)勢(shì)來自于它的層次結(jié)構(gòu),這一特征導(dǎo)致其計(jì)算量巨大,但也會(huì)產(chǎn)生大量高度并行化的工作,特別適合多核和眾核處理器。
深度學(xué)習(xí)領(lǐng)域的新研究思路往往是結(jié)合原生框架 operator 來實(shí)現(xiàn)的,這種方法雖然方便,但需要?jiǎng)?chuàng)建或移動(dòng)許多臨時(shí)張量,因此可能會(huì)造成神經(jīng)網(wǎng)絡(luò)的性能損失。編寫專門的 GPU 內(nèi)核或許可以解決這個(gè)問題,但 GPU 編程的確是一件相當(dāng)復(fù)雜的事。
DNN 計(jì)算潛力與 GPU 編程困難之間的矛盾由來已久。英偉達(dá)在 2007 年發(fā)布了 CUDA 的初始版本,CUDA 平臺(tái)是一個(gè)軟件層,使用者可以直接訪問 GPU 的虛擬指令集和并行計(jì)算單元,用于執(zhí)行計(jì)算內(nèi)核。近年來,主流深度學(xué)習(xí)框架幾乎都是基于 CUDA 進(jìn)行加速,英偉達(dá)也一直在完善 CUDA 工具包,但對(duì)于一般的開發(fā)者來說,CUDA 還是「不那么容易上手」。
今天,OpenAI 正式推出 Triton 1.0,這是一種類 Python 的開源編程語言。即使沒有 CUDA 經(jīng)驗(yàn)的研究人員,也能夠高效編寫 GPU 代碼。例如,它可以用不到 25 行代碼寫出與 cuBLAS 性能相匹配的 FP16 矩陣乘法內(nèi)核,后者是許多專業(yè)的 GPU 編程者尚且無法做到的。此外,OpenAI 的研究者已經(jīng)使用 Triton 成功生成了比 PyTorch 同類實(shí)現(xiàn)效率高 2 倍的內(nèi)核。
代碼地址:https://github.com/openai/triton
Triton 的最初想法來源于現(xiàn)任 OpenAI 科學(xué)家的 Philippe Tillet 2019 年在哈佛大學(xué)攻讀研究生學(xué)位時(shí)發(fā)表的一篇論文,當(dāng)時(shí)他的導(dǎo)師是 H. T. Kung 和 David Cox。
論文鏈接:http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf
Tillet 希望解決的問題是打造一種比英偉達(dá)的 CUDA 等特定供應(yīng)商庫更好用的庫,能夠處理神經(jīng)網(wǎng)絡(luò)中涉及矩陣的各種操作,具備可移植性,且性能可與 cuDNN 或類似的供應(yīng)商庫相媲美。團(tuán)隊(duì)表示:「直接用 CUDA 進(jìn)行 GPU 編程太難了,比如為 GPU 編寫原生內(nèi)核或函數(shù)這件事,會(huì)因?yàn)?GPU 編程的復(fù)雜性而出奇困難?!?/p>
Facebook AI 研究中心科學(xué)家 Soumith Chintala 也在推特上表達(dá)了自己對(duì) Triton 的期待:
新發(fā)布的 Triton 可以為一些核心的神經(jīng)網(wǎng)絡(luò)任務(wù)(例如矩陣乘法)提供顯著的易用性優(yōu)勢(shì)?!肝覀兊哪繕?biāo)是使其成為深度學(xué)習(xí) CUDA 的可行替代方案,」Philippe Tillet 作為 Triton 項(xiàng)目負(fù)責(zé)人如此表示。
GPU 編程面臨的挑戰(zhàn)
現(xiàn)代 GPU 的架構(gòu)大致可以分為三個(gè)主要組件:DRAM、SRAM 和 ALU。優(yōu)化 CUDA 代碼時(shí),必須考慮到每一個(gè)組件:
來自 DRAM 的內(nèi)存?zhèn)鬏敱仨毢喜⑦M(jìn)大型事務(wù),以利用現(xiàn)代內(nèi)存接口的總線位寬;
必須在數(shù)據(jù)重新使用之前手動(dòng)存儲(chǔ)到 SRAM 中,并進(jìn)行管理以最大限度地減少檢索時(shí)共享內(nèi)存庫沖突;
計(jì)算必須在流處理器(SM)內(nèi)部或之間細(xì)致分區(qū)和調(diào)度,以促進(jìn)指令 / 線程級(jí)的并行以及專用算術(shù)邏輯單元(ALU)的利用。
GPU 基礎(chǔ)架構(gòu)。
種種因素導(dǎo)致 GPU 編程難度驟增,即使對(duì)于具有多年經(jīng)驗(yàn)的 CUDA 程序員也是如此。Triton 的目的是將這些優(yōu)化過程自動(dòng)化,以此讓開發(fā)人員更專注于并行代碼的高級(jí)邏輯。出于對(duì)泛用能力的考量,Triton 不會(huì)自動(dòng)調(diào)度跨流處理器的工作,而是將一些重要的算法考慮因素(例如 tiling、SM 間同步)留給開發(fā)者自行決定。
CUDA vs Triton 編譯器優(yōu)化對(duì)比。
編程模型
在所有可用的領(lǐng)域?qū)S谜Z言和 JIT 編譯器中,Triton 或許與 Numba 最相似:內(nèi)核被定義為修飾過的 Python 函數(shù),并與實(shí)例網(wǎng)格上不同的 program_id 的同時(shí)啟動(dòng)。但不同之處值得注意:如下圖代碼片段所示,Triton 通過對(duì) block 的操作來展示 intra-instance 并行,此處 block 是維數(shù)為 2 的冪的數(shù)組,而不是單指令多線程(SIMT)執(zhí)行模型。如此一來,Triton 高效地抽象出了與 CUDA 線程 block 內(nèi)的并發(fā)相關(guān)的所有問題(比如內(nèi)存合并、共享內(nèi)存同步 / 沖突、張量核心調(diào)度)。
Triton 中的向量加法。
雖然這對(duì) embarrassingly 并行(即 element-wise)計(jì)算可能沒什么幫助,但是可以簡(jiǎn)化更復(fù)雜的 GPU 程序的開發(fā)。例如,在融合 softmax 核的情況下,對(duì)于每個(gè)輸入張量 X∈R^M×N 來說,每個(gè)實(shí)例對(duì)給定輸入張量的不同行進(jìn)行歸一化。這種并行化策略的標(biāo)準(zhǔn) CUDA 實(shí)現(xiàn)可能難以編寫,需要線程之間的顯式同步,因?yàn)檫@種策略并發(fā)地減少 X 的同一行。而 Triton 很大程度上消除了這種復(fù)雜性,每個(gè)內(nèi)核實(shí)例加載感興趣的行,并使用類似 NumPy 的原語順序?qū)ζ溥M(jìn)行規(guī)范化。
import triton
import triton.language as tl
@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# row index
m = tl.program_id(0)
# col indices
# this specific kernel only works for matrices that
# have less than BLOCK_SIZE columns
BLOCK_SIZE = 1024
n = tl.arange(0, BLOCK_SIZE)
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n * stride_xn
# load input data; pad out-of-bounds elements with 0
x = tl.load(X, mask=n < N, other=-float(‘inf’))
# compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# write back to Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)
import torch
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda‘)
Y = torch.empty_like(X)
# SPMD launch grid
grid = (X.shape[0], )
# enqueue GPU kernel
softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0] , X.shape[1])
在 Triton 中融合 softmax
Triton JIT 把 X、Y 當(dāng)作指針而不是張量。最重要的是,softmax 這種特殊實(shí)現(xiàn)方式在整個(gè)規(guī)范化過程中保持 SRAM 中 X 的行不變,從而在適用時(shí)最大限度地實(shí)現(xiàn)數(shù)據(jù)重用(約 32K 列)。這與 PyTorch 的內(nèi)部 CUDA 代碼不同,后者使用臨時(shí)內(nèi)存使其更通用,但速度明顯變慢(見下圖)。
融合 softmax、M=4096 的 A100 性能。
Torch (v1.9) JIT 較低的性能突出了從高級(jí)張量操作序列自動(dòng)生成 CUDA 代碼的難度。
@torch.jit.script
def softmax(x):
x_max = x.max(dim=1)[0]
z = x - x_max[:, None]
numerator = torch.exp(x)
denominator = numerator.sum(dim=1)
return numerator / denominator[:, None]
融合 softmax 與 Torch JIT
矩陣乘法
能夠?yàn)樵夭僮鳎╡lement-wise operation)和規(guī)約操作(reduction operation)編寫融合內(nèi)核是很重要的,但考慮到神經(jīng)網(wǎng)絡(luò)中矩陣乘法的重要性,這還不夠。事實(shí)證明,Triton 在這些方面表現(xiàn)很好,僅用大約 25 行 Python 代碼就能達(dá)到最佳性能。相比之下,CUDA 效率就沒有那么高了。
Triton 中的矩陣乘法。
手寫矩陣乘法內(nèi)核的一個(gè)重要優(yōu)點(diǎn)是它們可以根據(jù)需要進(jìn)行定制,以適應(yīng)其輸入(例如切片)和輸出(例如 Leaky ReLU)的融合變換。假如不存在 Triton 這樣的系統(tǒng),那么對(duì)于沒有出色的 GPU 編程專業(yè)知識(shí)的開發(fā)人員來說,矩陣乘法內(nèi)核將很難大改。
高級(jí)系統(tǒng)架構(gòu)
Triton 的良好性能得益于以 Triton-IR 為中心的模塊化系統(tǒng)架構(gòu)。Triton-IR 是一種基于 LLVM 的中間表示,多維值塊(blocks of values)是其中最重要的東西。
Triton 的高級(jí)架構(gòu)。
@triton.jit 裝飾器的工作原理是遍歷由 Python 函數(shù)提供的抽象語法樹(AST),這樣一來就能使用通用的 SSA 構(gòu)造算法實(shí)時(shí)生成 Triton-IR。生成的 IR 代碼隨后由編譯器后端進(jìn)行簡(jiǎn)化、優(yōu)化和自動(dòng)并行化,然后轉(zhuǎn)換為高質(zhì)量的 LLVM-IR,最終轉(zhuǎn)換為 PTX,以便在最新的 NVIDIA GPU 上執(zhí)行。目前 Triton 還不支持 CPU 和 AMD GPU,但團(tuán)隊(duì)表示對(duì)二者的支持正在開發(fā)中。
編譯器后端
研究人員發(fā)現(xiàn)通過 Triton-IR 來使用塊狀程序表示,這種方法允許編譯器自動(dòng)執(zhí)行各種重要的程序優(yōu)化。例如,通過查看計(jì)算密集型塊級(jí)操作(例如 tl.dot)的操作數(shù),數(shù)據(jù)可以自動(dòng)存儲(chǔ)到共享內(nèi)存中,并使用標(biāo)準(zhǔn)的活躍性分析技術(shù)進(jìn)行數(shù)據(jù)的分配與同步。
Triton 編譯器通過分析計(jì)算密集型操作中使用的塊變量的活動(dòng)范圍來分配共享內(nèi)存。
此外,Triton 還可以在 SM 之間以及 SM 之內(nèi)高效、自動(dòng)地并行化,前者通過并發(fā)執(zhí)行不同的內(nèi)核實(shí)例來實(shí)現(xiàn),后者通過分析每個(gè)塊級(jí)操作的迭代空間,并將其充分劃分到不同的 SIMD 單元來實(shí)現(xiàn)。如下所示:
Triton 自動(dòng)并行化。每個(gè)塊級(jí)操作都定義了一個(gè)塊級(jí)迭代空間,該空間可以自動(dòng)并行化以利用 SM(Streaming Multiprocessor) 上的可用資源。