scikit-learn調參輔助
阿新 • • 發佈:2018-12-15
learning_curve:
作用:模型精度和不同大小資料集之間的關係
from sklearn.model_selection import learning_curve train_sizes, train_scores, test_scores = learning_curve(estimator=pipe_clf, X=X_train, y=y_train, train_sizes=np.linspace(0.1, 1.0, 10), cv=10, n_jobs=1) train_mean = np.mean(train_scores, axis=1) #cv=10,,每行有10個score train_std = np.std(train_scores, axis=1) test_mean = np.mean(test_scores, axis=1) test_std = np.std(test_scores, axis=1) plt.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='training accuracy') plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue') #alpha是透明度 plt.plot(train_sizes, test_mean, color='green', linestyle='--', marker='s', markersize=5, label='validation accuracy') plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green') plt.grid() plt.xlabel('Number of training samples') plt.ylabel('Accuracy') plt.legend(loc='lower right') plt.ylim([0.8, 1.0]) plt.title('learning curve')
validation_curve
作用:模型精度和不同引數之間關係
from sklearn.model_selection import validation_curve param_range = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0] train_scores, test_scores = validation_curve(estimator=SVC, X=X_train, y=y_train, param_name='clf__C', param_range=param_range, cv=10) #C是SVC的懲罰引數 train_mean = np.mean(train_scores, axis=1) train_std = np.std(train_scores, axis=1) test_mean = np.mean(test_scores, axis=1) test_std = np.std(test_scores, axis=1) plt.figure(2) plt.plot(param_range, train_mean, color='blue', marker='o', markersize=5, label='training accuracy') plt.fill_between(param_range, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue') plt.plot(param_range, test_mean, color='green', linestyle='--', marker='s', markersize=5, label='validation accuracy') plt.fill_between(param_range, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green') plt.grid() plt.xscale('log') #x軸是對數軸 plt.legend(loc='lower right') plt.xlabel('Parameter C') plt.ylabel('Accuracy') plt.ylim([0, 1.0]) plt.title('validation curve') plt.show()
參考資料
機器學習系統模型調優實戰--所有調優技術都附相應的scikit-learn實現:https://blog.csdn.net/xlinsist/article/details/51344449