CUDA優化之LayerNorm效能優化實踐

語言: CN / TW / HK

撰文 | 郭冉、姚遲、鄭澤康、柳俊丞

2020年末,OneFlow 釋出了《OneFlow 效能優化分享: 如何實現一個高效的 Softmax CUDA kernel? ,其中介紹了 OneFlow深度優化後的Softmax, 尤其對很多框架沒有考慮的 half 型別做了充分優化,使得 效能大幅超 過了 cuDNN 的實現。

今天,奉上另一個重要運算元 LayerNorm 的效能優化實踐技術分享。

此外,OneFlow 還帶上了可以獨立使用的 OneFlow Softmax(具體見文末說明),歡迎大家試用、提建議。

1

OneFlow 效能優化後的測試結果

OneFlow 優化後的 LayerNorm 分別與 NVIDIA Apex、PyTorch 做了效能對比,測試結果顯示,OneFlow LayerNorm 有明顯的效能優勢。

與 NVIDIA Apex 的對比結果

NVIDIA Apex 中實現了高效的 fused LayerNorm Kernel 來擴充套件 PyTorch 運算元,我們對 OneFlow 優化後的 LayerNorm Kernel 和 NVIDIA Apex 進行了對比測試,測試結果如下:

橫軸為 num_cols 大小,縱軸為 Kernel 執行需要的時間(越低越好):

我們將時間換算成訪存頻寬,結果如下,縱軸為 Kernel 達到的有效頻寬(越高越好):

其中測試環境為 NVIDIA A100-PCIE-40GB GPU,資料型別為 halfShape =(49152, num_cols) ,我們將最後一維動態變化,測試了從32到32768不同大小的 LayerNorm Kernel,可以看到在所有情況下,OneFlow 的 Kernel 執行時間和有效訪存頻寬都優於 Apex 的實現。

與 PyTorch 的對比結果

PyTorch 的 LayerNorm 暫時不支援 half 型別,因此我們用 float型別做了一組對照,需要注意的是PyTorch中LayerNorm是分兩個CUDA Kernel(RowwiseMomentsCUDAKernel和LayerNormForwardCUDAKernel)做的,所以看起來效能比較差。

橫軸為 num_cols 大小,縱軸為 Kernel 執行需要的時間(越低越好):

可以看到,在各組對比實驗中,OneFlow 的效能也是最優的。

2

LayerNorm 效能優化

LayerNorm 是語言模型中常用的操作之一,其 CUDA Kernel 實現的高效性會影響很多網路最終的訓練速度,Softmax 的優化方法也適用於 LayerNorm,LayerNorm 的資料也可以表示為 (num_rows, num_cols) ,計算過程中對每一行的元素做 Reduce 操作求均值方差。因此我們使用了和 Softmax 同樣的優化方法來優化 LayerNorm 操作,本文以 LayerNorm 前向計算為例進行介紹。

LayerNorm 計算方法

以 PyTorch 為例,LayerNorm 的介面為:

torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

其中 input 形狀為: [∗, normalized_shape[0], normalized_shape[1], …,normalized_shape[−1]]

第一個引數 normalized_shape 只能是輸入 x_shape 的後幾維,例如 x_shape (N, C, H, W) , normalized_shape 可以是 (W) (H, W) (C, H, W) (N, C, H, W) 。輸入 x normalized_shape 這幾維上求均值和方差。

第三個引數 elementwise_affine 代表是否要對 normalize 的結果做變換,即 normalize 的結果乘 gamma ,加 beta 。若 elementwise_affine=True ,就多了兩個模型引數 gammabeta ,形狀為 normalized_shape

例如對於輸入 x 形狀為 (N, C, H, W)normalized_shape(H, W) 的情況,可以理解為輸入 x(N*C, H*W) ,在 N*C 個行上,每行有 H*W 個元素,對每行的元素求均值和方差,得到 N*Cmeaninv_variance ,再對輸入按如下 LayerNorm 的計算公式計算得到 y 。若 elementwise_affine=True ,則有 H*Wgammabeta ,對每行 H*W 個的元素做變換。

LayerNorm 中求方差的方法

