Efficiently Teaching an Effective Dense Retriever with Balanced Topic Aware Sampling

語言: CN / TW / HK

Source: SIGIR 2021: Efficiently Teaching an Effective Dense Retriever with Balanced Topic Aware Sampling

Code: sebastian-hofstaetter/tas-balanced-dense-retrieval

TL;DR: 基於BERT的 稠密檢索模型 雖然在IR領域取得了階段性的成功,但檢索模型的訓練、索引和查詢效率一直是IR社群關注的重點問題,雖然超越SOTA的檢索模型越來越多,但模型的訓練成本也越來越大,以至於要訓練最先進的稠密檢索模型通常都需要8 V100的配置。而採用本文提出的TAS-Balanced和Dual-supervision訓練策略, 我們僅需要在單個消費級GPU上花費48小時從頭訓練一個6層的DistilBERT就能取得SOTA結果 ,這再一次證明了當前大部分稠密檢索模型的訓練是緩慢且低效的。

緒言

在短短的兩年時間內,當初被質疑是Neural Hype的Neural IR現在已經被IR社群廣泛接受,不少開源搜尋引擎也逐漸支援了基於BERT的稠密檢索(dense retrieval),基本達到了開箱即用的效果。其中,DPR提出的 是當前最主流的稠密檢索模型,然而眾所周知的是, 的可遷移性遠不如BM25這類learning-free的傳統檢索方法,想要在具體的業務場景下使用 並取得理想的結果,我們通常需要準備充足的標註資料進一步訓練檢索模型。

因此,如何高效地訓練一個又快又好的 一直是Neural IR的研究熱點。目前來看,改進 主要有兩條路線可走,其中一條路線是改變batch內的樣本組合,讓模型能夠獲取更豐富的對比資訊:

  • 優化模型的訓練過程: 這類方法的代表作是 ANCE 提出的動態負取樣策略,其基本思路是在訓練過程中定期重新整理索引,從而為模型提供更優質的難負樣本,而不是像DPR那樣僅從BM25中獲取負樣本。在此基礎上, LTRe 指出目前的檢索模型其實是按learning to rank來訓練的,因為訓練過程中模型僅能看到一個batch內的樣本,但如果我們只訓練query encoder,凍結passage embedding,我們就可以按照learning to retrieve的方式計算全域性損失,而不是僅計算一個batch的損失。除此之外, RocketQA 提出了Cross Batch技巧來增大batch size,由於檢索模型採用對比損失訓練,因此理論上增大batch size帶來的基本都是正收益。

然而,這三種策略都在原始的 的基礎上增加了額外的計算成本,並且實現都比較複雜。除此之外,我們也可以利用 知識蒸餾 (knowledge distillation)為模型提供更優質的監督訊號:

  • 優化模型的監督訊號: 我們可以將表達能力更強但執行效率更低的 當作teacher model來為 提供 soft label 。在檢索模型的訓練中,知識蒸餾的損失函式有很多可能的選擇,本文僅討論pairwise loss和in-batch negative loss,其中in-batch negative loss在pairwise loss的基礎上將batch內部其他query的負樣本也當作當前query的負樣本,這兩類蒸餾loss的詳細定義後文會講。

本文同樣是在上述兩個方面對 做出優化,在訓練過程方面,作者提出了 Balanced Topic Aware Sampling(TAS-Balanced) 策略來構建batch內的訓練樣本;在監督訊號方面,作者提出了將 pairwise loss 和in-batch negative loss結合的 dual-supervision 蒸餾方式。

Dual Supervision

越來越多的證據表明知識蒸餾能夠帶來稠密檢索模型效能的提升,本文將 提供的pairwise loss和 提供的in-batch negative loss結合起來為 提供監督訊號,下面先簡單介紹一下teacher model和student model。

Teacher Model:

是當前應用最為廣泛的排序模型,它簡單地將query和passage的拼接作為 的輸入序列,然後對 輸出向量做一個線性變換得到相關性打分:

是一個經典的多向量表示模型,它將query和passage之間的互動簡化為max-sum來克服 無法快取passage向量的問題,其基本思路是首先對query和passage分別編碼

然後計算每個query term和每個passage term的 點積 相似度,按doc term做max-pooling並按query term求和獲取query和passage的相似度:

