詳解NLP和時序預測的相似性【附贈AAAI21最佳論文INFORMER的詳細解析】
摘要:本文主要分析自然語言處理和時序預測的相似性,並介紹Informer的創新點。
前言
時序預測模型無外乎RNN(LSTM, GRU)以及現在非常火的Transformer。這些時序神經網路模型的主要應用都集中在自然語言處理上面(transformer就是為了NLP設計的)。在近些年來,RNN和Transformer逐漸被應用於時序預測上面,並取得了很好的效果。2021年發表的Informer網路獲得了AAAI best paper。本文主要分析自然語言處理和時序預測的相似性,並介紹Informer的創新點。
具體的本文介紹了
- 早期機器翻譯模型RNN-AutoEncoder的原理
- RNN-AutoEncoder升級版Transformer的原理
- 時序預測與機器翻譯的異同以及時序預測演算法的分類
- AAAI21最佳論文,時序預測模型INFORMER的創新點分析
RNN AutoEncoder
早期自然語言處理:RNN autoencoder
Sutskever, Ilya, Oriol Vinyals, and Quoc V. Le. “Sequence to sequence learning with neural networks.” arXiv preprint arXiv:1409.3215 (2014). (google citation 14048)
這裡以機器翻譯為例子,介紹RNN autoencoder的原理。
輸入句子經過分詞,編碼成數字,然後embedding成神經網路可以接受的向量。
在訓練過程中,可以使用teacher forcing,即decoder的輸入使用真實值,迫使在訓練過程中,誤差不會累加
在線上翻譯過程中,encoder部分流程相同,decoder部分,目標句子是一個單詞一個單詞生成的
早期RNN auto encoder結構雖然相比於傳統模型取得了巨大成功,但encoder,decoder之間的資訊傳播僅僅時由單一的一個隱層連結完成的,這樣勢必會造成資訊丟失,因此,Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” arXiv preprint arXiv:1409.0473 (2014).(citation 16788)提出在輸入和輸出之間增加額外的attention連結,增加資訊傳遞的魯棒性以及體現輸出句子中不同單詞受輸入句子單詞影響的差異性。
Transformer
2017-劃時代:Transformer—LSTM autoencoder的延申。
既然attention效果如此的好,那麼能否只保留attention而拋棄迴圈神經網路呢?
Google在17年年底提出了transformer網路,帶來了nlp的技術革命。
transformer本質上還是一個資訊順序傳遞的模型。主要包含了positional encoding(用於區分詞語出現的先後順序),self-attention, attention, 以及feed forward網路四大部分。與RNN不同的是,Transformer利用了attention機制進行資訊傳遞,具體的,self-attention的資訊傳遞機制如下:
一個詞向量和句子中所有詞向量構成的矩陣做相關得到相關性向量,做softmax歸一化後,求得所有詞向量構成的加權和,得到新的詞向量。
transformer同樣用attention機制,代替了RNN-AE中用來傳遞句子之間資訊的隱層連結。此外,在decoder階段,為了保證矩陣中上一個下一個單詞僅僅由他前面的單詞決定,在self-attention中,還需要做一個上三角矩陣的masking。
在訓練過程中,一般同樣採取teacher forcing的方法,即decoder輸入是完整的目標句子的embedding。而在線上翻譯的時候,依然從採取瞭如RNN-AE一樣的滾動輸出的方式,即初始輸入為<SOS>,餘下向量全部用padding。得到輸出後,一個一個加入到decoder輸入中,直到遇到<EOS>。
在transformer提出以後,基於Transformer的BERT預言模型成為了NLP中統治級別的模型。
時序預測與機器翻譯的異同
時序預測按照輸入的區別可以分為兩大類,即直接時序預測和滾動時序預測。
直接時序預測,的輸入是被預測部分的時間戳,輸出是被預測部分的值。在訓練過程中直接時序預測演算法首先把輸出和時間戳的關係建立為y=f(x)函式,然後用訓練集擬合這個函式。在預測階段,直接輸入被預測部分的時間戳即可獲得目標值。典型的演算法即為FB的PROPHET演算法。
與直接時序預測演算法不同的是,滾動時間序列預測演算法絕大部分都不依靠時間戳資訊。滾動時間序列預測把時間序列模型建立為x_{t+1,t+n}=f(x_{t−m,t})xt+1,t+n=f(xt−m,t),即被預測時間段的值由歷史時間段的值決定。在訓練階段,把訓練集切分為一個一個的輸入輸出對,用SGD迭代減少輸出和目標真實值的誤差,在預測階段用被預測資料前一段的歷史資料作為輸入,得到預測輸出。
現階段,基於深度學習的預測演算法絕大多數都屬於滾動時間序列預測類別。
時序預測與機器翻譯的相同點
• 資料輸入都是一個時間序列矩陣
○ 時序預測的輸入矩陣為(t, d_{feature})(t,dfeature), t為輸入時間長度,d_{feature}dfeature為每個時間點的資料維度
○ nlp的輸入矩陣為(t, d_{embed})(t,dembed),t為輸入句子的最大長度,d_{embed}dembed為此嵌入向量長度
• 都是一個seq2seq的問題,可以用RNN-AE以及Transformer解決
時序預測與機器翻譯的不同點
• nlp中,詞語需要一系列預處理才能得到網路輸入矩陣而時序預測中,輸入矩陣是自然形成的。
• nlp中,線上翻譯採取了滾動輸出的方式,nlp輸出先做softmax並匹配為單詞後,重新做embedding才作為下一次預測的輸入,這樣的作法可以克服一部分誤差累積。而在時序預測中,如果採取滾動輸出的方式,上一個時間點的輸出是直接被當作下一時間點的輸入的。這樣可能會帶來過多的誤差累積。
Informer論文分析
Transformer近些年來成為了時序預測的主流模型。在剛剛結束的AAAI 2021中,來自北航的論文
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting得到了BEST paper的榮譽。Informer論文的主體依然採取了transformer encoder-decoder的結構。在transformer的基礎上,informer做出了諸多提高效能以及降低複雜度的改進。
1)Probsparse attention
a. transformer最大的特點就是利用了attention進行時序資訊傳遞。傳統transformer在資訊傳遞時,需要進行兩次矩陣乘,即(softmax(QK)^T/\sqrt{d})∗V(softmax(QK)T/d)∗V,則attention的計算複雜度為O(L_q L_k)O(LqLk),其中L_qLq 為query矩陣的時間長度,L_kLk 為key矩陣的時間長度。為了減少attention的計算複雜度,作者提出,attention的資訊傳遞過程具有稀疏性。以t時間為例,並非所有t時間以前的時間點都和t時間點都有關聯性。部分時間點和t時間點的關聯性非常小,可以忽略。如果忽略掉這些時間點和t時間點的attention計算,則可以降低計算複雜度。
b. attention的數學表示式為
out_i=softmax(q_i K^T)V=\sum\limits_{j=1}^{L_k}\frac{\exp(q_ik_j^T/\sqrt{d})}{\sum\limits_{l=1}^{L_k}\exp(q_ik_j^T/\sqrt{d})}v_j=\sum\limits_{j=1}^{L_k}p(k_j|q_i)v_jouti=softmax(qiKT)V=j=1∑Lkl=1∑Lkexp(qikjT/d)exp(qikjT/d)vj=j=1∑Lkp(kj∣qi)vj
在計算attention的時候,若q_iqi 和key矩陣整體相關性較低,則p(k_j |q_i )p(kj∣qi)退化為均勻分佈,這時,attention的output退化為了對value矩陣的行求均值。因此,可以用p(k_j |q_i )p(kj∣qi)和均勻分佈的差別,即p(k_j |q_i )p(kj∣qi)和均勻分佈的KL散度,來度量queryq_iqi 的稀疏度。如果KL散度高,則按照傳統方法求attention,如果KL散度低,則用對V求行平均的方式代替attention。總的來說,INFORMER中提出了一種度量query稀疏度(和均勻分佈的相似程度)並用value的行平均近似attention的方法。
c. 具體的令q為均勻分佈,p為p(k_j |q_i )p(kj∣qi),則KL散度經過計算為M(q_i,K)=\ln\sum\limits_{j=1}^{L_k}e^{\frac{q_ik_j^T}{\sqrt{d}}}-\frac{1}{L_k}\sum\limits_{j=1}^{L_k}\frac{q_ik_j^T}{\sqrt{d}}M(qi,K)=lnj=1∑LkedqikjT−Lk1j=1∑LkdqikjT
按照INFORMER的思想,即可對每一個query計算KL散度,然後取topk的query進行attention,其餘的query對應的行直接用V的行平均進行填充。
d. 根據以上的思想,在attention的時候確實可以降低複雜度,然而,在排序的時候,複雜度依然是O(L_k L_q)O(LkLq)。因此,作者又提出了一種對M(q_i,K)M(qi,K)排序進行近似計算的方式。在這裡,由於證明涉及到我的一些陌生領域,例如隨機微分,我並沒有深入取細嚼慢嚥。這裡就直接呈現結論。
i. M(q_i,K)=\ln\sum\limits_{j=1}^{L_k}e^{q_i k_j^T/\sqrt{d}} −\frac{1}{L_k}\sum\limits_{j=1}^{L_k}{q_i k_j^T}/\sqrt{d}M(qi,K)=lnj=1∑LkeqikjT/d−Lk1j=1∑LkqikjT/d 可以用其上界\bar{M}(q_i,K)=\max\limits_j({q_i k_j^T/\sqrt{d}}) −\frac{1}{L_k}\sum\limits_{j=1}^{L_k}{q_i k_j^T}/\sqrt{d}Mˉ(qi,K)=jmax(qikjT/d)−Lk1j=1∑LkqikjT/d代替,作者證明近似後大概率不影響排序。
ii. 上界在計算的時候可以只隨機取樣一部分k_jkj,減少k_jkj 也就減少了乘法的次數,降低了複雜度。作者在附錄中證明這樣的隨機取樣近似大概率對排序沒有影響。
e.
i. 作者在附錄中,給出了probsparse self-attention的具體實施過程
ii. 在第2行,對K進行取樣,使得sparse 排序過程得以簡化複雜度
iii. 在第5行,只選出top-u作為query,使得attention做了簡化
f. 關於probsparse,需要注意的問題有以下幾點:
i. 這個機制只用在了self-attention中。在文中,作者把提出的方法稱為了prob-sparse self-attention,在原始碼中,也只用在了self-attention中。至於為什麼不能用於cross-attention,現在不太清楚。
ii. 這個機制在有三角矩陣masking的情況下也不能用,因為在有masking的情況下,query和key的乘法數量本來就減少了。
iii. 因此,probsparse只能用於encoder的self-attention中
iv. 雖然論文中提出probsparse可以減少複雜度,但由於增加了排序的過程,不一定能減少計算時間,在一些資料長度本來就較少的情況下,可能會增加計算時間。
2)Attention distilling
a. 與普通transformer不同的是,由於probsparse self-attention中,一些資訊是冗餘的,因此在後面採取了convolution+maxpooling的方法,減少冗餘資訊。這種方法也只能在encoder中使用。
3)CNN feed forward network
a. 在17年的transformer中,feedforward網路是用全連線網路構成的,在informer中,全連線網路由CNN代替。
4)Time stamp embedding
a. Time stamp embedding也是Informer的一個特色。在普通的transformer中,由於qkv的乘法並不區分矩陣行的先後順序,因此要加一個positional encoding。在INFORMER中,作者把每個時間點時間戳的年,月,日等資訊,也加入作為encoding的一部分,讓transformer能更好的學習到資料的週期性。
5)Generative decoding
a. 在NLP中,decoding部分是迭代輸出的。這樣的作法如果在時序預測中應用的化,在長序列預測中會引起較長的計算複雜度。其次,由於NLP中有詞語匹配的過程,會較少噪聲累積,而時序預測中,這種噪聲累積則會因為單步滾動預測而變得越發明顯。
b. 因此,在decoding時,作者採取了一次輸出多部預測的方式,decoder輸入是encoder輸入的後面部分的擷取+與預測目標形狀相同的0矩陣。
其中,X_{token}Xtoken 由X_{feed\_en}Xfeed_en 後半部分擷取而成。
6) Informer程式碼:https://github.com/zhouhaoyi/Informer2020
- 帶你掌握 C 中三種類成員初始化方式
- 實踐GoF的設計模式:工廠方法模式
- DCM:一個能夠改善所有應用資料互動場景的中介軟體新秀
- 手繪圖解java類載入原理
- 關於加密通道規範,你真正用的是TLS,而非SSL
- 程式碼重構,真的只有複雜化一條路嗎?
- 解讀分散式排程平臺Airflow在華為雲MRS中的實踐
- 透過例項demo帶你認識gRPC
- 帶你聚焦GaussDB(DWS)儲存時遊標使用
- 傳統到敏捷的轉型中,誰更適合做Scrum Master?
- 輕鬆解決研發知識管理難題
- Java中觀察者模式與委託,還在傻傻分不清
- 如何使用Python實現影象融合及加法運算?
- 什麼是強化學習?
- 探索開源工作流引擎Azkaban在MRS中的實踐
- GaussDB(DWS) NOT IN優化技術解密:排他分析場景400倍效能提升
- Java中觀察者模式與委託,還在傻傻分不清
- Java中的執行緒到底有哪些安全策略
- 一圖詳解java-class類檔案原理
- Java中的執行緒到底有哪些安全策略