常見的求方差的方法有 two pass 方法、naive 方法、和 Welford 演算法,本文摘錄一些關鍵的公式和結論,詳細的介紹和推導可參考:Wiki: Algorithms for calculating variance( https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance ) ,和 GiantPandaCV: 用Welford演算法實現LN的方差更新

1.two-pass方法

使用的公式是:

two-pass 是指這種方法需要遍歷兩遍資料,第一遍累加 x 得到均值,第二遍用上面公式計算得到方差。這種方法在 n 比較小時仍然是數值穩定的。

2.naive方法

使用的公式是:

這種方法是一種 single pass 方法,在計算方差時只需要遍歷一遍資料累加 x 的平方及累加 x ,最後按上述公式計算得到方差。這種方法只需要遍歷一遍資料,相比 two-pass 的演算法,更容易達到好的效能,但是上面的 Wiki 參考連結中介紹由於 SumSquare 和 (Sum×Sum)/n 可能非常接近,可能會導致計算結果損失精度較大,因此這種方法不建議在實踐中使用。

3.Welford 演算法

使用的公式是:

Welford 演算法也是一種 single pass 方法,且數值穩定性很好,因此現在很多框架都採用這種方法。本文的程式碼中採用的也是 Welford 方法。

OneFlow 深度優化 LayerNorm CUDA Kernel 的技巧

和 Softmax 一樣,LayerNorm 也採用分段函式優化,對於不同的 num_cols 範圍,採用不同的實現,以在各種情況下都能達到較高的有效頻寬。

在每種實現中都採用了一個公共的優化:向量化訪存,NVIDIA 效能優化的部落格 Increase Performance with Vectorized Memory Access 中提到可以通過向量化記憶體操作來提高 CUDA Kernel 效能,很多 CUDA Kernel 都是頻寬受限的,使用向量化記憶體操作可以減少總的指令數,減少延遲,提高頻寬利用率。

理論上來說,在計算 LayerNorm 的過程中,輸入 x 需要被讀兩次,第一次用於計算均值和方差。第二次用於得到均值和方差後的計算過程。而對 Global Memory 的訪問操作是昂貴的,如果能將輸入 x 先存起來,不重複讀,就可以提升效能。在 GPU 中將輸入 x 存起來可以使用暫存器或 Shared memory,但是暫存器資源和 Shared memory 資源都是有限的,如果 num_cols 過大,就會超出資源的使用限制,因此我們針對不同 num_cols 採用不同的實現,下面分別進行介紹:

1. num_cols <= 1024 的情況

針對 num_cols <= 1024 的情況,以 Warp 為單位處理一行或兩行,將輸入 x 儲存到暫存器中。

硬體上並行執行的32個執行緒稱之為一個 Warp,同一個 Warp 的32個 thread 執行同一條指令, Warp是 GPU 排程執行的基本單元。執行緒塊和元素的對應關係如上圖所示,每個 Warp 的 threads 處理一行元素,每個 block 有 block_size / warp_size 個 Warp,每個 block 處理 block_size / warp_size 行元素。

具體的處理流程是,如下圖所示,每行有 num_cols 個元素,每個 warp 處理一行,因此每個執行緒需要處理 num_cols / warp_size 個元素,每個執行緒讀取自己需要處理的元素儲存到暫存器中,並用 Welford 演算法計算好均值和方差後,Warp 中的所有執行緒執行一次 WelfordWarpAllReduce,這樣每個執行緒上就得到了正確的均值和方差參與後續計算。

WelfordWarpAllReduce 由 WelfordWarpReduce 和 Broadcast 操作完成,WelfordWarpReduce 藉助 Warp 級別同步原語 __shfl_down_sync 實現,Broadcast操作藉助 __shfl_sync 實現,程式碼如下:

template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
                                             T* m2, T* count) {
  *mean = thread_mean;
  *m2 = thread_m2;
  *count = thread_count;
  for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
    T b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
    T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
    T b_count = __shfl_down_sync(0xffffffff, *count, mask);
    WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
  }
}

template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
                                                T* m2, T* count) {
  WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);
  *mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
  *m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
  *count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}

在這裡有個模板引數 thread_group_width ,當 num_cols > pack_size * WarpSize 時, thread_group_widthWarpSize 。當 num_cols 太小,即 num_cols<pack_size * WarpSize 時,一個 Warp 內的執行緒不是全部處理有效的值,此時我們採用更小的 thread_group_width ,取值可能是16、8、4、2、1,由 num_cols 決定,並且每個執行緒處理兩行增加並行度。