雖然理論上 可以對passage建立離線索引,但儲存passage多向量表示的資源開銷是非常大的,並且該儲存成本隨著語料庫的term數量呈線性增長,再加上max-sum的操作也會帶來額外的計算成本,因此這裡我們將 當作 的teacher。

Student Model:

DPR提出的 僅使用二元標籤和BM25生成的負樣本訓練模型, 首先將query和passage獨立編碼為單個向量:

然後計算 和 的點積相似度:

在檢索階段, 首先對query編碼,然後利用faiss做最大內積檢索,下表展示了在單個消費級GPU上6層DistilBERT在800萬passage集合上的檢索速度。

Dual-Teacher Supervision

如果僅看 監督訊號 的質量, 提供的in-batch negative loss當然是最優質的。然而, 雖然在表達能力上比 更強,但它實際上很少用於計算in-batch negative loss,因為 需要單獨編碼每個query-passage樣本對,所以其計算開銷隨著batch size二次增長,而 解耦了query和passage的表示,因此它的開銷是隨著batch size線性增長的,其in-batch negative loss的計算效率要高得多。

因此這裡我們只讓 提供pairwise loss,具體來說,我們首先利用訓練好的 對訓練集中所有的query-passage樣本對打分,然後計算 蒸餾損失 ,蒸餾損失的具體形式有很多選擇,這裡作者選擇了Margin-MSE loss作為pairwise loss:

其中 分別為

我們同時讓 提供in-batch negative loss:

in-batch negative loss中的 其實也可以替換成別的loss,作者在後續實驗中也嘗試了一些看起來更有效的listwise loss,然而實驗結果表明Margin-MSE loss依舊是最佳的選擇。因此,作者最終提出的蒸餾loss是pairwise loss和in-batch negative loss的加權平均,在後續實驗中,作者設 加權係數

Balanced Topic Aware Sampling

在原始的 的訓練中,我們首先隨機地從 query集合 中取樣 個 ,然後再為每個 隨機取樣一個正樣本 和一個負樣本 組成一個batch:

其中 表示從集合 無放回地取樣 個樣本。由於訓練集是非常大的,每個batch中的 幾乎都是沒有相關性的,但是當我們計算in-batch negative loss時,query不僅和自身的 互動,也和別的query對應的 互動,然而,由於 對模型來說大概率是簡單樣本,因此它所能提供的資訊增益是非常少的,這也導致了每個batch所能提供的資訊量偏少,使得檢索模型需要長時間的訓練才能收斂。

TAS

針對這個問題,作者提出了Topic Aware Sampling(TAS)策略來構建batch內的訓練樣本,具體來說,在訓練之前,我們先利用 -mea ns演算法 將query聚類到 個cluster中:

其中query的表示 由基線模型 提供, 聚類中心 ,這樣,每個cluster中的query都是主題相關的,在構建batch的時候,我們可以先從cluster的集合 隨機抽樣 個cluster,然後在每個cluster上隨機抽樣 個query:

在後續的實驗中,作者為40萬個query建立了 個cluster,並設batch size大小為 ,組建batch時隨機抽樣的cluster數量為 ,這樣,每個batch中的樣本都來自於同一個cluster。如下圖所示, 相比於在整個query集合上隨機抽樣,TAS策略生成的batch內部的query有更高的主題相似性。

TAS-balanced

在組建batch的時候,我們還需要為每個取樣到的query配置正負樣本對 。不難想到,幾乎所有query對應的 都比 少得多,如果用獨立隨機抽樣的方式獲取 ,那麼組成的 的margin(也就是 )大概率是很大的,因此大部分 對模型來說是簡單樣本,因為模型很容易將 分開。

因此,我們可以在TAS策略的基礎上進一步均衡batch內正負樣本對的margin分佈以減少high margin(low information)的正負樣本對。具體來說,針對每個query,我們首先計算它對應的樣本對集合的最小margin和最大margin,然後將該區間分割為 個子區間 ,在為query配置 時,我們首先從這 個子區間中隨機選擇一個子區間,然後從margin落在該子區間內的 集合中隨機取樣並組成一個訓練樣本:

這樣,在構建一個batch的時候,我們首先需要取樣一個cluster,然後取樣 個query,接下來為每個query取樣一個margin子區間,最後在該子區間上取樣一個正負樣本對,這整套流程就是所謂的 TAS-balanced batch sampling

