1. 程式人生 > >通過網格搜尋和巢狀交叉驗證尋找機器學習模型的最優引數

通過網格搜尋和巢狀交叉驗證尋找機器學習模型的最優引數

在機器學習的模型中,通常有兩類引數,第一類是通過訓練資料學習得到的引數,也就是模型的係數,如迴歸模型中的權重係數,第二類是模型演算法中需要進行設定和優化的超參,如logistic迴歸中的正則化係數和決策樹中的樹的深度引數等。在上一篇文章中,我們通過驗證曲線來尋找最優的超參,在這篇文章中,將通過一種功能更為強大的尋找超參的技巧:網格搜尋,它可以尋找最優的超參組合,來提高模型的效能。

一、網格(grid search)搜尋尋找超參

網格搜尋:網格搜尋其實是一種暴力搜尋引數的方法,它通過我們指定不同的超參列表進行窮舉搜尋,並計算每一個超參組合對於模型效能的影響,來獲取最優的超參組合。下面通過sklearn來實現網格搜尋尋找超參

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

if __name__ == "__main__":
    #讀取資料
    data = pd.read_csv("G:/dataset/wdbc.csv")
    #獲取X
    X = data.ix[:,2:32]
    #獲取字串類別標籤
    label_y = data.ix[:,1]
    #將字串的標籤轉為數字
    label = LabelEncoder()
    Y = label.fit_transform(label_y)
    #將資料集分為訓練集和測試集
    train_x,test_x,train_y,test_y = train_test_split(X,Y,test_size=0.2,random_state=1)
    #初始化一個流水線類
    pipe = Pipeline([("std",StandardScaler()),
                     ("clf",SVC(random_state=1))])
    #定義引數的取值
    param_range = [0.0001,0.001,0.01,0.1,1.0,10.0,100.0,1000.0]
    #定義一個網格搜尋的引數
    '''
    線性的SVM只需要,只需要調優正則化引數C
    基於RBF核的SVM,需要調優gamma引數和C
    '''
    param_grid = [{"clf__C":param_range,
                   "clf__kernel":['linear']},
                  {"clf__C":param_range,
                   "clf__gamma":param_range,
                   "clf__kernel":['rbf']}]
    #網格搜尋超參
    grid_search = GridSearchCV(estimator=pipe,param_grid=param_grid,
                               scoring="accuracy",cv=10,n_jobs=-1)
    grid_search = grid_search.fit(train_x,train_y)
    #獲取模型的最優超參
    print(grid_search.best_params_)
    #{'clf__C': 0.1, 'clf__kernel': 'linear'}
    #獲取最好的結果
    print(grid_search.best_score_)
    #0.978021978022

通過上面的結果可以發現,當SVM的核為"linear"時,引數C為0.1時,模型獲得最好的結果為97.8%。

測試模型在測試集上的準確率

    clf = grid_search.best_estimator_
    print(clf.score(test_x,test_y))
    #0.964912280702

網格搜尋是一種功能強大的尋找超參的方法,但是由於它在尋找超參的時候使用的是窮舉法,需要評估所有的引數組合,所以計算成本也是非常高的。sklearn還提供一種隨機網格搜尋引數RandomizedSearchCV類,可以以特定的代價從抽樣分佈中隨機抽取引數組合。

二、巢狀交叉驗證


當我們需要在不同的機器學習演算法中進行選擇的時候,可以通過巢狀交叉驗證來進行選擇。在對於誤差估計的偏差情形研究中表明:使用巢狀交叉驗證,估計的真實誤差與在測試集上得到的結果幾乎沒有差距