此外,在讀寫輸入輸出時,我們採用向量化訪存的優化,在滿足條件時,將 pack_size 個元素 pack 成更大的資料型別讀入,下圖為 pack_size=2 時的示意圖,每個執行緒以更大的資料型別讀入元素,可以更好的利用視訊記憶體頻寬。

pack_size 個元素 pack 成更大的資料型別讀入,但是 x 還要參與計算。因此我們定義一個 union 結構的 Pack 型別,storage 用於從 Global Memory中讀寫,做計算時用 elem[i] 取每個元素參與計算,Pack 型別定義如下:

template<typename T, int N>
union Pack {
  PackType<T, N> storage;
  T elem[N];
};

LayerNormWarpImpl Kernel 程式碼如下:

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
         int thread_group_width, int rows_per_access, bool padding>
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,                                   const double epsilon, ComputeType* mean,                                   ComputeType* inv_variance) {
  static_assert(cols_per_thread % pack_size == 0, "");
  static_assert(thread_group_width <= kWarpSize, "");
  static_assert(kWarpSize % thread_group_width == 0, "");
  constexpr int num_packs = cols_per_thread / pack_size;
  assert(cols <= cols_per_thread * thread_group_width);
  ComputeType buf[rows_per_access][cols_per_thread];
  const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
  const int64_t num_global_thread_group = gridDim.x * blockDim.y;
  const int64_t lane_id = threadIdx.x;
  for (int64_t row = global_thread_group_id * rows_per_access; row < rows;
       row += num_global_thread_group * rows_per_access) {
    ComputeType thread_mean[rows_per_access];
    ComputeType thread_m2[rows_per_access];
    ComputeType thread_count[rows_per_access];
#pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      thread_mean[row_id] = 0;
      thread_m2[row_id] = 0;
      thread_count[row_id] = 0;
      ComputeType* row_buf = buf[row_id];
#pragma unroll
      for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
        const int col = (pack_id * thread_group_width + lane_id) * pack_size;
        const int pack_offset = pack_id * pack_size;
        if (!padding || col < cols) {
          load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
#pragma unroll
          for (int i = 0; i < pack_size; ++i) {
            WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,
                           thread_count + row_id);
          }
        } else {
#pragma unroll
          for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; }
        }
      }
    }
    ComputeType warp_mean[rows_per_access];
    ComputeType warp_m2[rows_per_access];
    ComputeType warp_count[rows_per_access];
#pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      int global_row_id = row + row_id;
      ComputeType* row_buf = buf[row_id];
      WelfordWarpAllReduce<ComputeType, thread_group_width>(
          thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id,
          warp_m2 + row_id, warp_count + row_id);
      ComputeType row_mean = warp_mean[row_id];
      ComputeType row_variance =
          max(Div(warp_m2[row_id], warp_count[row_id]), static_cast<ComputeType>(0.0));
      ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
      if (lane_id == 0) {
        mean[global_row_id] = row_mean;
        inv_variance[global_row_id] = row_inv_var;
      }
#pragma unroll
      for (int i = 0; i < cols_per_thread; ++i) {
        row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;
      }
#pragma unroll
      for (int i = 0; i < num_packs; ++i) {
        const int col = (i * thread_group_width + lane_id) * pack_size;
        if (!padding || col < cols) {
          store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);
        }
      }
    }
  }
}

