員工離職困擾?來看AI如何解決,基於人力資源分析的 ML 模型構建全方案
攜手創作,共同成長!這是我參與「掘金日新計劃 · 8 月更文挑戰」的第26天,點選檢視活動詳情
- 💡 作者:韓信子@ShowMeAI
- 📘 資料分析實戰系列:https://www.showmeai.tech/tutorials/40
- 📘 機器學習實戰系列:https://www.showmeai.tech/tutorials/41
- 📘 本文地址:https://www.showmeai.tech/article-detail/308
- 📢 宣告:版權所有,轉載請聯絡平臺與作者並註明出處
- 📢 收藏ShowMeAI檢視更多精彩內容
人力資源是組織的一個部門,負責處理員工的招聘、培訓、管理和福利。一個組織每年都會僱傭幾名員工,並投入大量時間、金錢和資源來提高員工的績效和效率。每家公司都希望能夠吸引和留住優秀的員工,失去一名員工並再次僱傭一名新員工的成本是非常高的,HR部門需要知道僱用和留住重要和優秀員工的核心因素是什麼,那那麼可以做得更好。
在本專案中,ShowMeAI 帶大家通過資料科學和AI的方法,分析挖掘人力資源流失問題,並基於機器學習構建解決問題的方法,並且,我們通過對AI模型的反向解釋,可以深入理解導致人員流失的主要因素,HR部門也可以根據分析做出正確的決定。
本篇涉及到的資料集大家可以通過 ShowMeAI 的百度網盤地址獲取。
🏆 實戰資料集下載(百度網盤):公眾號『ShowMeAI研究中心』回覆『實戰』,或者點選 這裡 獲取本文 [17]人力資源流失場景機器學習建模與調優 『HR-Employee-Attrition 資料集』
⭐ ShowMeAI官方GitHub:https://github.com/ShowMeAI-Hub
💡 探索性資料分析
和 ShowMeAI 之前介紹過的所有AI專案一樣,我們需要先對場景資料做一個深度理解,這就是我們提到的EDA(Exploratory Data Analysis,探索性資料分析)過程。
EDA部分涉及的工具庫,大家可以參考ShowMeAI製作的工具庫速查表和教程進行學習和快速使用。 - 📘資料科學工具庫速查表 | Pandas 速查表 - 📘資料科學工具庫速查表 | Seaborn 速查表 - 📘圖解資料分析:從入門到精通系列教程
📌 資料&欄位說明
我們本次使用到的資料集欄位基本說明如下:
| 列名 | 含義 | | :------------------------ | :----------------------------------- | | Age | 年齡 | | Attrition | 離職 | | BusinessTravel | 出差:0-不出差、1-偶爾出差、2-經常出差 | | Department | 部門:1-人力資源、2-科研、3-銷售 | | DistanceFromHome | 離家距離 | | Education | 教育程度:1-大學一下、2-大學、3-學士、4-碩士、5-博士 | | EducationField | 教育領域 | | EnvironmentSatisfaction | 環境滿意度 | | Gender | 性別:1-Mae男、0- Female女 | | Joblnvolvement | 工作投入 | | JobLevel | 職位等級 | | JobRole | 工作崗位 | | JobSatisfaction | 工作滿意度 | | Maritalstatus | 婚姻狀況:0- Divorced離婚、1- Single未婚、2-已婚 | | Monthlylncome | 月收入 | | NumCompaniesWorked | 服務過幾家公司 | | OverTime | 加班 | | RelationshipSatisfaction | 關係滿意度 | | StockOptionLevel | 股權等級 | | TotalworkingYears | 總工作年限 | | TrainingTimesLastYear | 上一年培訓次數 | | WorkLifeBalance | 工作生活平衡 | | YearsAtCompany | 工作時長 | | YearsInCurrentRole | 當前崗位在職時長 | | YearsSinceLastPromotion | 上次升職時間 | | YearsWithCurrManager | 和現任經理時長 |
📌 資料速覽
下面我們先匯入所需工具庫、載入資料集並檢視資料基本資訊:
```python import pandas as pd import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt import seaborn as sns sns.set_style("darkgrid")
import warnings warnings.filterwarnings("ignore") pd.set_option('display.max_columns',100) print("import complete") ```
```python
讀取資料
data = pd.read_csv("HR-Employee-Attrition.csv")
data.head()
```
檢視前 5 條資料記錄後,我們瞭解了一些基本資訊:
① 資料包含『數值型』和『類別型』兩種型別的特徵。 ② 有不少離散的數值特徵。
📌 檢視資料基本資訊
接下來我們藉助工具庫進一步探索資料。
```python
欄位、型別、缺失情況
data.info() ```
我們使用命令 data.info``()
來獲取資料的資訊,包括總行數(樣本數)和總列數(欄位數)、變數的資料型別、資料集中非缺失的數量以及記憶體使用情況。
從資料集的資訊可以看出,一共有 35 個特徵,Attrition 是目標欄位,26 個變數是整數型別變數,9 個是物件型別變數。
📌 缺失值檢測&處理
我們先來做一下缺失值檢測與處理,缺失值的存在可能會降低模型效果,也可能導致模型出現偏差。
```python
檢視缺失值情況
data.isnull().sum() ```
從結果可以看出,資料集中沒有缺失值。
📌 特徵編碼
因為目標特徵“Attrition”是一個類別型變數,為了分析方便以及能夠順利建模,我們對它進行類別編碼(對映為整數值)。
```python
since Attrition is a categotical in nature so will be mapping it with integrs variables for further analysis
data.Attrition = data.Attrition.map({"Yes":1,"No":0}) ```
📌 資料統計概要
接下來,我們藉助於pandas的describe函式檢查數值特徵的統計摘要:
```python
checking statistical summary
data.describe().T ```
注意這裡的“.T”是獲取資料幀的轉置,以便更好地分析。
從統計摘要中,我們得到資料的統計資訊,包括資料的中心趨勢——平均值、中位數、眾數和散佈標準差和百分位數,最小值和最大值等。
📌 數值型特徵分析
我們進一步對數值型變數進行分析
```python
選出數值型特徵
numerical_feat = data.select_dtypes(include=['int64','float64']) numerical_feat ```
python
print(numerical_feat.columns)
print("No. of numerical variables :",len(numerical_feat.columns))
print("Number of unique values \n",numerical_feat.nunique())
我們有以下觀察結論:
① 共有27個數值型特徵變數 ② 月收入、日費率、員工人數、月費率等為連續數值 ③ 其餘變數為離散數值(即有固定量的候選取值)
我們藉助於 seaborn 工具包中的分佈圖方法 sns.distplot()
來繪製數值分佈圖
```python
資料分析&分佈繪製
plt.figure(figsize=(25,30)) plot = 1 for var in numerical_feat: plt.subplot(9,3,plot) sns.distplot(data[var],color='skyblue') plot+=1 plt.show() ```
通過以上分析,我們獲得以下一些基本觀察和結論:
- 大多數員工都是 30 多歲或 40 多歲
- 大多數員工具有 3 級教育
- 大多數員工將環境滿意度評為 3 和 4
- 大多數員工的工作參與度為 3 級
- 大多數員工來自工作級別 1 和 2
- 大多數員工將工作滿意度評為 3 和 4
- 大多數員工只在 1 個公司工作過
- 大多數員工的績效等級為 3
- 大多數員工要麼沒有股票期權,要麼沒有一級股票期權
- 大多數員工有 5-10 年的工作經驗
- 大多數員工的工作與生活平衡評分為 3
接下來我們對目標變數做點分析:
```python
目標變數分佈
sns.countplot('Attrition',data=data) plt.title("Distribution of Target Variable") plt.show() print(data.Attrition.value_counts()) ```
我們可以看到資料集中存在類別不平衡問題(流失的使用者佔比少)。類別不均衡情況下,我們要選擇更有效的評估指標(如auc可能比accuracy更有效),同時在建模過程中做一些優化處理。
我們分別對各個欄位和目標欄位進行聯合關聯分析。
```python
Age 與 attrition
age=pd.crosstab(data.Age,data.Attrition) age.div(age.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(14,7),cmap='spring') plt.title("Age vs Attrition",fontsize=20) plt.show() ```
```python
Distance from home 與 attrition
dist=pd.crosstab(data.DistanceFromHome,data.Attrition) dist.div(dist.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7)) plt.title("Distance From Home vs Attrition",fontsize=20) plt.show() ```
```python
Education 與 Attrition
edu=pd.crosstab(data.Education,data.Attrition) edu.div(edu.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7)) plt.title("Education vs Attrition",fontsize=20) plt.show() ```
```python
Environment Satisfaction 與 Attrition
esat=pd.crosstab(data.EnvironmentSatisfaction,data.Attrition) esat.div(esat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='BrBG') plt.title("Environment Satisfaction vs Attrition",fontsize=20) plt.show() ```
```python
Job Involvement 與 Attrition
job_inv=pd.crosstab(data.JobInvolvement,data.Attrition) job_inv.div(job_inv.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Spectral') plt.title("Job Involvement vs Attrition",fontsize=20) plt.show() ```
```python
Job Level 與 Attrition
job_lvl=pd.crosstab(data.JobLevel,data.Attrition) job_lvl.div(job_lvl.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='prism_r') plt.title("Job Level vs Attrition",fontsize=20) plt.show() ```
```python
Job Satisfaction 與 Attrition
job_sat=pd.crosstab(data.JobSatisfaction,data.Attrition) job_sat.div(job_sat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='inferno') plt.title("Job Satisfaction vs Attrition",fontsize=20) plt.show() ```
```python
Number of Companies Worked 與 Attrition
num_org=pd.crosstab(data.NumCompaniesWorked,data.Attrition) num_org.div(num_org.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='cividis_r') plt.title("Number of Companies Worked vs Attrition",fontsize=20) plt.show() ```
```python
Percent Salary Hike 與 Attrition
sal_hike_percent=pd.crosstab(data.PercentSalaryHike,data.Attrition) sal_hike_percent.div(sal_hike_percent.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='RdYlBu') plt.title("Percent Salary Hike vs Attrition",fontsize=20) plt.show() ```
```python
Performance Rating 與 Attrition
performance=pd.crosstab(data.PerformanceRating,data.Attrition) performance.div(performance.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='viridis_r') plt.title("Performance Rating vs Attrition",fontsize=20) plt.show() ```
```python
Relationship Satisfaction 與 Attrition
rel_sat=pd.crosstab(data.RelationshipSatisfaction,data.Attrition) rel_sat.div(rel_sat.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='brg_r') plt.title("Relationship Satisfaction vs Attrition",fontsize=20) plt.show() ```
```python
Stock Option Level 與 Attrition
stock_opt=pd.crosstab(data.StockOptionLevel,data.Attrition) stock_opt.div(stock_opt.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Accent') plt.title("Stock Option Level vs Attrition",fontsize=20) plt.show() ```
```python
Training Times Last Year 與 Attrition
tr_time=pd.crosstab(data.TrainingTimesLastYear,data.Attrition) tr_time.div(tr_time.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='coolwarm') plt.title("Training Times Last Year vs Attrition",fontsize=20) plt.show() ```
```python
Work Life Balance 與 Attrition
work=pd.crosstab(data.WorkLifeBalance,data.Attrition) work.div(work.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='gnuplot') plt.title("Work Life Balance vs Attrition",fontsize=20) plt.show() ```
```python
Years With Curr Manager 與 Attrition
curr_mang=pd.crosstab(data.YearsWithCurrManager,data.Attrition) curr_mang.div(curr_mang.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='OrRd_r') plt.title("Years With Curr Manager vs Attrition",fontsize=20) plt.show() ```
```python
Years Since Last Promotion 與 Attrition
prom=pd.crosstab(data.YearsSinceLastPromotion,data.Attrition) prom.div(prom.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='PiYG_r') plt.title("Years Since Last Promotion vs Attrition",fontsize=20) plt.show() ```
```python
Years In Current Role 與 Attrition
role=pd.crosstab(data.YearsInCurrentRole,data.Attrition) role.div(role.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='terrain') plt.title("Years In Current Role vs Attrition",fontsize=20) plt.show() ```
這些堆積條形圖顯示了員工流失情況與各個欄位取值的關係,從上圖我們可以得出以下基本結論:
- 30 歲以下或特別是年輕的員工比年長員工的流失率更高。
- 離家較遠的員工的流失率較高,即如果員工住在附近,他或她離開公司的機會較小。
- 3 級和1 級教育的流失率較高,5 級教育的流失率最低,高等教育水平的候選人穩定性更高。
- 環境滿意度低會導致較高的人員流失率,滿意度1 級的人員流失率較高,4 級人員的流失率最低。
- 工作參與級別 1 的員工流失率較高,級別 4 的員工流失率最低,這意味著工作參與度更高的員工離職機會較低。
- 級別 1 和級別 3 的員工流失率較高,級別 5 的員工流失率最低,即職位級別較高的員工流失的可能性較小。
- 工作滿意度級別 1 的員工流失率較高,級別 4 的員工流失率最低,工作滿意度較高的員工流失的可能性較小。
- 在超過四家公司工作過的員工流失率較高,這個欄位本身在一定程度上體現了員工的穩定性。
- 1 級關係滿意度較高,4 級最少,這意味著與僱主關係好的員工流失可能性較低。
- 股票期權級別為 0 的員工流失率較高,而級別 1 和 2 的員工流失率較低,這意味著如果員工持有股票,會更傾向於留下
- 工作與生活平衡水平為 1 的員工流失率高,或者我們可以說工作與生活平衡低的員工更可能流失。
- 自過去 8 年以來未晉升的員工有大量流失。
- 隨著與經理相處的時間變長,員工流失率會下降。
📌 類別型特徵分析
現在我們對類別型特徵進行分析,在這裡我們使用餅圖和堆積條形圖來分析它們的分佈以及它們和目標變數的相關性。
```python
分析Buisness Travel
colors=['red','green','blue'] size = data.BusinessTravel.value_counts().values explode_list=[0,0.05,0.1] plt.figure(figsize=(15,10)) plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.15) plt.title("Business Travel",fontsize=15) plt.legend(labels=['Travel_Rarely','Travel_Frequently','Non-Travel'],loc='upper left') plt.show() ```
```python
分析Department
colors=['orchid','gold','olive'] size = data.Department.value_counts().values explode_list=[0,0.05,0.06] plt.figure(figsize=(15,10)) plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1) plt.title("Department",fontsize=15) plt.legend(labels=['Sales','Research & Development','Human Resources'],loc='upper left') plt.show() ```
```python
分析Education Field
colors=["cyan","orange","hotpink","green","navy","grey"] size = data.EducationField.value_counts().values explode_list=[0,0.05,0.05,0.08,0.08,0.1] plt.figure(figsize=(15,10)) plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1) plt.title("Education Field",fontsize=15) plt.legend(labels=['Life Sciences','Other','Medical','Marketing','Technical Degree','Human Resources'],loc='upper left') plt.show() ```
```python
分析婚姻狀況
colors=["red","orange","magenta","green","navy","grey","cyan","blue","black"] size = data.JobRole.value_counts().values explode_list=[0,0.05,0.05,0.05,0.08,0.08,0.08,0.1,0.1] plt.figure(figsize=(15,10)) plt.pie(size,labels=None,explode=explode_list,colors=colors,autopct="%1.1f%%",pctdistance=1.1) plt.title("Job Role",fontsize=15) plt.legend(labels=['Sales Executive','Research Scientist','Laboratory Technician','Manufacturing Director','Healthcare Representative','Manager','Sales Representative','Research Director','Human Resources'],loc='upper left') plt.show() ```
```python
分析gender性別
plt.figure(figsize=(10,9)) plt.title('Gender distribution',fontsize=15) sns.countplot('Gender',data=data,palette='magma') ```
從上面的圖中,我們獲得了一些資訊:
- 大部分員工很少出差。
- 銷售部門是公司的主體,研發佔公司的30%左右,人力資源佔比最小。
- 擁有生命科學教育背景的員工數量較多,而人力資源教育背景的員工數量較少。
- 大部分員工來自銷售職位,最少來自人力資源部門。
- 大部分員工未婚。
- 公司中男性的數量多於女性。
下面做關聯分析:
```python
Business Travel 與 Attrition
trav = pd.crosstab(data.BusinessTravel,data.Attrition) trav.div(trav.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1') plt.title("Business Travel vs Attrition",fontsize=20) plt.show() ```
```python
Department 與 Attrition
dept = pd.crosstab(data.Department,data.Attrition) dept.div(dept.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1') plt.title("Department vs Attrition",fontsize=20) plt.show() ```
```python
Education Field 與 Attrition
edu_f = pd.crosstab(data.EducationField,data.Attrition) edu_f.div(edu_f.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1') plt.title("Education Field vs Attrition",fontsize=20) plt.show() ```
```python
Job Role 與 Attrition
jobrole = pd.crosstab(data.JobRole,data.Attrition) jobrole.div(jobrole.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1') plt.title("Job Role vs Attrition",fontsize=20) plt.show() ```
```python
Marital Status 與 Attrition
mary = pd.crosstab(data.MaritalStatus,data.Attrition) mary.div(mary.sum(1),axis=0).plot(kind='bar',stacked=True,figsize=(12,7),cmap='Set1') plt.title("Marital Status vs Attrition",fontsize=20) plt.show() ```
```python
gender 與 Attrition
plt.figure(figsize=(10,9)) plt.title('Gender distribution',fontsize=15) sns.countplot('Gender',data=data,palette='magma') ```
上圖反應了一些資訊:
- 經常出差員工的離職率較高,非出差員工離職率較低,也就是說,經常出差的員工更有可能流失。
- 銷售和人力資源的流失率較高,而研發的流失率較低。
- 人力資源教育背景的流失率較高,而醫學和其他教育背景的流失率最低。,醫學和其他教育背景的員工離職的可能性較小。
- 銷售代表、人力資源和實驗室技術人員的流失率最高。
- 未婚員工離職率較高,離婚員工離職率最低。
- 男性員工的流失率更高。
📌 相關性分析
我們計算特徵之間的相關係數並繪製熱力圖:
```python
計算相關度矩陣並繪製熱力圖
plt.figure(figsize=(20,15)) sns.heatmap(data.corr(method='spearman'),annot=True,cmap='Accent') plt.title('Correlation of features',fontsize=20) plt.show() ```
```python
相關度排序
plt.figure(figsize=(15,9)) correlation = data . corr(method='spearman') correlation.Attrition.sort_values(ascending=False).drop('Attrition').plot.bar(color='r') plt.title('Correlation of independent features with target feature',fontsize=20) plt.show() ```
📌 異常值檢測與處理
下面我們檢測一下資料集中的異常值,在這裡,我們使用箱線圖來視覺化分佈並檢測異常值。
```python
繪製箱線圖
plot=1 plt.figure(figsize=(15,30)) for i in numerical_feat.columns: plt.subplot(9,3,plot) sns.boxplot(data[i],color='navy') plt.xlabel(i) plot+=1 plt.show() ```
箱線圖顯示資料集中有不少異常值,不過這裡的異常值主要是因為離散變數(可能是取值較少的候選),我們將保留它們(不然會損失掉這些樣本資訊),不過我們注意到月收入的異常值比較奇怪,這可能是由於資料收集錯誤造成的,可以清洗一下。
💡 特徵工程
關於機器學習特徵工程,大家可以參考 ShowMeAI 整理的特徵工程最全解讀教程。
📌 類別均衡處理
下面我們來完成特徵工程的部分,從原始資料中抽取強表徵的資訊,以便模型能更直接高效地挖掘和建模。
我們在EDA過程中發現 MonthlyIncome、JobLevel 和 YearsAtCompany 以及 YearsInCurrentRole 高度相關,可能會帶來多重共線性問題,我們會做一些篩選,同時我們會刪除一些與 EmployeeCount、StandardHours 等變數不相關的特徵,並剔除一些對預測不重要的特徵。
```python dataset = data.copy()
刪除與目標相關性低的Employee count 和 standard hours特徵
dataset.drop(['EmployeeCount','StandardHours'],inplace=True,axis=1) dataset.head(2) ```
下面我們對類別型特徵進行編碼,包括數字對映與獨熱向量編碼。
```python
按照出差的頻度進行編碼
dataset.BusinessTravel = dataset.BusinessTravel.replace({ 'Non-Travel':0,'Travel_Rarely':1,'Travel_Frequently':2 })
性別與overtime編碼
dataset.Gender = dataset.Gender.replace({'Male':1,'Female':0}) dataset.OverTime = dataset.OverTime.replace({'Yes':1,'No':0})
獨熱向量編碼
new_df = pd.get_dummies(data=dataset,columns=['Department','EducationField','JobRole', 'MaritalStatus']) new_df ```
處理與轉換後的資料如下所示:
在前面的資料探索分析過程中,我們發現目標變數是類別不平衡的,因此可能會導致模型偏向多數類而帶來偏差。我們在這裡會應用過取樣技術 SMOTE(合成少數類別的樣本補充)來處理資料集中的類別不平衡問題。
我們把資料先切分為特徵和標籤,處理之前的標籤類別比例如下:
```python
切分特徵和標籤
X = new_df.drop(['Attrition'],axis=1) Y = new_df.Attrition
標籤01取值比例
sns.countplot(data=new_df,x=Y,palette='Set1') plt.show() print(Y.value_counts()) ```
應用過取樣技術 SMOTE:
```python
SMOTE處理類別不均衡
from imblearn.over_sampling import SMOTE sm = SMOTE(sampling_strategy='minority') x,y = sm.fit_resample(X,Y) print(x.shape," \t ",y.shape)
(2466, 45) (2466,)
```
過取樣後
```python
過取樣之後的比例
sns.countplot(data=new_df,x=y,palette='Set1') plt.show() print(y.value_counts()) ```
📌 特徵幅度縮放
現在資料集已經類別均衡了,我們做一點特徵工程處理,比如有些模型對於特徵值的幅度是敏感的,我們做一點幅度縮放,這裡我們呼叫sklearn.preprocessing 類中的 MinMaxScaler 方法。
```python
特徵幅度縮放
from sklearn.preprocessing import MinMaxScaler scaler = MinMaxScaler() x_scaled = scaler.fit_transform(x) x_scaled = pd.DataFrame(x_scaled, columns=x.columns) x_scaled ```
處理後我們的資料集看起來像這樣
所有取值都已調整到 0 -1 的幅度範圍內。
💡 分析特徵重要性
通常在特徵工程之後,我們會得到非常多的特徵,太多特徵會帶來模型訓練效能上的問題,不相關的差特徵甚至會拉低模型的效果。
我們很多時候會進行特徵重要度分析的工作,篩選和保留有效特徵,而對其他特徵進行剔除。我們先將資料集拆分為訓練集和測試集,再基於互資訊判定特徵重要度。
```python
訓練集測試集切分
from sklearn.model_selection import train_test_split
xtrain,xtest,ytrain,ytest = train_test_split(x_scaled,y,test_size=0.3,random_state=1) ```
我們使用 sklearn.feature_selection 類中的mutual_info_classif 方法來獲得特徵重要度。Mutual _info_classif的工作原理是類似資訊增益。
```python from sklearn.feature_selection import mutual_info_classif
mutual_info = mutual_info_classif(xtrain,ytrain) mutual_info ```
下面我們繪製一下特徵重要性
```python mutual_info = pd.Series(mutual_info) mutual_info.index = xtrain.columns mutual_info.sort_values(ascending=False)
plt.title("Feature Importance",fontsize=20) mutual_info.sort_values().plot(kind='barh',figsize=(12,9),color='r') plt.show() ```
當然,實際判定特徵重要度的方式有很多種,甚至結果也會有一些不同,我們只是基於這個步驟,進行一定的特徵篩選,把最不相關的特徵剔除。
💡 模型構建和評估
關於建模與評估,大家可以參考 ShowMeAI 的機器學習系列教程與模型評估基礎知識文章。 - 📘圖解機器學習演算法:從入門到精通系列教程 - 📘圖解機器學習演算法(1) | 機器學習基礎知識 - 📘圖解機器學習演算法(2) | 模型評估方法與準則
好,我們前序工作就算完畢啦!下面要開始構建模型了。在建模之前,有一件非常重要的事情,是我們需要選擇合適的評估指標對模型進行評估,這能給我們指明模型優化的方向,我們在這裡,針對分類問題,儘量覆蓋地選擇了下面這些評估指標
- 準確度得分
- 混淆矩陣
- precision
- recall
- F1-score
- Auc-Roc
我們這裡選用了8個模型構建baseline,並應用交叉驗證以獲得對模型無偏的評估結果。
```python
匯入工具庫
from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.svm import SVC from sklearn.neighbors import KNeighborsClassifier from sklearn.naive_bayes import BernoulliNB from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import AdaBoostClassifier from sklearn.ensemble import GradientBoostingClassifier from sklearn.model_selection import cross_val_score,cross_validate from sklearn.metrics import classification_report,confusion_matrix,accuracy_score,plot_roc_curve,roc_curve,auc,roc_auc_score,precision_score,r
初始化baseline模型(使用預設引數)
LR = LogisticRegression() KNN = KNeighborsClassifier() SVC = SVC() DTC = DecisionTreeClassifier() BNB = BernoulliNB() RTF = RandomForestClassifier() ADB = AdaBoostClassifier() GB = GradientBoostingClassifier()
構建模型列表
models = [("Logistic Regression ",LR), ("K Nearest Neighbor classifier ",KNN), ("Support Vector classifier ",SVC), ("Decision Tree classifier ",DTC), ("Random forest classifier ",RTF), ("AdaBoost classifier",ADB), ("Gradient Boosting classifier ",GB), ("Naive Bayes classifier",BNB)] ```
接下來我們遍歷這些模型進行訓練和評估:
```python for name,model in models: model.fit(xtrain,ytrain) print(name," trained")
遍歷評估
train_scores=[] test_scores=[] Model = [] for name,model in models: print("**",name,"****") train_acc = accuracy_score(ytrain,model.predict(xtrain)) test_acc = accuracy_score(ytest,model.predict(xtest)) print('Train score : ',train_acc) print('Test score : ',test_acc) train_scores.append(train_acc) test_scores.append(test_acc) Model.append(name)
不同的評估準則
precision_ =[] recall_ = [] f1score = [] rocauc = [] for name,model in models: print("**",name,"****") cm = confusion_matrix(ytest,model.predict(xtest)) print("\n",cm) fpr,tpr,thresholds=roc_curve(ytest,model.predict(xtest)) roc_auc= auc(fpr,tpr) print("\n","ROC_AUC_SCORE : ",roc_auc) rocauc.append(roc_auc) print(classification_report(ytest,model.predict(xtest))) precision = precision_score(ytest, model.predict(xtest)) print('Precision: ', precision) precision_.append(precision) recall = recall_score(ytest, model.predict(xtest)) print('Recall: ', recall) recall_.append(recall) f1 = f1_score(ytest, model.predict(xtest)) print('F1 score: ', f1) f1score.append(f1) plt.figure(figsize=(10,20)) plt.subplot(211) print(sns.heatmap(cm,annot=True,fmt='d',cmap='Accent')) plt.subplot(212) plt.plot([0,1],'k--') plt.plot(fpr,tpr) plt.xlabel('false positive rate') plt.ylabel('true positive rate') plt.show() ```
我們把所有的評估結果彙總,得到一個模型結果對比表單
```python
構建一個Dataframe儲存所有模型的評估指標
evaluate = pd.DataFrame({}) evaluate['Model'] = Model evaluate['Train score'] = train_scores evaluate['Test score'] = test_scores evaluate['Precision'] = precision_ evaluate['Recall'] = recall_ evaluate['F1 score'] = f1score evaluate['Roc-Auc score'] = rocauc
evaluate ```
我們從上述baseline模型的彙總評估結果裡看到:
- 邏輯迴歸和隨機森林在所有模型中表現最好,具有最高的訓練和測試準確度得分,並且它具有低方差的泛化性
- 從precision精度來看,邏輯迴歸0.976、隨機森林0.982,也非常出色
- 從recall召回率來看,Adaboost、邏輯迴歸、KNN表現都不錯
- F1-score會綜合precision和recall計算,這個指標上,邏輯迴歸、隨機森林、Adaboost表現都不錯
- Roc-Auc評估的是排序效果,它對於類別不均衡的場景,評估非常準確,這個指標上,邏輯迴歸和隨機森林、Adaboost都不錯
我們要看一下最終的交叉驗證得分情況
```python
檢視交叉驗證得分
for name,model in models: print("**",name,"****") cv_= cross_val_score(model,x_scaled,y,cv=5).mean() print(cv_) ```
從交叉驗證結果上看,隨機森林表現最優,我們把它選為最佳模型,並將進一步對它進行調優以獲得更高的準確性。
💡 超引數調優
關於建模與評估,大家可以參考ShowMeAI的相關文章。
我們剛才建模過程,使用的都是模型的預設超引數,實際超引數的取值會影響模型的效果。我們有兩種最常用的方法來進行超引數調優:
- 網格搜尋:模型針對具有一定範圍值的超引數網格進行評估,嘗試引數值的每種組合,並實驗以找到最佳超引數,計算成本很高。
- 隨機搜尋:這種方法評估模型的超引數值的隨機組合以找到最佳引數,計算成本低於網格搜尋。
下面我們演示使用隨機搜尋調參優化。
```python from sklearn.model_selection import RandomizedSearchCV
params = {'n_estimators': [int(x) for x in np.linspace(start = 100, stop = 1200, num = 12)], 'criterion':['gini','entropy'], 'max_features': ['auto', 'sqrt'], 'max_depth': [int(x) for x in np.linspace(5, 30, num = 6)], 'min_samples_split': [2, 5, 10, 15, 100], 'min_samples_leaf': [1, 2, 5, 10] } random_search=RandomizedSearchCV(RTF,param_distributions=params,n_jobs=-1,cv=5,verbose=5) random_search.fit(xtrain,ytrain) ```
擬合隨機搜尋後,我們取出最佳引數和最佳估計器。
python
random_search.best_params_
python
random_search.best_estimator_
我們對最佳估計器進行評估
```python
最終模型
final_mod = RandomForestClassifier(max_depth=10, max_features='sqrt', n_estimators=500) final_mod.fit(xtrain,ytrain) final_pred = final_mod.predict(xtest) print("Accuracy Score",accuracy_score(ytest,final_pred)) cross_val = cross_val_score(final_mod,x_scaled,y,scoring='accuracy',cv=5).mean() print("Cross val score",cross_val) plot_roc_curve(final_mod,xtest,ytest) ```
我們可以看到,超引數調優後:
- 模型的整體效能有所提升。
- 準確度和交叉嚴重分數提高了。
- Auc 得分達到了97%。
💡 儲存模型
最後我們對模型進行儲存,以便後續使用或者部署上線。
```python import joblib joblib.dump(final_mod,'hr_attrition.pkl')
['hr_attrition.pkl']
```
參考連結
- 📘 圖解資料分析:從入門到精通系列教程:https://www.showmeai.tech/tutorials/33
- 📘 資料科學工具庫速查表 | Pandas 速查表:https://www.showmeai.tech/article-detail/101
- 📘 資料科學工具庫速查表 | Seaborn 速查表:https://www.showmeai.tech/article-detail/105
- 📘 圖解機器學習演算法:從入門到精通系列教程:https://www.showmeai.tech/tutorials/34
- 📘 圖解機器學習演算法(1) | 機器學習基礎知識:https://www.showmeai.tech/article-detail/185
- 📘 圖解機器學習演算法(2) | 模型評估方法與準則:https://www.showmeai.tech/article-detail/186
- 📘 機器學習實戰 | 機器學習特徵工程最全解讀:https://www.showmeai.tech/article-detail/208
- 📘 深度學習教程(7) | 網路優化:超引數調優、正則化、批歸一化和程式框架:https://www.showmeai.tech/article-detail/218
- whylogs工具庫的工業實踐!機器學習模型流程與效果監控 ⛵
- 脈脈瘋傳!2023年程式設計師生存指南;多款prompt效率加倍工具;提示工程師最全祕籍;AI裁員正在發生 | ShowMeAI日報
- 中國風?古典系?AI中文繪圖創作嚐鮮!⛵
- Python中內建資料庫!SQLite使用指南!
- Pandas中你一定要掌握的時間序列相關高階功能
- 資料科學家賺多少?資料全分析與視覺化 ⛵
- 互動式儀表板!Python輕鬆完成!⛵
- ChatGPT!我是你的破壁人;比爾·蓋茨不看好Web3與元宇宙;FIFA押中4屆世界盃冠軍;GitHub今日熱榜 | ShowMeAI資訊日報
- ChatGPT要收費了;華爾街大裁員;阿里2023十大科技趨勢;小紅書元宇宙虛擬服飾被吐槽;GitHub今日熱榜 | ShowMeAI資訊日報
- AI創業時代!這9個方向有錢途;AIGC再添霸榜應用Lensa;美團SemEval2022冠軍方法分享;醫學影象處理工具箱… | ShowMeAI資訊日報
- 噓!P站資料分析年報;各省市疫情感染進度條;愛奇藝推出元宇宙App;You推出AI聊天機器人;GitHub今日熱榜 | ShowMeAI資訊日報
- 美國公司裁員潮時間線◉科技寒冬視覺化;3份報告回顧中國開發者2022;自動駕駛下半場,誰會衝出重圍 | ShowMeAI每週通訊 #005-01.07
- 副業月入過萬?資料有話說;掃地機器人發展到哪步了;疫情後要不要重返辦公室;淘寶元宇宙直播間;GitHub今日熱榜 | ShowMeAI資訊日報
- 大戰谷歌!微軟Bing引入ChatGPT;羊了個羊40萬年薪招研發;Debian徹底移除Python2;GitHub今日熱榜 | ShowMeAI資訊日報
- 酸了!樂視工作制改為四天半;高通新年裁員;AI繪畫公司開始倒閉;網易入股張藝謀元宇宙公司;GitHub今日熱榜 | ShowMeAI資訊日報
- 要麼幹要麼滾!推特開始裁員了;深度學習產品應用·隨書程式碼;可分離各種樂器音源的工具包;Transformer教程;前沿論文 | ShowMeAI資訊日報
- 真實世界的人工智慧應用落地——OpenAI篇 ⛵
- 陽過→陽康,資料裡的時代側影;谷歌慌了!看各公司如何應對ChatGPT;兩份優質AI年報;本週技術高光時刻 | ShowMeAI每週通訊 #003-12.24
- 用魔法打敗魔法!這件毛衣讓攝像頭看不到你;兩款酷炫的AI寫作軟體;快如閃電的B站下載工具;基於擴散模型的蛋白質設計 | ShowMeAI資訊日報
- 一文讀懂!異常檢測全攻略!從統計方法到機器學習 ⛵