1. 程式人生 > >交叉驗證和網格搜尋

交叉驗證和網格搜尋

一、交叉驗證(Cross Validation)

1. 目的

交叉驗證的目的是為了讓模型評估更加準確可信。

2. 基本思想

基本思想是將原始資料(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set or test set),首先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型,以此來作為評價分類器的效能指標。

3. 主要方法

交叉驗證主要有以下三種方法:

  • Holdout驗證
  • K折交叉驗證
  • 留一驗證

3.1 Holdout驗證

將原始資料隨機分為兩組,一組做為訓練集,一組做為驗證集,利用訓練集訓練分類器,然後利用驗證集驗證模型。

3.2 K折交叉驗證(K-fold Cross Validation)

以10折交叉驗證為例,如下圖所示。

步驟如下:

  1. 將資料集平均分成不相交的10個子集
  2. 每一次挑選其中的1份作為測試集,其餘的9份作為訓練集進行模型訓練,得到模型的指標
  3. 重複第2步10次,使每個子集都作為1次測試集,得到10個模型的指標
  4. 將10個模型指標取平均值,作為10折交叉驗證的模型的指標

3.3 留一驗證(Leave-One-Out Cross Validation,LOOCV)

留一驗證是K折交叉驗證的特例,假設原始資料有N個樣本,每個樣本單獨作為驗證集,其餘的N-1個樣本作為訓練集。此方法主要用於樣本量非常少的情況。

二、網格搜尋(Grid Search)

通常情況下,很多超引數需要調節,但是手動過程繁雜,所以需要對模型預設幾種超引數組合,每組超引數都採用交叉驗證來進行評估。最後選出最優引數組合建立模型。

sklearn中網格搜尋API

    sklearn.model_selection.GridSearchCV(estimator,param_grid,cv)

estimator:估計器物件
param_grid:估計器引數,引數名稱(字串)作為key,要測試的引數列表作為value的字典,或這樣的字典構成的列表
cv:整形,指定K折交叉驗證
方法:
fit:輸入訓練資料
score:準確率
best_score_

:交叉驗證中測試的最好的結果
best_estimator_:交叉驗證中測試的最好的引數模型
best_params_:交叉驗證中測試的最好的引數
cv_results_:每次交叉驗證的結果

簡單示例如下:

knn = KNeighborsClassifier()

param = {"n_neighbors": [3,5,10]}
gscv = GridSearchCV(knn, param_grid=param, cv=10)

gscv.fit(x_train, y_train)

print(gscv.score(x_test, y_test))
print(gscv.best_score_)
print(gscv.best_estimator_)
print(gscv.best_params_)
print(pd.DataFrame(gscv.cv_results_).T)

到不了的地方都叫做遠方,回不去的世界都叫做家鄉,我一直嚮往的卻是比遠更遠的地方。——《幽靈公主》