梯度提升樹(GBDT)詳解之二:分類舉例
在2006年12月召開的 IEEE 資料探勘國際會議上(ICDM, International Conference on Data Mining),與會的各位專家選出了當時的十大資料探勘演算法( top 10 data mining algorithms ),可以參見文獻【1】。作為整合學習(Ensemble learning)的一個重要代表AdaBoost入選其中。但基於Boosting 思想設計的演算法,也是常常用來同AdaBoost進行比較的另外一個演算法就是Gradient Boost。它是在傳統機器學習演算法裡面是對真實分佈擬合的最好的幾種演算法之一,在前幾年深度學習還沒有大行其道之時,Gradient Boost幾乎橫掃各種資料探勘(Data mining)或知識發現(Knowledge discovery)競賽。
梯度提升樹(GBDT, Gradient Boosted Decision Trees),或稱Gradient Tree Boosting,是一個以決策樹為基學習器,以Boost為框架的加法模型的整合學習技術。在之前的文章【2】中,我們已經見識過利用GBDT進行迴歸分析的具體方法,希望讀者在此基礎之上閱讀本文。
我們已經知道,GBDT可以用來做迴歸,也可以用來做分類。本文將著重介紹利用GBDT進行分類的具體方法。作為一個例子,我們將要使用的訓練資料集(Train Set)如下表所示,觀眾的年齡、是否喜歡吃爆米花,以及他們喜歡的顏色共同構成了特徵向量,而每個觀眾是否喜歡看科幻電影則是作為分類的標籤。
就像之前我們在利用GBDT做迴歸時一樣,作為開始,先構建一個單結點的決策樹來作為對每個個體的初始估計或初始預測。之前在做迴歸時,這一步操作是取平均值,而現在我們要做的是分類任務,那麼這個用作初始估計的值使用的就是log(odds),這跟邏輯迴歸中所使用的是一樣的,具體你可以參考我之前的文章【3】和【4】。
我們如何使用該值來進行分類預測呢?回憶邏輯迴歸時所做的,我們將其轉化成為一個概率值,具體來說,是利用Logistic函式來將log(odds)值轉化成概率。因此,可得給定一個觀眾,他喜歡看科幻電影的概率是
注意,為了方便計算,我們保留小數點後面一位有效數字,但其實log(odds)和喜歡看科幻片的概率這兩個0.7只是近似計算之後出於巧合才相等的,二者直接並沒有必然聯絡。
現在我們知道給定一個觀眾,他喜歡看科幻電影的概率是0.7 ,說明概率大於50%,所以我們最終斷定他是喜歡看科幻電影的。這個最初的猜測或者估計怎麼樣呢?現在這是一個非常粗糙的估計。在訓練資料集上,有4個人的預測結果正確,而另外兩人的預測結果錯誤。或者說現在構建的模型在訓練資料集上擬合的還不十分理想。我們可以使用偽殘差來衡量最初的估計距離真實情況有多遠。正如前面在迴歸分析時我們所做過的那樣,這裡的偽殘差就是指觀察值與預測值之間的差。
如上圖所示,紅色的兩個點,表示資料集中不喜歡看科幻片的兩個觀眾,或者說他們喜歡看科幻片的概率為0。類似地,藍色的四個點,表示資料集中喜歡看科幻片的四個觀眾,或者說他們喜歡看科幻片的概率為1。紅色以及藍色的點是實際觀察值,而灰色的虛線是我們的預測值。因此可以很容易算得每個資料點的偽殘差如下:
接下來要做的事情就跟之前做迴歸時一樣了,也就是基於特徵向量來構建決策樹從而預測上表中給出的偽殘差。於是得到下面這棵決策樹。注意,我們需要限定決策樹中允許之葉子數量的上限(在歸回中也有類似限定),在本例中我們設定這個上限是3,畢竟這裡的例子中資料規模非常小。但實踐中,當面對一個較大的資料集時,通常會設定允許的最多葉子數量在8到32個之間。在使用scikit-learn工具箱時,例項化類sklearn.ensemble.GradientBoostingClassifier時,可以通過指定max_leaf_nodes來控制該值。
但如何使用上面這棵剛剛構建起來的決策樹相對而言是比較複雜的。注意到,最初的估計值log(odds)≈0.7是一個對數似然,而新建的決策樹葉子結點給出的偽殘差是基於概率值的,因此二者是不能直接做加和的。一些轉化是必不可少的,具體來說,在利用GBDT做分類時,需要使用如下這個變換對每個葉子結點上的值做進一步的處理。
上面這個變換的推導涉及到一些數學細節,這一點我們留待後續文章再做詳細闡釋。現在,我們基於上述公式在做計算:例如,對於第一個只有一個值-0.7的葉子結點而言,因為只有一個值,所以可以忽略上述公式中的加和符號,即有
注意,因為在構建這棵決策樹之前,我們的上一棵決策樹,對於所有觀眾給出的喜歡看科幻片的概率都是0.7,所以上述公式中Previous Probability則帶入該值。於是,將該葉子結點中的值替換為-3.3。
接下來,計算包含0.3和-0.7這兩個值的葉子結點。於是有
注意,因為葉子結點中包含有兩個偽殘差,因此分母上的加和部分就是對每個偽殘差所對應的結果都加一次。另外,目前Previous Probability對於每個偽殘差都是一樣的(因為上一棵決策樹中只有一個結點),但我們在生成下一次決策樹時情況就不一樣了。
類似地,最後一個結點的計算結果如下
所以現在決策樹變成了如下這個樣子
現在,基於之前建立的決策樹,以及現在剛剛得到的新決策樹,就可以對每個觀眾是否喜歡看科幻片進行預測了。與之前回歸的情況一樣,我們還要使用一個learning rate對新得到的決策樹進行縮放。這裡我們使用的值是0.8,出於演示方便的目的這裡所取的值相對較大。在使用scikit-learn工具箱時,例項化類sklearn.ensemble.GradientBoostingClassifier時,可以通過指定引數learning_rate的值來控制學習率,該引數的預設值為0.1。
例如現在要計算上表中第一名觀眾的log(odds),可得0.7 + (0.8×1.4) = 1.8。所以,利用Logistic函式計算概率得
注意到,此前我們對第一名觀眾是否喜歡看科幻片的概率估計是0.7,現在是0.9,顯然我們的模型朝著更好的方向前進了一小步。採用上述這個方法,下面逐個計算其餘觀眾喜歡看科幻片的概率,然後再計算偽殘差,可得下表之結果
接下來,基於特徵向量來構建決策樹從而預測上表中給出的偽殘差。於是得到下面這樣一棵決策樹。
再根據之前使用過的變換對每個葉子結點上的值做進一步的處理,也就是得到每個葉子結點的最終輸出。
例如,我們計算上圖中所示的表中第二行資料對應的葉子結點的輸出
其中Previous Probability就帶入上表中的概率預測值0.5。
再來計算一個稍微複雜一點的葉子,如上圖所示,有四個觀眾的預測結果都指向該葉子結點。於是有
所以該葉子結點的輸出就是0.6。這一步我們也可以看出,每個偽殘差對應的分母項未必一致,這是因為上一次的預測概率不盡相同。對所有葉子結點都計算對應輸出之後可得一棵新的決策樹如下:
現在,我們可以把所有已經得到的組合到一起了,如下圖所示。最開始,我們有一個只有一個結點的樹,在此基礎上我們得到了第二棵決策樹,現在又有了一棵新的決策樹,所有這些新生成的決策樹都通過Learning Rate進行縮放,然後再加總到最開始的單結點樹上。
接下來,根據已有的所有決策樹,又可以算得新的偽殘差。注意,第一次得到的偽殘差僅僅是根據初始估計值算得的。第二次得到的偽殘差則是基於初始估計值,連同第一棵決策樹一起算得的,而接下來將要計算的第三次偽殘差則是基於初始估計值,連同第一、二棵決策樹一起算得的。而且,每次新引入一棵決策樹後,偽殘差都會逐漸變小。偽殘差逐漸變小,就意味著構建的模型正朝著好的方向逐漸逼近。
為了得到更好的結果,我們將反覆執行計算偽殘差並構建新決策樹這一過程,直到偽殘差的變化不再顯著大於某個值,或者達到了預先設定的決策樹數量上限。在使用scikit-learn工具箱時,例項化類GradientBoostingClassifier時,可以通過指定n_estimators來控制決策樹數量的上限,該引數的預設值為100。
為了便於演示,現在假設,我們已經得到了最終的GBDT,它只有上述三棵決策樹構成(一個初始樹以及兩個後續建立的決策樹)。如果有一個年紀25歲,喜歡吃爆米花,喜歡綠色的觀眾,請問他是否喜歡看科幻電影呢?我們將這個特徵向量在三棵決策樹上執行一遍,得到葉子結點的值,再由Learning Rate作用後加和到一起,於是有0.7+(0.81.4)+(0.80.6)=2.3。注意這是一個log(odds),所以我們用Logistic函式將其轉化為概率,得0.9。因此,我們斷定該觀眾喜歡看科幻片。這就是利用已經訓練好的GBDT進行分類的具體方法。
在後續的文章中,我們還將詳細解釋Gradient Boost的數學原理,屆時讀者將會更加深刻地理解演算法如此設計的緣由。
參考文獻
*本文中的例子主要參考及改編自文獻【5】,文獻【6】中提供了在scikit-learn中利用GBDT進行分類預測的一個簡單例子。
【1】Wu, X., Kumar, V., Quinlan, J.R., Ghosh, J., Yang, Q., Motoda, H., McLachlan, G.J., Ng, A., Liu, B., Philip, S.Y. and Zhou, Z.H., 2008. Top 10 algorithms in data mining. Knowledge and information systems, 14(1), pp.1-37.
【3】部落格文章連結
【4】部落格文章連結
【5】Gradient Boost for Classification
【6】http://scikit-learn.org/stable/modules/ensemble.html#gradient-tree-boosting
- Highcharts使用HTML表中的資料建立互動式圖表教程
- bzoj1529: [POI2005]ska Piggy banks(並查集)
- 深入剖析Java中的斷言assert
- Nacos 1.4.1 之前存在鑑權漏洞,建議修復到最新版
- 記憶體操作函式:memcpy函式,memove函式
- 微軟開源 Python 自動化神器 Playwright
- Web前端之HTML
- java class檔案安全加密工具
- 5G QoS和DNN以及網路切片技術
- 有獎問答獲獎名單出爐,快來看看有沒有你!
- IDEA Groovy指令碼一鍵生成實體類,用法舒服,高效!
- 微信api呼叫限制,45009 reach max api daily quota limit 解決方法
- OpenStack Placement元件
- 【Linux伺服器開發系列】手寫使用者態協議棧,udpipeth資料包的封裝,零拷貝的實現,柔性陣列
- requestAnimationFrame詳解
- k8s叢集多容器Pod和資源共享
- 梯度提升樹(GBDT)詳解之二:分類舉例
- Automatic Model Evaluation - 知乎
- Linux下IPMI iBMC遠端管理配置查詢及密碼重置
- 京東/淘寶的手機銷售榜(前4名 -- 手機品牌 --手機型號*3 --手機解析度 -- 手機作業系統 --安卓版本號)(android / IOS)