需要注意的是,TAS-balanced策略不會影響模型的訓練速度,因為batch的構建是可以並行處理或者預先處理好的。TAS-balanced策略組建的batch對模型來說整體的難度更大,因此為模型提供了更多的資訊量,即使採用較小的batch size,模型也能很好地收斂。如下表所示,我們可以在消費級顯示卡上(11GB記憶體)高效地訓練 而不需要昂貴的8 V100的配置,因為該方法不需要像ANCE那樣重複重新整理索引,也不需要像RocketQA那樣進行超大批量的訓練。

Experiment

作者選擇MSMACRO-Passage官方提供的4000萬正負樣本對作為檢索模型的訓練集,並選擇MSMACRO-DEV(sparsely-judged,包含6980個query)和TREC-DL 19/20(densely-judged,包含43/54個query)作為驗證集。同時 均採用6層的DistilBERT初始化,且沒有使用預訓練的檢索模型。

Results

Source of Effectiveness

首先我們對作者提出的Dual-supervision做消融實驗,如下表所示。對於基於pairwise loss的知識蒸餾,Margin-MSE loss的優越性已經被之前的論文證明,所以這裡僅討論in-batch negative loss的有效性。作者對比了基於listwise loss的KL Divergence、ListNet和Lambdarank,實驗結果表明這些損失的效果都不如Margin-MSE loss,尤其是在R@1K上面。

為什麼pairwise的Margin-MSE比listwise loss更好呢?因為Margin-MSE不僅僅是讓模型去學習teacher所給出的排序,同時還學習teacher score的分佈,由於batch內部樣本的order實際上是有偏的,它並不能準確刻畫樣本間的真實距離,因此比起學習order,學習score分佈其實是一種更精確的方式。另外,由於teacher和student在訓練階段所使用的損失是一致的,這也會讓student更容易學習到teacher的score分佈。

接下來我們對TAS-Balanced策略做消融實驗,如下表所示。總體來說,TAS-balanced策略加上Dual-supervision蒸餾可以在各個資料集上取得最優效能。值得關注的是, 在單獨的pairwise loss的監督下使用TAS策略其實並不能帶來明顯的提升 ,這是因為TAS是面向in-batch negative loss設計的,使用pairwise loss訓練時,batch內的樣本是沒有互動的,因此TAS也就不會起作用。而TAS-balanced策略會影響正負樣本對的組成方式,因此會對pairwise loss產生一定的影響。

Comparing to Baselines

下表對比了作者的模型和其他模型的效能,對比最後三行,我們可以發現一個有趣的現象: 增大batch size在TREC-DL這類densely-judge的資料集上沒有帶來提升,但在MSMACRO-DEV這類sparsely-judge的資料集上會帶來持續的提升。 因此作者猜想增大batch size會導致模型在sparsely-judge的MSMACRO上過擬合,RocketQA的SOTA表現可能僅僅是因為它的batch size夠大。

TAS-Balanced Retrieval in a Pipeline

為了進一步證明方法的有效性,作者嘗試將TAS-Balanced訓練的檢索模型應用到召回-排序系統中。眾所周知,稠密檢索和稀疏檢索是互補的,且融合稀疏檢索幾乎不會影響召回速度,因此作者考慮將稀疏檢索的docT5query的檢索結果和TAS-balanced稠密檢索模型的結果融合,然後使用最先進的mono-duo-T5排序模型對檢索結果做重排。

選擇不同的召回模型、排序模型和不同大小的候選集,我們可以得到不同延遲水平的檢索系統。如上表所示,作者提出的模型在各個延遲水平上均取得了優異的表現。值得注意的是,在高延遲系統中,排序模型mono-duo-T5是在BM25的召回結果上訓練的,這實際上會導致 訓練測試分佈不一致 的問題,所以TAS-B+mono-duo-T5甚至沒能超越BM25+mono-duo-T5,為了取得更好的效能, 我們應該先訓召回模型,然後在召回模型的給出召回結果上訓練排序模型,這其實也間接反映了當前的排序模型泛化性不足的問題。

Discussion

本篇論文最大的亮點是TAS-Balanced策略的高效性,使用作者的模型,我們僅需要在單個消費級GPU上從頭訓練48小時就能取得SOTA結果,極大地降低了檢索模型的訓練成本,這在之前是無法想象的。實際上,比起NLP社群,IR社群更加強調模型和資料的Efficiency,這一課題在將來也一定會受到持續的關注。