1. 程式人生 > >隨機森林實戰

隨機森林實戰

res code style odin ensemble n) 部分 範圍 dict

代碼實現:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Tue Sep  4 09:38:57 2018
 4 
 5 @author: zhen
 6 """
 7 
 8 from sklearn.ensemble import RandomForestClassifier
 9 from sklearn.model_selection import train_test_split
10 from sklearn.metrics import accuracy_score
11 from sklearn.datasets import load_iris
12 import matplotlib.pyplot as plt 13 14 iris = load_iris() 15 x = iris.data[:, :2] 16 y = iris.target 17 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) 18 19 # n_estimators:森林中樹的個數(默認為10),建議為奇數 20 # n_jobs:並行執行任務的個數(包括模型訓練和預測),默認值為-1,表示根據核數 21 rnd_clf = RandomForestClassifier(n_estimators=15, max_leaf_nodes=16, n_jobs=1) 22 rnd_clf.fit(x_train, y_train) 23 24 y_predict_rf = rnd_clf.predict(x_test)
25 26 print(accuracy_score(y_test, y_predict_rf)) 27 28 for name, score in zip(iris[feature_names], rnd_clf.feature_importances_): 29 print(name, score) 30 31 # 可視化 32 plt.plot(x_test[:, 0], y_test, r., label=real) 33 plt.plot(x_test[:, 0], y_predict_rf, b., label=predict) 34 plt.xlabel(
sepal-length, fontsize=15) 35 plt.ylabel(type, fontsize=15) 36 plt.legend(loc="upper left") 37 plt.show() 38 39 plt.plot(x_test[:, 1], y_test, r., label=real) 40 plt.plot(x_test[:, 1], y_predict_rf, b., label=predict) 41 plt.xlabel(sepal-width, fontsize=15) 42 plt.ylabel(type, fontsize=15) 43 plt.legend(loc="upper right") 44 plt.show()

結果:

技術分享圖片

可視化(查看每個預測條件的影響):

技術分享圖片

技術分享圖片

  分析:鳶尾花的花萼長度在小於6時預測準確率很高,隨著長度的增加,在6~7這段中,預測出現較大錯誤率,當大於7時,預測會恢復到較好的情況。寬度也出現類似的情況,在3~3.5這個範圍出現較高錯誤,因此在訓練中建議在訓練數據中適量增加中間部分數據的訓練量(該部分不容易區分),以便得到較好的訓練模型!

隨機森林實戰