巢狀交叉驗證分為外部迴圈和內部迴圈,在外部迴圈中,我們將資料分為訓練塊和測試塊。在內部迴圈中,我們將訓練塊分為訓練塊和測試塊,在訓練塊上使用k折交叉驗證,測試塊用於對於模型進行評估,通過內部迴圈來進行模型選擇。通過上圖可以發現,外部迴圈由5個模組組合,內部迴圈由2個模組組成,因此巢狀交叉驗證也被稱為5×2交叉驗證。下面通過sklearn來實現巢狀交叉驗證

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":
    #讀取資料
    data = pd.read_csv("G:/dataset/wdbc.csv")
    #獲取X
    X = data.ix[:,2:32]
    #獲取字串類別標籤
    label_y = data.ix[:,1]
    #將字串的標籤轉為數字
    label = LabelEncoder()
    Y = label.fit_transform(label_y)
    #將資料集分為訓練集和測試集
    train_x,test_x,train_y,test_y = train_test_split(X,Y,test_size=0.2,random_state=1)
    #初始化一個流水線類
    pipe = Pipeline([("std",StandardScaler()),
                     ("clf",SVC(random_state=1))])
    #定義引數的取值
    param_range = [0.0001,0.001,0.01,0.1,1.0,10.0,100.0,1000.0]
    #定義一個網格搜尋的引數
    '''
    線性的SVM只需要,只需要調優正則化引數C
    基於RBF核的SVM,需要調優gamma引數和C
    '''
    param_grid = [{"clf__C":param_range,
                   "clf__kernel":['linear']},
                  {"clf__C":param_range,
                   "clf__gamma":param_range,
                   "clf__kernel":['rbf']}]
    #網格搜尋超參
    grid_search = GridSearchCV(estimator=pipe,param_grid=param_grid,
                               scoring="accuracy",cv=5,n_jobs=-1)
    scores = cross_val_score(grid_search,train_x,train_y,scoring="accuracy",cv=5)
    print("CV accuracy:%.3f +/- %.3f"%(np.mean(scores),np.std(scores)))
    #CV accuracy:0.978 +/- 0.012

通過巢狀交叉驗證來判斷決策樹的表現效能

    from sklearn.tree import DecisionTreeClassifier
    param_range = [1,2,3,4,5,6,7,8,9,10]
    grid_search = GridSearchCV(estimator=DecisionTreeClassifier(random_state=0),
                               param_grid=[
                                   {"max_depth":param_range}
                               ],scoring="accuracy",cv=5)
    scores = cross_val_score(estimator=grid_search,X=train_x,y=train_y,scoring="accuracy",cv=5)
    print("CV accuray:%.3f +/- %.3f"%(np.mean(scores),np.std(scores)))
    #CV accuray:0.908 +/- 0.045
通過SVM和決策樹的巢狀交叉驗證結果表明,SVM的模型效能要高於決策樹模型的效能。

相關推薦

通過網格搜尋交叉驗證尋找機器學習模型引數

在機器學習的模型中,通常有兩類引數,第一類是通過訓練資料學習得到的引數,也就是模型的係數,如迴歸模型中的權重係數,第二類是模型演算法中需要進行設定和優化的超參,如logistic迴歸中的正則化係數和決策樹中的樹的深度引數等。在上一篇文章中,我們通過驗證曲線來尋找最優的超參,在

[讀書筆記] 《Python 機器學習》- 使用交叉驗證進行模型選擇

摘要 通過巢狀交叉驗證選擇演算法(外部迴圈通過k-折等進行引數優化,內部迴圈使用交叉驗證),我們可以對特定資料集進行模型選擇 程式碼 # 6.4.2: 巢狀交叉驗證選擇演算法,用於在不同的機器學習演算法中進行選擇 import matplotli

學習筆記(七)模型的調參之網格搜尋交叉驗證的簡單應用

學習筆記(七)模型的調參之網格搜尋和交叉驗證的簡單應用 資料概述 交叉驗證 1. Cross——Validation 交叉驗證 2. k折交叉驗證(kfold) 3.留一法Leave-one-out Cross-validation

Go基礎系列:structstruct

struct struct定義結構,結構由欄位(field)組成,每個field都有所屬資料型別,在一個struct中,每個欄位名都必須唯一。 說白了就是拿來儲存資料的,只不過可自定義化的程度很高,用法很靈活,Go中不少功能依賴於結構,就這樣一個角色。 Go中不支援面向物件,面向物件中描述事物的類的重擔

Python_從零開始學習_(21) 函式的返回值呼叫