LayerNormWarpImpl 的實現的模板引數的意義分別如下:

  • LOADSTORE 分別代表輸入輸出,使用 load.template load<pack_size>(ptr, row_id, col_id);store.template store<pack_size>(ptr, row_id, col_id); 進行讀取和寫入。使用 LOADSTORE 有兩個好處:a) 可以在 CUDA Kernel中只關心計算型別 ComputeType ,而不用關心具體的資料型別 T 。b) 只需要加幾行程式碼就可以快速支援 LayerNorm 和其他 Kernel Fuse,減少頻寬需求,提升整體效能。

  • ComputeType 代表計算型別。 pack_size 代表向量化訪存操作的 pack 元素的個數,我們將幾個元素 pack 起來讀寫,提升頻寬利用率。

  • cols_per_thread 代表每個執行緒處理的元素個數。

  • thread_group_width 代表處理元素的執行緒組的寬度,當 cols > pack_size * warp_size 時, thread_group_width 就是warp_size,即32。當 cols < pack_size * warp_size 時,就根據 cols 大小用 1/2個warp 或 1/4個warp 來處理每行的元素。採用更小的 thread_group_width 後,WarpAllReduce需要執行的輪次也相應減少。

  • rows_per_access 代表每個 thread_group 一次處理的行數,當 cols 較小且 thread_group_width 小於warp_size時,若 rows 能被2整除,我們就讓每個執行緒處理2行來增加指令並行度,從而提升效能。

  • padding 代表當前是否做了 padding,若 cols 不是 warp_size 的整數倍,我們會把它padding 到最近的整數倍處理。

2. num_cols > 1024 的情況

針對 num_cols > 1024 ,以 block 為單位處理一行,利用 Shared Memory 儲存輸入資料 對於 num_cols > 1024 的情況,每個 block 處理一行元素,將輸入 x 儲存到 Shared Memory中。

具體的處理流程是,如下圖所示,每行有 num_cols 個元素,每個 block 處理一行,因此每個執行緒需要處理 num_cols / block_size 個元素,每個執行緒讀取自己需要處理的元素儲存到 Shared Memory 中,並用 Welford 演算法計算好均值和方差後,block 中的所有執行緒執行一次WelfordBlockAllReduce,這樣每個執行緒上就得到了正確的均值和方差參與後續計算。

WelfordBlockAllReduce 是藉助 WelfordWarpReduce 操作完成的,具體邏輯是,一個 Block 中最多有32個 Warp,對所有的 Warp 先執行一次 WelfordWarpReduce,執行完後,每個 warp 中的第一個執行緒,即 lane_id=0 的執行緒上得到當前 WelfordWarpReduce 的結果,再將每個 Warp 的第一個執行緒的結果拷貝到一塊 Shared Memory buffer 中,再用第一個 Warp 的32個執行緒執行一次 WelfordWarpReduce,此時第一個 Warp 中的 lane_id=0 的執行緒上得到的就是 block 中所有執行緒reduce 的結果。再借助 Shared Memory,將該結果 broadcast 到 block 中的所有執行緒上,即完成了 WelfordBlockAllReduce 的操作。

值得注意的是,GPU 上 Shared Memory 資源同樣有限,當 num_cols 超過一定範圍時需要佔用的Shared Memory 可能就超出了最大限制,Kernel 就無法啟動起來。

因此,我們採用 cudaOccupancyMaxActiveBlocksPerMultiprocessor 函式判斷當前硬體資源條件下 Kernel 是否能成功啟動,僅在返回值大於0時採用這種方案。

此外,由於 Block 內執行緒要做同步,當 SM 中正在排程執行的一個 Block 到達同步點時,SM 內可執行 Warp 逐漸減少,若同時執行的 Block 只有一個,則 SM 中可同時執行的 Warp 會在此時逐漸降成0,會導致計算資源空閒,造成浪費,若此時同時有其他 Block 在執行,則在一個 Block 到達同步點時仍然有其他 Block 可以執行。

block_size 越小時,SM 可同時排程的 Block 越多,因此在這種情況下 block_size 越小越好。但是當在調大 block_size ,SM 能同時排程的 Block 數不變的情況下, block_size 應該是越大越好,越大就有越好的並行度。因此程式碼中在選擇 block_size 時,對不同 block_size 都計算了 cudaOccupancyMaxActiveBlocksPerMultiprocessor ,若結果相同,使用較大的 block_size

LayerNormBlockSMemImpl Kernel的程式碼如下:

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,                                        const int64_t cols, const double epsilon, ComputeType* mean,                                        ComputeType* inv_variance) {
  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
  auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
  const int tid = threadIdx.x;
  assert(cols % pack_size == 0);
  const int num_packs = cols / pack_size;
  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
    ComputeType thread_mean = 0;
    ComputeType thread_m2 = 0;
    ComputeType thread_count = 0;
    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
      ComputeType pack[pack_size];
      load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
      for (int i = 0; i < pack_size; ++i) {
        buf[i * num_packs + pack_id] = pack[i];
        WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
      }
    }
    ComputeType row_mean = 0;
    ComputeType row_m2 = 0;
    ComputeType row_count = 0;
    WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
                                       &row_count);
    ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
    ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
    if (threadIdx.x == 0) {
      mean[row] = row_mean;
      inv_variance[row] = row_inv_var;
    }
    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
      ComputeType pack[pack_size];
