RWKV-v2-RNN 原理簡介:超越 Transformer,實現 O(T) 的語言建模

語言: CN / TW / HK

RWKV-2 是純 RNN,但效能媲美 Transformer。

RWKV-2 同時支援並行和序列模式,因此,它具有 RNN 和 Transformer 的所有優點:高效能,快速執行,快速訓練,節省視訊記憶體,“無限”ctxLen,隱狀態是免費的句嵌入,等等等等。

目前我正在 8*A100 訓練 400M 引數語言模型,很快將測試 1B 引數。 如果你的專案用到 Transformer,歡迎聯絡我,測試 RWKV-2 在各個任務的效能。

LSTM 提出者 Sepp Hochreiter 也在推特轉發:

我們先回顧傳統 Transformer 的注意力機制,然後看 RWKV-2 的做法。

一、傳統 Transformer 模型:生成長度 T 的序列,需 O(T^2) 複雜度

令 F[t] 為 t 時刻的系統狀態(高維向量)。

令 x[t] 為 t 時刻的外部輸入資訊狀態。

預測 F[t+1] 時,需考慮 F[0], F[1], .. F[t]。因此,生成長度 T 的序列,需 O(T^2) 複雜度。

簡化版本的公式:

這裡 Q K V 是三個可訓練的矩陣。

其意義為:

  • 每個狀態 i 對於後續的潛在貢獻是 V F[i]。
  • Q x[t] 向量,與此前的所有 K F[i] 向量分別做點乘,再 exp,得到 x[t] 與之前各個 F[i] 狀態的匹配度。
  • 如果匹配度 越大, V F[i] 的權重越大。
  • 分母為歸一化因子。

注意:公式中沒有顯式出現 t 與 i 的距離資訊。我們會採用其它方式(例如位置編碼)將其注入系統。

二、RWKV-2 模型:生成長度 T 的序列,只需 O(T) 複雜度

簡化版本的公式(這是並行版本,再往下你就會看到 O(T) 的序列 RNN 版本):

它和 https:// arxiv.org/abs/2105.1410 3 很像,核心區別是,我注意到了上述形式還擁有 RNN 形式。

這裡 R K V 是三個可訓練的矩陣, W 是一個可訓練的向量(代表時間衰減率)。

其意義為:

  • 每個狀態 i 對於後續的潛在貢獻是 V F[i]。
  • 匹配度由 改為 ,其中 是非線性函式,經實驗採用 sigmoid 函式的效果較好。注意 不參與歸一化,所以我將 R 稱為 receptance。
  • 這裡 是顯式的距離因子。實際我在 2020 年 8 月就提出可以在注意力中加入這種距離因子,當時我稱為 time-weighting(見 https:// github.com/BlinkDL/minG PT-tuned 的 commit 記錄)。

現在,我們進一步變換,將其變為 RNN 遞迴形式。即,生成 F[t+1] 時,只需考慮 x[t],以及固定大小的隱狀態 A[t] 和 B[t],無需與此前 F[i] 都進行計算。因此,生成長度 T 的序列,只需 O(T) 複雜度。

這是因為:

易見:

其中 A[t] 和 B[t] 是上一步的分子和分母。

可見,RWKV-2 的原理非常簡單。

這樣簡單的方式卻有效,關鍵是因為,不斷乘以 exp(W) 的衰減過程,如果把 exp(W) 看成是某個矩陣 M 對角化後的對角線,那麼它就足以模擬矩陣 M 的不斷作用。

RWKV-2 也可以寫成連續的形式,進一步聯絡到微分方程,這個我稍後會寫寫。