構建7種分類模型,評分並畫出ROC曲線
阿新 • • 發佈:2018-12-22
構建7種分類模型,評分並畫出ROC曲線
- 匯入的包
import pandas as pd from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.svm import SVC from sklearn.metrics import f1_score,precision_score,recall_score,roc_auc_score,accuracy_score,roc_curve import matplotlib.pyplot as plt from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import GradientBoostingClassifier from xgboost.sklearn import XGBClassifier import lightgbm as lgb
- 讀取資料集
data_all = pd.read_csv('/home/infisa/wjht/project/DataWhale/data_all.csv', encoding='gbk')
- 劃分訓練集和測試集
features = [x for x in data_all.columns if x not in ['status']] X = data_all[features] y = data_all['status'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=2018)
- 構建模型
lr = LogisticRegression(random_state=2018,tol=1e-6) # 邏輯迴歸模型 lr.fit(X_train, y_train) lr_y_proba=lr.predict_proba(X_test) lr_y_pre=lr.predict(X_test) tr = DecisionTreeClassifier(random_state=2018) # 決策樹模型 tr.fit(X_train, y_train) tr_y_pre=tr.predict(X_test) tr_y_proba=tr.predict_proba(X_test) svm = SVC(probability=True,random_state=2018,tol=1e-6) # SVM模型 svm.fit(X_train, y_train) svm_y_pre=svm.predict(X_test) svm_y_proba=svm.predict_proba(X_test) forest=RandomForestClassifier(n_estimators=100,random_state=2018) # 隨機森林 forest.fit(X_train,y_train) forest_y_pre=forest.predict(X_test) forest_y_proba=forest.predict_proba(X_test) Gbdt=GradientBoostingClassifier(random_state=2018) #CBDT Gbdt.fit(X_train,y_train) Gbdt_y_pre=Gbdt.predict(X_test) Gbdt_y_proba=Gbdt.predict_proba(X_test) Xgbc=XGBClassifier(random_state=2018) #Xgbc Xgbc.fit(X_train,y_train) Xgbc_y_pre=Xgbc.predict(X_test) gbm=lgb.LGBMClassifier(random_state=2018) #lgb gbm.fit(X_train,y_train) gbm_y_pre=gbm.predict(X_test) gbm_y_proba=gbm.predict_proba(X_test)
- 模型評分
# 模型評分
lr_score = lr.score(X_test, y_test)
lr_accuracy_score=accuracy_score(y_test,lr_y_pre)
lr_preci_score=precision_score(y_test,lr_y_pre)
lr_recall_score=recall_score(y_test,lr_y_pre)
lr_f1_score=f1_score(y_test,lr_y_pre)
lr_auc=roc_auc_score(y_test,lr_y_proba[:,1])
print('lr_accuracy_score: %f,lr_preci_score: %f,lr_recall_score: %f,lr_f1_score: %f,lr_auc: %f'
%(lr_accuracy_score,lr_preci_score,lr_recall_score,lr_f1_score,lr_auc))
'lr_accuracy_score: 0.768746,lr_preci_score: 0.688312,lr_recall_score: 0.147632,lr_f1_score: 0.243119,lr_auc: 0.716681'
tr_score = tr.score(X_test, y_test)
tr_accuracy_score=accuracy_score(y_test,tr_y_pre)
tr_preci_score=precision_score(y_test,tr_y_pre)
tr_recall_score=recall_score(y_test,tr_y_pre)
tr_f1_score=f1_score(y_test,tr_y_pre)
tr_auc=roc_auc_score(y_test,tr_y_proba[:,1])
# print('tr_accuracy_score: %f,tr_preci_score: %f,tr_recall_score: %f,tr_f1_score: %f,tr_auc: %f'
# %(tr_accuracy_score,tr_preci_score,tr_recall_score,tr_f1_score,tr_auc))
'tr_accuracy_score: 0.684653,tr_preci_score: 0.382429,tr_recall_score: 0.412256,tr_f1_score: 0.396783,tr_auc: 0.594237'
svm_accuracy_score=accuracy_score(y_test,svm_y_pre)
svm_preci_score=precision_score(y_test,svm_y_pre)
svm_recall_score=recall_score(y_test,svm_y_pre)
svm_f1_score=f1_score(y_test,svm_y_pre)
svm_auc=roc_auc_score(y_test,svm_y_proba[:,1])
print('svm_accuracy_score: %f,svm_preci_score: %f,svm_recall_score: %f,svm_f1_score: %f,svm_auc: %f'
%(svm_accuracy_score,svm_preci_score,svm_recall_score,svm_f1_score,svm_auc))
'svm_accuracy_score: 0.748423,svm_preci_score: 0.000000,svm_recall_score: 0.000000,svm_f1_score: 0.000000,svm_auc: 0.500000'
forest_accuracy_score=accuracy_score(y_test,forest_y_pre)
forest_preci_score=precision_score(y_test,forest_y_pre)
forest_recall_score=recall_score(y_test,forest_y_pre)
forest_f1_score=f1_score(y_test,forest_y_pre)
forest_auc=roc_auc_score(y_test,forest_y_proba[:,1])
print('forest_accuracy_score: %f,forest_preci_score: %f,forest_recall_score: %f,forest_f1_score: %f,forest_auc: %f'
%(forest_accuracy_score,forest_preci_score,forest_recall_score,forest_f1_score,forest_auc))
'forest_accuracy_score: 0.782060,forest_preci_score: 0.681818,forest_recall_score: 0.250696,forest_f1_score: 0.366599,forest_auc: 0.749137'
Gbdt_accuracy_score=accuracy_score(y_test,Gbdt_y_pre)
Gbdt_preci_score=precision_score(y_test,Gbdt_y_pre)
Gbdt_recall_score=recall_score(y_test,Gbdt_y_pre)
Gbdt_f1_score=f1_score(y_test,Gbdt_y_pre)
Gbdt_auc=roc_auc_score(y_test,Gbdt_y_proba[:,1])
print('Gbdt_accuracy_score: %f,Gbdt_preci_score: %f,Gbdt_recall_score: %f,Gbdt_f1_score: %f,Gbdt_auc: %f'
%(Gbdt_accuracy_score,Gbdt_preci_score,Gbdt_recall_score,Gbdt_f1_score,Gbdt_auc))
'Gbdt_accuracy_score: 0.780659,Gbdt_preci_score: 0.611650,Gbdt_recall_score: 0.350975,Gbdt_f1_score: 0.446018,Gbdt_auc: 0.763828'
Xgbc_accuracy_score=accuracy_score(y_test,Xgbc_y_pre)
Xgbc_preci_score=precision_score(y_test,Xgbc_y_pre)
Xgbc_recall_score=recall_score(y_test,Xgbc_y_pre)
Xgbc_f1_score=f1_score(y_test,Xgbc_y_pre)
Xgbc_auc=roc_auc_score(y_test,Xgbc_y_pre)
print('Xgbc_accuracy_score: %f,Xgbc_preci_score: %f,Xgbc_recall_score: %f,Xgbc_f1_score: %f,Xgbc_auc: %f'
%(Xgbc_accuracy_score,Xgbc_preci_score,Xgbc_recall_score,Xgbc_f1_score,Xgbc_auc))
'Xgbc_accuracy_score: 0.785564,Xgbc_preci_score: 0.630542,Xgbc_recall_score: 0.356546,Xgbc_f1_score: 0.455516,Xgbc_auc: 0.643161'
gbm_accuracy_score=accuracy_score(y_test,gbm_y_pre)
gbm_preci_score=precision_score(y_test,gbm_y_pre)
gbm_recall_score=recall_score(y_test,gbm_y_pre)
gbm_f1_score=f1_score(y_test,gbm_y_pre)
gbm_auc=roc_auc_score(y_test,gbm_y_proba[:,1])
print('gbm_accuracy_score: %f,gbm_preci_score: %f,gbm_recall_score: %f,gbm_f1_score: %f,gbm_auc: %f'
%(gbm_accuracy_score,gbm_preci_score,gbm_recall_score,gbm_f1_score,gbm_auc))
'gbm_accuracy_score: 0.770147,gbm_preci_score: 0.570136,gbm_recall_score: 0.350975,gbm_f1_score: 0.434483,gbm_auc: 0.757402'
- 畫出Roc曲線
lr_fpr,lr_tpr,lr_threasholds=roc_curve(y_test,lr_y_proba[:,1]) # 計算ROC的值,lr_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('logist',lr_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(lr_fpr,lr_tpr)
plt.show()
tr_fpr,tr_tpr,tr_threasholds=roc_curve(y_test,tr_y_proba[:,1]) # 計算ROC的值,lr_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('decisiontree',tr_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(tr_fpr,tr_tpr)
plt.show()
svm_fpr,svm_tpr,svm_threasholds=roc_curve(y_test,svm_y_proba[:,1]) # 計算ROC的值,svm_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('svm',svm_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(svm_fpr,svm_tpr)
plt.show()
forest_fpr,forest_tpr,forest_threasholds=roc_curve(y_test,forest_y_proba[:,1]) # 計算ROC的值,svm_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('forest',forest_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(forest_fpr,forest_tpr)
plt.show()
Gbdt_fpr,Gbdt_tpr,Gbdt_threasholds=roc_curve(y_test,Gbdt_y_proba[:,1]) # 計算ROC的值,svm_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('Gbdt',Gbdt_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(Gbdt_fpr,Gbdt_tpr)
plt.show()
Xgbc_fpr,Xgbc_tpr,Xgbc_threasholds=roc_curve(y_test,Xgbc_y_pre) # 計算ROC的值,svm_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('Xgbc',Xgbc_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(Xgbc_fpr,Xgbc_tpr)
plt.show()
gbm_fpr,gbm_tpr,gbm_threasholds=roc_curve(y_test,gbm_y_proba[:,1]) # 計算ROC的值,svm_threasholds為閾值
plt.title("roc_curve of %s(AUC=%.4f)" %('gbm',gbm_auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot(gbm_fpr,gbm_tpr)
plt.show()
- 思考
01 對於roc曲線,直接畫成直線不太理解;
02 對於svm_accuracy_score: 0.748423,svm_preci_score: 0.000000,svm_recall_score: 0.000000,svm_f1_score: 0.000000,svm_auc: 0.500000 其中得出的項為什麼為0不理解 - 參考的文章
機器學習中的 precision、recall、accuracy、F1 Score
分類問題的幾個評價指標(Precision、Recall、F1-Score、Micro-F1、Macro-F1)