1. 程式人生 > >十、用scikit-learn的網格搜尋快速找到最優模型引數

十、用scikit-learn的網格搜尋快速找到最優模型引數


任何一種機器學習模型都附帶很多引數,不同場景對應不同的最佳引數,手工嘗試各種引數無疑浪費很多時間,scikit-learn幫我們實現了自動化,那就是網格搜尋

網格搜尋

這裡的網格指的是不同引數不同取值交叉後形成的一個多維網格空間。比如引數a可以取1、2,引數b可以取3、4,引數c可以取5、6,那麼形成的多維網格空間就是:

1、3、5
1、3、6
1、4、5
1、4、6
2、3、5
2、3、6
2、4、5
2、4、6

一共2*2*2=8種情況

網格搜尋就是遍歷這8種情況進行模型訓練和驗證,最終選擇出效果最優的引數組合

用法舉例

# coding:utf-8

import sys
reload(sys)
sys.setdefaultencoding( "utf-8"
) from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model.logistic import LogisticRegression from sklearn.grid_search import GridSearchCV from sklearn.pipeline import Pipeline # 構造樣本,這塊得多構造點,不然會報class不足的錯誤,因為gridsearch會拆分成小組 X = [] X.append("fuck you") X.append("fuck you all"
) X.append("hello everyone") X.append("fuck me") X.append("hello boy") X.append("fuck you") X.append("fuck you all") X.append("hello everyone") X.append("fuck me") X.append("hello boy") X.append("fuck you") X.append("fuck you all") X.append("hello everyone") X.append("fuck me") X.append("hello boy") X.append("fuck you"
) X.append("fuck you all") X.append("hello everyone") X.append("fuck me") X.append("hello boy") X.append("fuck you") X.append("fuck you all") X.append("hello everyone") X.append("fuck me") X.append("hello boy") y = [1,0,1,0,1,1,0,1,0,1,1,0,1,0,1,1,0,1,0,1,1,0,1,0,1] # 這是執行的序列,gridsearch是構造多程序順序執行序列並比較結果 # 這裡的vect和clf名字自己隨便起,但是要和parameters中的字首對應 pipeline = Pipeline([ ('vect', TfidfVectorizer(stop_words='english')), ('clf', LogisticRegression()) ]) # 這裡面的max_features必須是TfidfVectorizer的引數, 裡面的取值就是子程序分別執行所用 parameters = { 'vect__max_features': (3, 5), } # accuracy表示按精確度判斷最優值 grid_search = GridSearchCV(pipeline, parameters, n_jobs = -1, verbose = 1, scoring = 'accuracy', cv = 3) grid_search.fit(X, y) print '最佳效果: %0.3f' % grid_search.best_score_ print '最優引數組合: ' best_parameters = grid_search.best_estimator_.get_params() for param_name in sorted(parameters.keys()): print('\t%s: %r' % (param_name, best_parameters[param_name]))

執行結果如下:

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[Parallel(n_jobs=-1)]: Done   7 out of   6 | elapsed:    0.0s remaining:   -0.0s
[Parallel(n_jobs=-1)]: Done   7 out of   6 | elapsed:    0.1s remaining:   -0.0s
[Parallel(n_jobs=-1)]: Done   7 out of   6 | elapsed:    0.1s remaining:   -0.0s
[Parallel(n_jobs=-1)]: Done   7 out of   6 | elapsed:    0.1s remaining:   -0.0s
[Parallel(n_jobs=-1)]: Done   7 out of   6 | elapsed:    0.1s remaining:   -0.0s
[Parallel(n_jobs=-1)]: Done   6 out of   6 | elapsed:    0.1s finished
最佳效果: 0.800
最優引數組合:
    vect__max_features: 3

這裡面並行啟動了6個任務,最終判斷出max_features的最優解值是3