斯坦福大學CS博士新作:新型Attention提速2-4倍,BERT單節點訓練最快

語言: CN / TW / HK

一種快速、記憶體高效的注意力演算法來了,被命名為 FlashAttention。通過減少 GPU 記憶體讀取 / 寫入,FlashAttention 的執行速度比 PyTorch 標準注意力快 2-4 倍,所需記憶體減少 5-20 倍。

這項研究由斯坦福大學、紐約州立大學布法羅分校的研究者共同完成。共同一作是兩位斯坦福計算機博士生 Tri Dao 和 Dan Fu。

下面我們介紹一下論文具體內容。

FlashAttention

Transformer 已然成為自然語言處理和影象分類等應用中最廣泛使用的架構。隨著研究的不斷前進,Transformer 尺寸變得越來越大、層數也越來越深,但是給 Transformer 配備更長的上下文仍然很困難,因為 Transformer 核心自注意力模組的時間複雜度以及記憶體複雜度在序列長度上是二次方的。

有研究者提出一些近似注意力的方法,旨在減少注意力計算和記憶體需求。這些方法包括稀疏近似、低秩近似以及它們的組合。從序列長度來看,儘管這些方法可以將計算降低到線性或接近線性,但它們並沒有顯示出針對標準注意力的 wall-clock 加速,因而沒有被廣泛使用。這其中一個主要原因是這些研究專注於減少 FLOP(這可能與 wall-clock 速度無關)並且傾向於忽略來自記憶體訪問 (IO) 的開銷。

在本文中,該研究認為應該讓注意力演算法具有 IO 感知——即考慮視訊記憶體級間的讀寫。現代 GPU 計算速度超過了記憶體速度,transformer 中的大多數操作都被記憶體訪問所阻塞。IO 感知演算法對於類似的記憶體繫結操作至關重要,這種重要性體現在當讀寫資料佔據很大執行時——例如資料庫連線、影象處理、數值線性代數等。然而,用於深度學習的常見 Python 介面,如 PyTorch 和 Tensorflow,不允許對記憶體訪問進行細粒度控制。

論文地址:https://arxiv.org/pdf/2205.14135.pdf

GitHub 地址:https://github.com/HazyResearch/flash-attention

該研究提出了一種新的注意力演算法 FlashAttention,它可以使用更少的記憶體訪問來計算精確的注意力。FlashAttention 旨在避免從 HBM(High Bandwidth Memory)中讀取和寫入注意力矩陣。這需要做到:(i) 在不訪問整個輸入的情況下計算 softmax reduction;(ii) 在後向傳播中不能儲存中間注意力矩陣。

該研究採用兩種成熟的技術來應對這些挑戰:

(i) 該研究重組注意力計算,將輸入分成塊,並在輸入塊上進行多次傳遞,從而逐步執行 softmax reduction(也稱為 tiling);

(ii) 該研究儲存前向傳遞的 softmax 歸一化因子,在後向傳播中快速重新計算片上注意力,這比從 HBM 中讀取中間注意力矩陣的標準方法更快。

該研究在 CUDA 中實現 FlashAttention ,以達到對記憶體訪問的細粒度控制,並將所有注意力操作融合到一個 GPU 核心中。即使由於重新計算導致 FLOPs 增加,但其執行速度更快(在 GPT-2 上高達 7.6 倍,圖 1 右圖)並且使用更少的記憶體(序列長度線性),主要是因為大大減少了 HBM 訪問量。

該研究分析了 FlashAttention 的 IO 複雜度,證明它需要 ( ^2 ^2^ −1)HBM 訪問,其中 是 head 維度, 是 SRAM 的大小,而標準的注意力需要Ω( + ^2 )HBM 訪問。對於 和 的典型值,與標準注意力相比,FlashAttention 需要的 HBM 訪問次數要少很多(最多減少 9 倍,如圖 2 所示)。此外,該研究還提供了一個下限,表明沒有精確的注意力演算法可以漸近地提高所有 SRAM 大小的 HBM 訪問次數。

該研究還表明,FlashAttention 可以作為一種原語(primitive),通過克服記憶體訪問開銷問題來實現近似注意力演算法。作為概念證明,該研究實現了塊稀疏 FlashAttention,這是一種稀疏注意力演算法,比 FlashAttention 快 2-4 倍,可擴充套件到 64k 的序列長度。該研究證明了塊稀疏 FlashAttention 比 FlashAttention 具有更好的 IO 複雜度。

值得一提的是,該研究還開源了 FlashAttention。

實驗結果

BERT:FlashAttention 得到了最快的單節點 BERT 訓練速度。該研究在 Wikipedia 上用 FlashAttention 訓練了一個 BERT-large 模型。表 1 將 FlashAttention 訓練時間與 Nvidia MLPerf 1.1 進行了比較,結果表明 FlashAttention 的訓練速度提高了 15%。

GPT-2:表 2 顯示,與 HuggingFace 相比,FlashAttention 端到端加速可達 3 倍,與 Megatron-LM 相比,加速可達 1.7 倍

Long-range Arena:該研究在 long-range arena (LRA) 基準上進行了實驗,他們測量了準確率、吞吐量、訓練時間。每個任務有不同的序列長度,從 1024 到 4096 不等。此外,實驗遵循 Tay 和 Xiong 等人的實驗設定。表 3 顯示,與標準注意力相比,FlashAttention 的速度提高了 2.4 倍。塊稀疏 FlashAttention 比所有近似注意力方法都要快。

具有長上下文的語言模型:FlashAttention 的執行時間和記憶體效率允許我們將 GPT-2 的上下文長度增加 4 倍,同時仍然比 Megatron-LM 的執行更快。從表 4 可以看出,上下文長度為 4K 的 FlashAttention GPT-2 仍然比上下文長度為 1K 的 Megatron 的 GPT-2 快 30%,同時 perplexity 提高了 0.7。

表 5 表明,在 MIMIC 上序列長度為 16K 的效能比長度為 512 的高出 4.3 個點,而在 ECtHR 上,序列長度為 8K 的比長度 512 高出 8.5 個點。

表 6 展示了 Transformer 模型可以解決 Path-X、Path-256 問題。該研究在 Path-64 上預訓練 transformer,然後通過空間插值位置嵌入遷移到 Path-X。FlashAttention 在 Path-X 上達到 61.4 的準確率。此外,塊稀疏 FlashAttention 使得 Transformers 將序列擴充套件到 64K,在 Path-256 實現 63.1 的準確率。

圖 3(左) 報告了以毫秒為單位的 FlashAttention 和塊稀疏 FlashAttention 前向 + 後向傳播的執行時間與基準比較,圖 3(右) 顯示了與各種精確、近似和稀疏注意基線相比,FlashAttention 和塊稀疏 FlashAttention 的記憶體佔用情況。

「其他文章」