1.  函式的返回值 在程式開發中,  有時候,  會希望 一個函式執行結束後,  告訴呼叫者一個結果,  以便呼叫者針對具體的結果做後續的處理 返回值 是函式 完成工作 後,  最後 給呼叫者的 一個結果

C#程式設計基礎第六課:C#中三元運算子的初級使用

知識點:三元運算子的使用。 1、三元運算子 三元運算子的初級使用: 符號: ?: 舉例:int c=bool ? a : b 當bool=true,c=表示式a,當bool=false,c=表示式b。 三元運算子?:是 if~else 語句的簡寫形式 書寫格式

CSS之分組選擇器選擇器

分組選擇器, 將一個樣式應用於多個類,或者標籤啥的 每個選擇器用逗號隔開 <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title>菜鳥教程(runoob.c

Python List資料去重List資料去重

單個list中資料去重 例如: 去除a中重複的資料 ‘b’ a = ['a','b','c','b'] b = list(set(a)) print(b) 輸出結果為: ['a', 'c', 'b']    巢狀list中去除相同list資料

css關於position定位元素並列顯示的小發現

一、fixed定位 1、元素並列 <!-- html程式碼 --> <div class="container"> <div class="fixed-one"> <el-button type="primar

python 中高階函式函式

 1、高階函式:變數可以指向函式;                         函式的引數可以接收變數;      

vue——46-webpack打包vue-路由 路由

一、路由 main.js 中 1.引入 vue-router 包 安裝命名:cnpm i vue-router -s import Vue from 'vue'; import app from

Python List資料去重List資料去重

單個list中資料去重 例如: 去除a中重複的資料 ‘b’ a = ['a','b','c','b'] b = list(set(a)) print(b) 輸出結果為: ['a', 'c', 'b']  巢狀list中去除相同list資料 例如: 去除

elasticsearch複合資料型別——陣列,物件

在ElasticSearch中,使用JSON結構來儲存資料,一個Key/Value對是JSON的一個欄位,而Value可以是基礎資料型別,也可以是陣列,文件(也叫物件),或文件陣列,因此,每個JSON文件都內在地具有層次結構。複合資料型別是指陣列型別,物件型別和巢狀型別,各個

CSS 分組 選擇器

Grouping Selectors在樣式表中有很多具有相同樣式的元素。h1{color:green;}h2{color:green;}p{color:green;}   為了儘量減少程式碼,你可以使用

CSS選擇器的宣告

前言   在利用CSS選擇器控制HTML標記時,除了每個選擇器可以一次宣告多個,選擇器本身也可以同時宣告多個。 集體宣告   在宣告各種CSS選擇器時,如果某些選擇器的風格是完全相同或部分相同,這時便可以利用集體宣告的方法,將風格相同的CSS選擇器同時

CSS 分組

組選擇器 <!DOCTYPE html > <html> <head> <meta charset="utf-8"> <title&g

轉 CSS的組合 id class, 點, #

CSS的組合和巢狀2007年04月27日 星期五 上午 8:58 CSS的組合和巢狀 組合 你不必重複有相同屬性的多個選擇符,你只要用英文逗號(,)隔開選擇符就可以了。 比如,你有如下的程式碼: h2 { color: red; } .thisOtherClass { color: red; } .yetAn

CSS 分組 選擇器

在樣式表中有很多具有相同樣式的元素。 h1{color:green;} h2{color:green;} p{color:green;} 分組選擇器:(相同樣式不同元素可使用分組選擇器) h1,h2,p{color:green;} 巢狀選擇器 它可能適用於選擇

web前端-CSS 分組-015

Grouping Selectors 在樣式表中有很多具有相同樣式的元素。 h1 { color:green; } h2 { color:green; } p { color:green; } 為了儘量減少程式碼,你可以使用分組選擇器。 每個選擇器用

Java內部類(inner Class)類(static inner Class)也就是靜態內部類的區別

內部類和靜態類有著本質的區別,有點類似普通成員變數和靜態成員變數的區別。 . 內部類可以看成是外部類的普通成員變數,這個成員變數可以使用外部類的屬性(靜態和非靜態),可以呼叫外部類的方法(靜態和非靜態),而且內部類還持有外部類物件作為其自身的一個屬性,這個屬