有時(shí)候,好的訓(xùn)練「技巧」比蠻力堆參更有效。
現(xiàn)階段,視覺 transformer(ViT)模型已經(jīng)在圖像分類、目標(biāo)檢測(cè)與分割等各樣各樣的計(jì)算機(jī)視覺任務(wù)中得到了廣泛應(yīng)用,并可以在視覺表征與識(shí)別中實(shí)現(xiàn) SOTA 結(jié)果。由于計(jì)算機(jī)視覺模型的性能往往與參數(shù)量和訓(xùn)練時(shí)長(zhǎng)呈正相關(guān),AI 社區(qū)已經(jīng)實(shí)驗(yàn)了越來(lái)越大規(guī)模的 ViT 模型。
但應(yīng)看到,隨著模型開始超出萬(wàn)億次浮點(diǎn)運(yùn)算的規(guī)模,該領(lǐng)域已經(jīng)遇到了一些主要的瓶頸。訓(xùn)練單個(gè)模型可能耗費(fèi)數(shù)月,需要數(shù)以千塊的 GPU,進(jìn)而增加了加速器需求并導(dǎo)致大規(guī)模 ViT 模型將很多從業(yè)者「排除在外」。
為了擴(kuò)展 ViT 模型的使用范圍,Meta AI 的研究者已經(jīng)開發(fā)出了更高效的訓(xùn)練方法。非常重要的一點(diǎn)是對(duì)訓(xùn)練進(jìn)行優(yōu)化以實(shí)現(xiàn)最佳的加速器利用。但是,這一過(guò)程耗時(shí)費(fèi)力且需要大量的專業(yè)知識(shí)。為了設(shè)置有序的實(shí)驗(yàn),研究者必須從無(wú)數(shù)可能的優(yōu)化方案中進(jìn)行選擇:一次訓(xùn)練過(guò)程中執(zhí)行的百萬(wàn)次運(yùn)算中的任何一個(gè)都有可能受到低效率的影響和阻礙。
Meta AI 發(fā)現(xiàn),通過(guò)將一系列優(yōu)化應(yīng)用到其圖像分類代碼庫(kù) PyCls 中的 ViT 實(shí)現(xiàn),可以提升計(jì)算和存儲(chǔ)效率。對(duì)于使用 PyCIs 訓(xùn)練的 ViT 模型,Meta AI 的方法可以提升訓(xùn)練速度和每加速器吞吐量(TFLOPS)。
下圖展示了使用優(yōu)化代碼庫(kù) PyCIs 后每芯片(per chip)加速器吞吐量相較于 V100 基準(zhǔn)的相對(duì)增加,而 A100 優(yōu)化的加速器吞吐量是 V100 基準(zhǔn)的 4.05 倍。
運(yùn)行原理
Meta AI 首先對(duì) PyCIs 代碼庫(kù)進(jìn)行分析以確認(rèn)低訓(xùn)練效率的潛在來(lái)源,最終將注意力放在了對(duì)數(shù)字格式的選擇上。在默認(rèn)情況下,大多數(shù)應(yīng)用使用 32-bit 單精度浮點(diǎn)格式來(lái)表征神經(jīng)網(wǎng)絡(luò)值。轉(zhuǎn)換至 16-bit 半精度格式(FP16)可以減少模型的內(nèi)存占用和執(zhí)行時(shí)間,但往往也會(huì)降低準(zhǔn)確率。
研究者采取了折中方案,即混合精度。利用它,系統(tǒng)通過(guò)單精度格式執(zhí)行計(jì)算以加速訓(xùn)練并減少內(nèi)存使用,同時(shí)通過(guò)單精度存儲(chǔ)結(jié)果以保持準(zhǔn)確率。他們沒(méi)有手動(dòng)地將部分網(wǎng)絡(luò)轉(zhuǎn)換至半精度,而是實(shí)驗(yàn)了不同模式的自動(dòng)混合精度訓(xùn)練,這樣可以在數(shù)字格式之間自動(dòng)切換。更高級(jí)模式的自動(dòng)混合精度主要依賴半精度運(yùn)算和模型權(quán)重。研究者采用的平衡設(shè)置既能大幅度加速訓(xùn)練,同時(shí)也不犧牲準(zhǔn)確率。
為了使流程更加高效,研究者充分利用了 FairScale 庫(kù)中的完全分片數(shù)據(jù)并行(Fully Sharder Data Parallel, FSDP)訓(xùn)練算法,它在 GPU 上對(duì)參數(shù)、梯度和優(yōu)化器狀態(tài)進(jìn)行分片。通過(guò) FSDP 算法,研究者可以使用更少的 GPU 構(gòu)建更大量級(jí)的模型。此外,研究者還使用了 MTA 優(yōu)化器、一個(gè)池化的 ViT 分類器和一個(gè) batch-second 輸入張量布局來(lái)跳過(guò)冗余轉(zhuǎn)置運(yùn)算。
下圖 X 軸為可能的優(yōu)化,Y 軸為采用 ViT-H/16 訓(xùn)練時(shí)加速器吞吐量相較于分布式數(shù)據(jù)并行(DDP)基準(zhǔn)的相對(duì)增加。
研究者在總 patch 大小為 560 時(shí)實(shí)現(xiàn)了 1.51 倍的加速器吞吐量提升,以每個(gè)加速器芯片上每秒執(zhí)行的浮點(diǎn)運(yùn)算數(shù)量衡量。通過(guò)將圖像大小從 224 像素增加至 256 像素,他們可以將吞吐量提升至 1.86 倍。但是,改變圖像大小意味著超參數(shù)的變化,這會(huì)對(duì)模型的準(zhǔn)確率造成影響。在完全 FP16 模式下訓(xùn)練時(shí),相對(duì)吞吐量增加至 2.18 倍。盡管有時(shí)會(huì)降低準(zhǔn)確率,但在實(shí)驗(yàn)中準(zhǔn)確率降低少于 10%。
下圖 Y 軸為 epoch 時(shí)間,在整個(gè) ImageNet-1K 數(shù)據(jù)集上一次訓(xùn)練的持續(xù)時(shí)間。這里專注于現(xiàn)有配置的實(shí)際訓(xùn)練時(shí)間,這些配置通常使用 224 像素的圖像大小。
Meta AI 的研究者使用優(yōu)化方案,將 epoch 時(shí)間(在整個(gè) ImageNet-1K 數(shù)據(jù)集上一次訓(xùn)練的持續(xù)時(shí)間)從 0.65 小時(shí)減少到 0.43 小時(shí)。
下圖 X 軸表示特定配置中 A100 GPU 加速器芯片的數(shù)量,Y 軸表示每芯片 TFLOPS 的絕對(duì)吞吐量。
該研究還討論了不同 GPU 配置的影響。在每種情況下,系統(tǒng)都實(shí)現(xiàn)了比分布式數(shù)據(jù)并行(DDP)基線水平更高的吞吐量。隨著芯片數(shù)量的增加,由于設(shè)備間通信的開銷,我們可以觀察到吞吐量略有下降。然而,即使用 64 塊 GPU,Meta 的系統(tǒng)也比 DDP 基準(zhǔn)快 1.83 倍。
新研究的意義
將 ViT 訓(xùn)練中可實(shí)現(xiàn)的吞吐量翻倍可以有效讓訓(xùn)練集群規(guī)模翻倍,提高加速器利用率直接減少了 AI 模型的碳排放。由于最近大模型的發(fā)展帶來(lái)了更大模型和更長(zhǎng)訓(xùn)練時(shí)間的趨勢(shì),這種優(yōu)化有望幫助研究領(lǐng)域進(jìn)一步推動(dòng)最先進(jìn)的技術(shù),縮短周轉(zhuǎn)時(shí)間并提高生產(chǎn)力。