機器學習實戰:基於MNIST數據集的二分類問題
theme: smartblue
公眾號:尤而小屋
作者:Peter
編輯:Peter
大家好,我是Peter~
MNIST數據集是一組由美國高中生和人口調查局員工手寫的70,000個數字的圖片,每張圖片上面有代表的數字標記。
這個數據集被廣泛使用,被稱之為機器學習領域的“Hello World”,主要是被用於分類問題。本文是對MNIST數據集執行一個二分類的建模
關鍵詞:隨機梯度下降、二元分類、混淆矩陣、召回率、精度、性能評估
導入數據
在這裏是將一份存放在本地的mat文件的數據導進來。
第一次用Python讀取MATLAB文件
In [1]:
``` import pandas as pd import numpy as np
import scipy.io as si
from sklearn.datasets import fetch_openml
```
In [2]:
mnist = si.loadmat('mnist-original.mat')
In [3]:
type(mnist) # 查看數據類型
Out[3]:
dict
In [4]:
mnist.keys()
Out[4]:
dict_keys(['__header__', '__version__', '__globals__', 'mldata_descr_ordering', 'data', 'label'])
我們發現導進來的數據是一個字典。其中data和label兩個鍵的值就是我們想要的特徵和標籤數據
創建特徵和標籤
修改的內容1:一定要執行轉置功能,原書是沒有的,保證數據shape合理。
In [5]:
```
修改1:一定要轉置
X, y = mnist["data"].T, mnist["label"].T
X.shape ```
Out[5]:
(70000, 784)
總共是70000張圖片,每個圖片中有784個特徵。圖片是28*28的像素,所以每個特徵代表一個像素點,取值從0-255。
In [6]:
y.shape
Out[6]:
(70000, 1)
In [7]:
y # 每個圖片有個專屬的數字
Out[7]:
array([[0.],
[0.],
[0.],
...,
[9.],
[9.],
[9.]])
顯示一張圖片
matplotlib庫能夠顯示圖像:
In [8]:
``` import matplotlib as mpl import matplotlib.pyplot as plt
one_digit = X[0]
one_digit_image = one_digit.reshape(28, 28)
plt.imshow(one_digit_image, cmap="binary") plt.axis("off") plt.show() ```
In [9]:
y[0] # 真實的標籤的確是0
Out[9]:
python
array([0.]) # 結果是0
標籤類型轉換
元數據中標籤是字符串,我們需要轉成整數類型
In [10]:
y.dtype
Out[10]:
dtype('<f8')
In [11]:
y = y.astype(np.uint8)
創建訓練集和測試集
前面的6萬條是訓練集,後面的1萬條是測試集
In [12]:
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
二元分類器
比如現在有1張圖片,顯示是0,我們識別是:"0和非0",兩種情形即可,這就是簡單的二元分類問題
In [13]:
``` y_train_0 = (y_train == 0) # 挑選出5的部分
y_test_0 = (y_test == 0) ```
隨機梯度下降分類器SGD
使用scikit-learn自帶的SGDClassifier分類器:能夠處理非常大型的數據集,同時SGD適合在線學習
In [14]:
``` from sklearn.linear_model import SGDClassifier
sgd_c = SGDClassifier(random_state=42) # 設置隨機種子,保證運行結果相同
sgd_c.fit(X_train, y_train_0) ```
Out[14]:
SGDClassifier(random_state=42)
結果驗證
在這裏我們檢查下數字0的圖片:結果為True
In [15]:
sgd_c.predict([one_digit]) # one_digit是0,非5 表示為False
Out[15]:
array([ True])
性能測量1-交叉驗證
一般而言,分類問題的評估比迴歸問題要困難的多。下面採用多個指標來評估分類的結果
自定義交差驗證(優化)
- 每個摺疊由StratifiedKFold執行分層抽樣,產生的每個類別中的比例符合原始數據中的比例
- 每次迭代會創建一個分類器的副本,用訓練器對這個副本進行訓練,然後測試集進行測試
- 最後預測出準確率,輸出正確的比例
In [16]:
```python
K折交叉驗證
from sklearn.model_selection import StratifiedKFold
用於生成分類器的副本
from sklearn.base import clone
實例化對象
k_folds = StratifiedKFold( n_splits = 3, # 3折 shuffle=True, # add 一定要設置shuffle才能保證random_state生效 random_state=42 )
每個摺疊由StratifiedKFold執行分層抽樣
for train_index, test_index in k_folds.split(X_train, y_train_0): # 分類器的副本 clone_c = clone(sgd_c)
X_train_folds = X_train[train_index] # 訓練集的索引號
y_train_folds = y_train_0[train_index]
X_test_fold = X_train[test_index] # 測試集的索引號
y_test_fold = y_train_0[test_index]
clone_c.fit(X_train_folds, y_train_folds) # 模型訓練
y_pred = clone_c.predict(X_test_fold) # 預測
n_correct = sum(y_pred == y_test_fold) # 預測準確的數量
print(n_correct / len(y_pred)) # 預測準確的比例
```
運行的結果如下:
python
[0.09875 0.09875 0.09875 ... 0.90125 0.90125 0.90125]
[0.0987 0.0987 0.0987 ... 0.9013 0.9013 0.9013]
[0.0987 0.0987 0.0987 ... 0.9013 0.9013 0.9013]
scikit_learn的交叉驗證
使用cross_val_score來評估分類器:
In [17]:
```python
評估分類器的效果
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_c, # 模型 X_train, # 數據集 y_train_0, cv=3, # 3折 scoring="accuracy" # 準確率 )
結果
array([0.98015, 0.95615, 0.9706 ]) ```
可以看到準確率已經達到了95%以上,效果是相當的可觀
自定義一個“非0”的簡易分類器,看看效果:
In [18]:
```python from sklearn.base import BaseEstimator # 基分類器
class Never0Classifier(BaseEstimator): def fit(self, X, y=None): return self
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
```
In [19]:
```python never_0_clf = Never0Classifier()
cross_val_score( never_0_clf, # 模型 X_train, # 訓練集樣本 y_train_0, # 訓練集標籤 cv=3, # 折數 scoring="accuracy" ) ```
Out[19]:
array([0.70385, 1. , 1. ])
In [20]:
統計數據中每個字出現的次數:
pd.DataFrame(y).value_counts()
Out[20]:
1 7877
7 7293
3 7141
2 6990
9 6958
0 6903
6 6876
8 6825
4 6824
5 6313
dtype: int64
In [21]:
6903 / 70000
Out[21]:
下面顯示大約有10%的概率是0這個數字
0.09861428571428571
In [22]:
(0.70385 + 1 + 1) / 3
Out[22]:
0.9012833333333333
可以看到判斷“非0”準確率基本在90%左右,因為只有大約10%的樣本是屬於數字0。
所以如果猜測一張圖片是非0,大約90%的概率是正確的。
性能測量2-混淆矩陣
預測結果
評估分類器性能更好的方法是混淆矩陣,總體思路是統計A類別實例被劃分成B類別的次數
混淆矩陣是通過預測值和真實目標值來進行比較的。
cross_val_predict函數返回的是每個摺疊的預測結果,而不是評估分數
In [23]:
```python from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict( sgd_c, # 模型 X_train, # 特徵訓練集 y_train_0, # 標籤訓練集 cv=3 # 3折 )
y_train_pred ```
Out[23]:
array([ True, True, True, ..., False, False, False])
混淆矩陣
In [24]:
```
導入混淆矩陣
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_0, y_train_pred) ```
Out[24]:
array([[52482, 1595],
[ 267, 5656]])
混淆矩陣中:行表示實際類別,列表示預測類別
- 第一行表示“非0”:52482張被正確地分為“非0”(真負類),有1595張被錯誤的分成了“0”(假負類)
- 第二行表示“0”:267被錯誤地分為“非0”(假正類),有5656張被正確地分成了“0”(真正類)
In [25]:
```
假設一個完美的分類器:只存在真正類和真負類,它的值存在於對角線上
y_train_perfect_predictions = y_train_0
confusion_matrix(y_train_0, y_train_perfect_predictions) ```
Out[25]:
array([[54077, 0],
[ 0, 5923]])
精度和召回率
$$精度=\frac{TP}{TP+FP}$$
召回率的公式為:
$$召回率 = \frac {TP}{TP+FN}$$
混淆矩陣顯示的內容:
- 左上:真負
- 右上:假正
- 左下:假負
- 右下:真正
精度:正類預測的準確率
召回率(靈敏度或真正類率):分類器正確檢測到正類實例的比例
計算精度和召回率
In [26]:
``` from sklearn.metrics import precision_score, recall_score
precision_score(y_train_0, y_train_pred) # 精度 ```
Out[26]:
0.78003034064267
In [27]:
recall_score(y_train_0, y_train_pred) # 召回率
Out[27]:
0.9549214924869154
F_1係數
F_1係數是精度和召回率的諧波平均值。只有當召回率和精度都很高的時候,分類器才會得到較高的F_1分數
𝐹1=21精度+1召回率(3)(3)F1=21精度+1召回率
In [28]:
``` from sklearn.metrics import f1_score
f1_score(y_train_0, y_train_pred) ```
Out[28]:
0.8586609989373006
精度/召回率權衡
精度和召回率通常是一對”抗體“,我們一般不可能同時增加精度又減少召回率,反之亦然,這就現象叫做精度/召回率權衡
In [29]:
```
使用decision_function
y_scores = sgd_c.decision_function([one_digit]) y_scores ```
Out[29]:
array([24816.66593936])
In [30]:
threshold = 0 # 設置閾值
y_digit_pred = y_scores > threshold
y_digit_pred
Out[30]:
array([ True])
In [31]:
```
提升閾值
threshold = 100000
y_digit_pred = y_scores > threshold
y_digit_pred
```
Out[31]:
array([False])
如何使用閾值
- 先使用cross_val_predict函數獲取訓練集中所有實例的分數
In [32]:
``` y_scores = cross_val_predict( sgd_c, X_train, y_train_0.ravel(), # 原文 y_train_0 cv=3, method="decision_function")
y_scores ```
Out[32]:
array([ 51616.39393745, 27082.28092103, 20211.29278048, ...,
-23195.59964776, -21022.63597851, -18702.17990507])
2、有了這些分數就可以計算精度和召回率:
In [33]:
``` from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_0, y_scores) ```
In [34]:
precisions # 精度
Out[34]:
array([0.10266944, 0.10265389, 0.10265566, ..., 1. , 1. ,
1. ])
In [35]:
recalls # 召回率
Out[35]:
array([1.00000000e+00, 9.99831167e-01, 9.99831167e-01, ...,
3.37666723e-04, 1.68833361e-04, 0.00000000e+00])
In [36]:
thresholds # 閾值
Out[36]:
array([-86393.49001095, -86375.60229796, -86374.22313529, ...,
92555.12952489, 93570.30614671, 96529.58216984])
繪製精度和召回率曲線
In [37]:
```python
def figure_precision_recall(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1],"b--",label="Precision") # 精度-藍色
plt.plot(thresholds, recalls[:-1],"g-",label="Recall") # 召回率-綠色
plt.legend(loc="center right", fontsize=12)
plt.xlabel("Threshold", fontsize=16)
plt.grid(True)
figure_precision_recall(precisions, recalls, thresholds) plt.show() ```
直接繪製精度和召回率的曲線圖:
```python
精度-召回率
plt.plot(recalls[:-1], precisions[:-1],"b--")
plt.legend(loc="center right", fontsize=12)
plt.xlabel("Threshold", fontsize=16)
plt.grid(True)
```
現在我們將精度設置成90%,通過np.argmax()函數來獲取最大值的第一個索引,即表示第一個True的值:
In [39]:
threshold_90_precision = thresholds[np.argmax(precisions >= 0.9)]
threshold_90_precision
Out[39]:
9075.648564157285
In [40]:
y_train_pred_90 = (y_scores >= threshold_90_precision)
y_train_pred_90
Out[40]:
array([ True, True, True, ..., False, False, False])
In [41]:
```
再次查看精度和召回率
precision_score(y_train_0, y_train_pred_90) ```
Out[41]:
0.9001007387508395
In [42]:
recall_score(y_train_0, y_train_pred_90)
Out[42]:
0.9051156508526085
性能測量3-ROC曲線
繪製ROC
還有一種經常和二元分類器一起使用的工具,叫做受試者工作特徵曲線ROC。
繪製的是真正類率(召回率的別稱)和假正類率(FPR)。FPR是被錯誤分為正類的負類實例比率,等於1減去真負類率(TNR)
TNR是被正確地分為負類的負類實例比率,也稱之為特異度。
ROC繪製的是靈敏度和(1-特異度)的關係圖
In [43]:
```
1、計算TPR、FPR
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_0, y_scores) ```
In [44]:
```python def plot_roc_curve(fpr,tpr,label=None):
plt.plot(fpr, tpr, linewidth=2,label=label)
plt.plot([0,1], [0,1], "k--")
plt.legend(loc="center right", fontsize=12)
plt.xlabel("FPR", fontsize=16)
plt.ylabel("TPR", fontsize=16)
plt.grid(True)
plot_roc_curve(fpr,tpr) plt.show() ```
AUC面積
auc就是上面ROC曲線的線下面積。完美的分類器ROC_AUC等於1;純隨機分類器的ROC_AUC等於0.5
In [45]:
``` from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_0, y_scores) ```
Out[45]:
0.9910680354987216
ROC曲線和精度/召回率(PR)曲線非常類似,選擇經驗:當正類非常少見或者我們更加關注假正類而不是假負類,應該選擇PR曲線,否則選擇ROC曲線
對比隨機森林分類器
報錯解決方案:http://stackoverflow.com/questions/63506197/method-predict-proba-for-cross-val-predict-return-index-1-is-out-of-bounds-fo
報錯:index 1 is out of bounds for axis 1 with size 1
In [46]:
X_train.shape
Out[46]:
(60000, 784)
In [47]:
```
解決方案
y_train_0 = y_train_0.reshape(X_train.shape[0], ) y_train_0 ```
Out[47]:
array([ True, True, True, ..., False, False, False])
In [48]:
``` from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42) y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_0, cv=3, method="predict_proba") y_probas_forest ```
Out[48]:
array([[0. , 1. ],
[0.04, 0.96],
[0.15, 0.85],
...,
[0.93, 0.07],
[0.97, 0.03],
[0.96, 0.04]])
使用roc_curve函數來提供分類的概率:
In [49]:
``` y_scores_forest = y_probas_forest[:,1]
fpr_rf, tpr_rf, thresholds_rf = roc_curve(y_train_0, y_scores_forest) ```
In [50]:
```python plt.plot(fpr, tpr, "b:", label="SGD") plot_roc_curve(fpr_rf,tpr_rf,"Random Forest") plt.legend(loc="lower right")
plt.show() ```
現在我們重新查看ROC-AUC值、精度和召回率,發現都得到了提升:
In [51]:
roc_auc_score(y_train_0,y_scores_forest) # ROC-AUC值
Out[51]:
0.9975104189747056
In [52]:
precision_score(y_train_0,y_train_pred) # 精度
Out[52]:
0.78003034064267
In [53]:
recall_score(y_train_0,y_train_pred) # 召回率
Out[53]:
0.9549214924869154
總結
本文從公開的MNIST數據出發,通過SGD建立一個二元分類器,同時利用交叉驗證來評估我們的分類器,以及使用不同的指標(精度、召回率、精度/召回率平衡)、ROC曲線等來比較SGD和RandomForestClassifier不同的模型。
- 基於機器學習分類算法的鋼材缺陷檢測分類
- JSON數據,Python搞定!
- 邏輯迴歸:信貸違規預測!
- kaggle實戰-腫瘤數據統計分析
- Pandas操作mysql數據庫!
- 數學公式編輯神器-Mathpix Snipping Tool
- 精選20個Pandas統計函數
- 露一手,利用Python分析房價
- 德國信貸數據建模baseline!
- Python函數傳參機制詳解
- Python爬蟲周遊全國-蘭州站
- 一道Pandas題:3種解法
- 機器學習實戰:基於3大分類模型的中風病人預測
- 利用seaborn繪製柱狀圖
- 機器學習實戰:基於MNIST數據集的二分類問題
- 基於深度學習Keras的深圳租房建模
- 機器學習高頻使用代碼片段
- Python入門:Python變量和賦值
- 深度學習框架Keras入門保姆教程
- kaggle實戰:極度不均衡的信用卡數據分析