使用交叉驗證對鳶尾花分類模型進行調參(超參數)
阿新 • • 發佈:2018-09-15
www. eight data svc ans 分塊 分類 app files
如何選擇超參數:
交叉驗證:
如圖,
大訓練集
分塊,使用不同的分塊方法分成N對小訓練集
和驗證集
。- 使用
小訓練集
進行訓練,使用驗證集
進行驗證,得到準確率,求N個驗證集
上的平均正確率
; - 使用
平均正確率
最高的超參數
,對整個大訓練集
進行訓練,訓練出參數。 - 在
訓練集
上訓練。
網格搜索
諸如你有多個可調節的超參數,那麽選擇超參數的方法通常是網格搜索,即固定一個參、變化其他參,像網格一樣去搜索。
# 人工智能數據源下載地址:https://video.mugglecode.com/data_ai.zip,下載壓縮包後解壓即可(數據源與上節課相同)# -*- coding: utf-8 -*- """ 任務:鳶尾花識別 """ import pandas as pd from sklearn.model_selection import train_test_split, GridSearchCV from sklearn.neighbors import KNeighborsClassifier from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC DATA_FILE = ‘./data_ai/Iris.csv‘ SPECIES_LABEL_DICT= { ‘Iris-setosa‘: 0, # 山鳶尾 ‘Iris-versicolor‘: 1, # 變色鳶尾 ‘Iris-virginica‘: 2 # 維吉尼亞鳶尾 } # 使用的特征列 FEAT_COLS = [‘SepalLengthCm‘, ‘SepalWidthCm‘, ‘PetalLengthCm‘, ‘PetalWidthCm‘] def main(): """ 主函數 """ # 讀取數據集 iris_data = pd.read_csv(DATA_FILE, index_col=‘Id‘) iris_data[‘Label‘] = iris_data[‘Species‘].map(SPECIES_LABEL_DICT) # 獲取數據集特征 X = iris_data[FEAT_COLS].values # 獲取數據標簽 y = iris_data[‘Label‘].values # 劃分數據集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=10) model_dict = {‘kNN‘: ( KNeighborsClassifier(), {‘n_neighbors‘: [5, 15, 25], ‘p‘: [1, 2]} ), ‘Logistic Regression‘: ( LogisticRegression(), {‘C‘: [1e-2, 1, 1e2]} ), ‘SVM‘: ( SVC(), {‘C‘: [1e-2, 1, 1e2]} ) } # 名稱+元組 for model_name, (model, model_params) in model_dict.items(): # 訓練模型 clf = GridSearchCV(estimator=model, param_grid=model_params, cv=5) #模型、參數、折數 clf.fit(X_train, y_train) #訓練 best_model = clf.best_estimator_ #最佳模型的對象 # 驗證 acc = best_model.score(X_test, y_test) print(‘{}模型的預測準確率:{:.2f}%‘.format(model_name, acc * 100)) print(‘{}模型的最優參數:{}‘.format(model_name, clf.best_params_)) #最好的模型名稱和參數 if __name__ == ‘__main__‘: main()
運行結果:
kNN模型的預測準確率:96.00%
kNN模型的最優參數:{‘n_neighbors‘: 15, ‘p‘: 2}
Logistic Regression模型的預測準確率:96.00%
Logistic Regression模型的最優參數:{‘C‘: 100.0}
SVM模型的預測準確率:98.00%
SVM模型的最優參數:{‘C‘: 1}
練習
練習:使用交叉驗證對水果分類模型進行調參
-
題目描述:為模型選擇最優的參數並進行水果類型識別,模型包括kNN,邏輯回歸及SVM。對應的超參數為:
-
kNN中的近鄰個數n_neighbors及閔式距離的p值
-
邏輯回歸的正則項系數C值
-
SVM的正則項系數C值
-
題目要求:
-
使用3折交叉驗證對模型進行調參
-
使用scikit-learn提供的方法為模型調參
-
數據文件:
-
數據源下載地址:https://video.mugglecode.com/fruit_data.csv(數據源與上節課相同)
-
fruit_data.csv,包含了59個水果的的數據樣本。
-
共5列數據
-
fruit_name:水果類別
-
mass: 水果質量
-
width: 水果的寬度
-
height: 水果的高度
-
color_score: 水果的顏色數值,範圍0-1。
-
0.85 - 1.00:紅色
-
0.75 - 0.85: 橙色
-
0.65 - 0.75: 黃色
-
0.45 - 0.65: 綠色
image
可能的代碼
import pandas as pd from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC #讀取數據 data = pd.read_csv(‘./data_ai/fruit_data.csv‘) #數據處理 fruit_dict = { ‘apple‘: 0, ‘lemon‘: 1, ‘mandarin‘: 2, ‘orange‘: 3 } data[‘label‘] = data[‘fruit_name‘].map(fruit_dict) feat_cols = [‘mass‘,‘width‘,‘height‘,‘color_score‘] #數據提取 X = data[feat_cols].values y = data[‘label‘].values X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=1/5, random_state= 3) model_dict = { ‘KNN‘: ( KNeighborsClassifier(), {‘n_neighbors‘: [5,15,25], ‘p‘ : [1,2]} ), ‘Logestic Regression‘: (LogisticRegression(), {‘C‘:[1e02, 1, 1e2] }), ‘SVM‘: (SVC(), {‘C‘:[1e02, 1, 1e2]}) } for model_name, (model, model_para) in model_dict.items(): #訓練 clf = GridSearchCV(estimator=model, param_grid=model_para, cv=5) # 模型、參數、折數 clf.fit(X_train,y_train) best_model = clf.best_estimator_ #驗證 acc = best_model.score(X_test, y_test) print(f‘{model_name}中選擇{clf.best_params_}為參數的預測準確率最好,準確率可達{acc*100}%‘)
運行結果:
KNN中選擇{‘n_neighbors‘: 5, ‘p‘: 1}為參數的預測準確率最好,準確率可達66.66666666666666%
Logestic Regression中選擇{‘C‘: 100.0}為參數的預測準確率最好,準確率可達91.66666666666666%
SVM中選擇{‘C‘: 100.0}為參數的預測準確率最好,準確率可達50.0%
作者:夏威夷的芒果
鏈接:https://www.jianshu.com/p/790ac622dc18
來源:簡書
使用交叉驗證對鳶尾花分類模型進行調參(超參數)