1. 程式人生 > >機器學習為什麼需要交叉驗證?怎麼使用k-fold cross validation(k-摺疊交叉驗證)

機器學習為什麼需要交叉驗證?怎麼使用k-fold cross validation(k-摺疊交叉驗證)

介紹這個非常重要的概念,希望在訓練演算法時能幫助各位。

概念和思維解讀

叉驗證的目的:在實際訓練中,模型通常對訓練資料好,但是對訓練資料之外的資料擬合程度差。用於評價模型的泛化能力,從而進行模型選擇

交叉驗證的基本思想:把在某種意義下將原始資料(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set or test set),首先用訓練集對模型進行訓練,再利用驗證集來測試模型的泛化誤差。另外,現實中資料總是有限的,為了對資料形成重用,從而提出k-摺疊交叉驗證。

對於個分類或迴歸問題,假設可選的模型為k-摺疊交叉驗證就是將訓練集的1/k作為測試集,每個模型訓練k次,測試k次,錯誤率為k次的平均,最終選擇平均率最小的模型Mi。

1、 將全部訓練集S分成k個不相交的子集,假設S中的訓練樣例個數為m,那麼每一個子集有m/k個訓練樣例,相應的子集稱作{}。

2、 每次從模型集合M中拿出來一個,然後在訓練子集中選擇出k-1個

{}(也就是每次只留下一個),使用這k-1個子集訓練後,得到假設函式。最後使用剩下的一份作測試,得到經驗錯誤。

3、 由於我們每次留下一個(j從1到k),因此會得到k個經驗錯誤,那麼對於一個,它的經驗錯誤是這k個經驗錯誤的平均。

4、 選出平均經驗錯誤率最小的,然後使用全部的S再做一次訓練,得到最後的。

程式碼使用案例

一、選擇正確的Model基礎驗證法

from sklearn.datasets import load_iris # iris資料集  

from sklearn.model_selection import train_test_split # 分割資料模組  
from sklearn.neighbors import KNeighborsClassifier # K最近鄰(kNN,k-NearestNeighbor)分類演算法  
#載入iris資料集  
iris = load_iris()  
X = iris.data  
y = iris.target  
#分割資料並  
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)  
#建立模型  

knn = KNeighborsClassifier()  
#訓練模型  
knn.fit(X_train, y_train)  
#將準確率打印出  
print(knn.score(X_test, y_test))  
# 0.973684210526     基礎驗證的準確率  

二、選擇正確的Model交叉驗證法(Cross-validation)

cv= 5

from sklearn.cross_validation import cross_val_score # K折交叉驗證模組  
#使用K折交叉驗證模組  
scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')  
#將5次的預測準確率打印出  
print(scores)  
# [ 0.96666667  1.          0.93333333  0.96666667  1.        ]  
#將5次的預測準確平均率打印出  
print(scores.mean())  
# 0.973333333333  
三、準確率和平均方差

一般來說準確率(accuracy)會用於判斷分類(Classification)模型的好壞。

import matplotlib.pyplot as plt #視覺化模組  
#建立測試引數集  
k_range = range(1, 31)  
k_scores = []  
#藉由迭代的方式來計算不同引數對模型的影響,並返回交叉驗證後的平均準確率  
for k in k_range:  
   knn = KNeighborsClassifier(n_neighbors=k)  
   scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')  
   k_scores.append(scores.mean())  
#視覺化資料  
plt.plot(k_range, k_scores)  
plt.xlabel('Value of K for KNN')  
plt.ylabel('Cross-Validated Accuracy')  
plt.show()  


結果如圖,從圖中可以得知,選擇12~18的k值最好。高過18之後,準確率開始下降則是出現過擬合了。



一般來說平均方差(Mean squared error)會用於判斷迴歸(Regression)模型的好壞。
import matplotlib.pyplot as plt  
k_range = range(1, 31)  
k_scores = []  
for k in k_range:  
   knn = KNeighborsClassifier(n_neighbors=k)  
   loss = -cross_val_score(knn, X, y, cv=10, scoring='mean_squared_error')  
   k_scores.append(loss.mean())  
plt.plot(k_range, k_scores)  
plt.xlabel('Value of K for KNN')  
plt.ylabel('Cross-Validated MSE')  
plt.show()  

結果如下圖,當K取13~20時,平方誤差最小,模型最好。