#pragma unroll
      for (int i = 0; i < pack_size; ++i) {
        pack[i] = (buf[i * num_packs + pack_id] - row_mean) * row_inv_var;
      }
      store.template store<pack_size>(pack, row, pack_id * pack_size);
    }
  }
}

3.num_cols 較大時,不使用 Shared Memory 的情況

num_cols 較大,當前硬體資源條件下使用Shared Memory的方法無法成功Launch Kernel時,使用這種實現:一個 Block 處理一行的元素,不使用 Shared Memory,重複讀輸入 x

這種方法和前面第二種情況執行緒和元素對應關係一致,唯一的區別在於,第二種方法將輸入 x 儲存到Shared Memory 中,本方法不儲存 x ,在每次計算時需要再從 Global Memory 中讀入 x 。這種方法雖然需要多讀一份 x ,但是在實際執行時,部分輸入可以被 Cache 快取起來,不會實際增加很多時間。值得注意的是,在這種實現中, block_size 越大,SM 中能同時並行執行的 block 數就越少,對 Cache 的需求就越少,就有更多機會命中 Cache,因此我們使用較大的 block_size

LayerNormBlockUncachedImpl 程式碼如下:

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,                                            const int64_t cols, const double epsilon,                                            ComputeType* mean, ComputeType* inv_variance) {
  const int tid = threadIdx.x;
  assert(cols % pack_size == 0);
  const int num_packs = cols / pack_size;
  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
    ComputeType thread_mean = 0;
    ComputeType thread_m2 = 0;
    ComputeType thread_count = 0;
    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
      ComputeType pack[pack_size];
      load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
      for (int i = 0; i < pack_size; ++i) {
        WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
      }
    }
    ComputeType row_mean = 0;
    ComputeType row_m2 = 0;
    ComputeType row_count = 0;
    WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
                                       &row_count);
    ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
    ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
    if (threadIdx.x == 0) {
      mean[row] = row_mean;
      inv_variance[row] = row_inv_var;
    }
    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
      ComputeType pack[pack_size];
      const int pack_offset = pack_id * pack_size;
      load.template load<pack_size>(pack, row, pack_offset);
#pragma unroll
      for (int i = 0; i < pack_size; ++i) { pack[i] = (pack[i] - row_mean) * row_inv_var; }
      store.template store<pack_size>(pack, row, pack_offset);
    }
  }
}

3

OneFlow Softmax 庫

經過反覆迭代,OneFlow 的 Softmax 的介面和實現已經成熟,趨於穩定,所以 OneFlow 團隊把它解耦後,作為獨立的介面提供,優化程式碼放在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh ,它可以脫離 OneFlow 程式碼獨立編譯。

在你的專案中 include 這個標頭檔案後,就可以直接使用。比如,使用以下幾行程式碼就可以實現一個 Softmax GPU Kernel。

    oneflow::cuda::softmax::DirectLoad<half, float> load(in, cols);
    oneflow::cuda::softmax::DirectStore<float, half> store(out, cols);
    oneflow::cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(
        cuda_stream, load, store, rows, cols);

如果要實現一個 LogSoftmax Kernel 也很簡單:只需要將以上程式碼中的的 DispatchSoftmax 換成 DispatchLogSoftmax 就可以了。

與其它地方提供的 Softmax 相比,OneFlow Softmax 的主要優勢有:

  • 效能優勢,可見之前的文章分享。此外,最近一年進一步優化了小的 num_cols 下的效能。

  • 同時支援了 Softmax 和 LogSoftmax,適用場景更廣。

  • 輸入輸出通過 Load/Store 結構傳遞,解耦資料IO和計算,只需要加幾行程式碼就可以快速支援 Softmax 和其他 Kernel Fuse,減少頻寬需求,帶來很高的效能收益。

其他人都在看

點選“ 閱讀原文 ,歡迎下載體驗OneFlow新一代開源深度學習框架