BERT 蒸餾在垃圾輿情識別中的探索

語言: CN / TW / HK

近來 BE RT等大規模預訓練模型在 NLP 領域各項子任務中取得了不凡的結果,但是模型海量引數,導致上 線困難,不能滿足生產需求。 輿情稽核業務中包含大量的垃圾輿情,會耗費大量的人力。本文在垃圾輿情識別任務中嘗試 BERT 蒸餾技術 ,提升 textCNN 分類器效能,利用其小而快的優點,成功落地。

風險樣本如下:

一  傳統蒸餾方案

目前,對模型壓縮和加速的技術主要分為四種:

  • 引數剪枝和共享

  • 低秩因子分解

  • 轉移/緊湊卷積濾波器

  • 知識蒸餾

知識蒸餾就是將教師網路的知識遷移到學生網路上,使得學生網路的效能表現如教師網路一般。本文主要集中講解知識蒸餾的應用。

1  soft label

知識蒸餾最早是 2014 年 Caruana 等人提出方法。通過引入 teacher network(複雜網路,效果好,但預測耗時久) 相關的軟標籤作為總體 loss 的一部分,來引導 student network(簡單網路,效果稍差,但預測耗時低) 進行學習,來達到知識的遷移目的。這是一個通用而簡單的、不同的模型壓縮技術。

  • 大規模神經網路 (teacher network)得到的類別預測包含了資料結構間的相似性。

  • 有了先驗的小規模神經網路(student network)只需要很少的新場景資料就能夠收斂。

  • Softmax函式隨著溫度變數(temperature)的升高分佈更均勻。

Loss公式如下:

其中,  

由此我們可以看出蒸餾有以下優點:

  • 學習到大模型的特徵表徵能力,也能學習到one-hot label中不存在的類別間資訊。

  • 具有抗噪聲能力,如下圖,當有噪聲時,教師模型的梯度對學生模型梯度有一定的修正性。

  • 一定的程度上,加強了模型的泛化性。

紅色為噪聲資料梯度,黃色為教師模型梯度,綠色為最優梯度

2  using hints

(ICLR 2015) FitNets Romero等人的工作不僅利用教師網路的最後輸出logits,還利用了中間隱層引數值,訓練學生網路。獲得又深又細的FitNets。

中間層學習loss如下:

作者通過新增中間層loss的方式,通過teacher network 的引數限制student network的解空間的方式,使得引數的最優解更加靠近到teacher network,從而學習到teacher network的高階表徵,減少網路引數的冗餘。

3  co-training

(arXiv 2019) Route Constrained Optimization (RCO) Jin和Peng等人的工作受課程學習(curriculum learning)啟發,並且知道學生和老師之間的gap很大導致蒸餾失敗,導致認知偏差,提出路由約束提示學習(Route Constrained Hint Learning),把學習路徑更改為每訓練一次teacher network,並把結果輸出給student network進行訓練。student network可以一步一步地根據這些中間模型慢慢學習,from easy-to-hard。

訓練路徑如下圖:

二  Bert2TextCNN蒸餾方案

為了提高模型的準確率,並且保障時效性,應對GPU資源緊缺,我們開始構建bert模型蒸餾至textcnn模型的方案。

方案1:離線logit textcnn 蒸餾

使用的是Caruana的傳統方法進行蒸餾。

離線 logit textcnn 蒸餾訓練流程

方案2:聯合訓練 bert textcnn 蒸餾

引數隔離:teacher model 訓練一次,並把logit傳給student。teacher 的引數更新至受到label的影響,student 引數更新受到teacher loigt的soft label loss 和label 的 hard label loss 的影響。

聯合訓練 bert textcnn 蒸餾引數隔離訓練流程

方案3:聯合訓練 bert textcnn 蒸餾

引數不隔離: 與方案2類似,主要區別在於前一次迭代的student 的 soft label 的梯度會用於teacher引數的更新。

聯合訓練 bert textcnn 蒸餾引數不隔離訓練流程

方案4:聯合訓練 bert textcnn loss 相加

teacher 和student 同時訓練,使用mutil-task的方式。

聯合訓練 bert textcnn loss 相加訓練流程

方案5:多teacher

大部分模型,在更新時候需要覆蓋線上歷史模型的樣本,使用線上歷史模型作為teacher,讓模型學習原有歷史模型的知識,保障對原有模型有較高的覆蓋。

多 teacher 訓練流程

實驗結果如下:

從以上的實驗,可以發現很有趣的現象。

1)方案2和方案3均使用先訓練teacher,再訓練student的方式,但是由於梯度返回更新是否隔離的差異,導致方案2低於方案3。是由於方案3中,每次訓練一次teacher,在訓練一次student,student學習完了的soft loss 會再反饋給teacher,讓teacher知道指如何導student是合適的,並且還提升了teacher的效能。

2)方案4採用共同更新的,同時反饋梯度的方式。反而textcnn 的效能迅速下降,雖然bert的效能基本沒有衰減,但是bert難以對textcnn每一步的反饋有個正確性的引導。

3)方案5中使用了歷史textcnn 的logit,主要是為了用替換線上模型時候,並保持對原有模型有較高的覆蓋率,雖然召回下降,但是整體的覆蓋率相比於單textcnn 提高了5%的召回率。

Reference

1.Dean, J. (n.d.). Distilling the Knowledge in a Neural Network. 1–9.

2.Romero A , Ballas N , Kahou S E , et al. FitNets: Hints for Thin Deep Nets[J].

3.Jin X , Peng B , Wu Y , et al. Knowledge Distillation via Route Constrained Optimization[J].

歡迎各位技術同路人加入螞蟻集團大安全機器智慧團隊,我們專注於面向海量輿情藉助大資料技術和自然語言理解技術挖掘存在的金融風險、平臺風險,為使用者資金安全護航、提高使用者在螞蟻生態下的使用者體驗。內推直達 [email protected],有信必回。

AI 場景體驗

機器學習演算法: 基於邏輯迴歸的分類預測

邏輯迴歸(Logistic regression,簡稱LR)是一個分類模型,其模型簡單和可解釋性強, 是很多分類演算法的基礎 元件 。邏輯迴歸模型廣泛應用於機器學習、大多數醫學領域和社會科學等領域。通過本次實驗,幫助大家掌握邏輯迴歸的理論,以及 sklearn 函式呼叫使用並將其運用到鳶尾花資料集的預測中。

擊”閱讀原文“立即體驗吧 ~

關注 機器智慧

把握未來可能

戳我,立即體驗。