1. 程式人生 > >XGBoost使用教程(純xgboost方法)一

XGBoost使用教程(純xgboost方法)一

一、匯入必要的工具包

# 匯入必要的工具包
import xgboost as xgb

# 計算分類正確率
from sklearn.metrics import accuracy_score
二、資料讀取
XGBoost可以載入libsvm格式的文字資料,libsvm的檔案格式(稀疏特徵)如下:
1  101:1.2 102:0.03
1:2.1 10001:300 10002:400
...
每一行表示一個樣本,第一行的開頭的“1”是樣本的標籤“101”和“102”為特徵索引,'1.2'和'0.03' 為特徵的值。

在兩類分類中,用“1”表示正樣本,用“0” 表示負樣本。也支援[0,1]表示概率用來做標籤,表示為正樣本的概率。

下面的示例資料需要我們通過一些蘑菇的若干屬性判斷這個品種是否有毒。
UCI資料描述:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,
每個樣本描述了蘑菇的22個屬性,比如形狀、氣味等等(將22維原始特徵用加工後變成了126維特徵,

並存為libsvm格式),然後給出了這個蘑菇是否可食用。其中6513個樣本做訓練,1611個樣本做測試。

XGBoost載入的資料儲存在物件DMatrix中
XGBoost自定義了一個資料矩陣類DMatrix,優化了儲存和運算速度

DMatrix文件:http://xgboost.readthedocs.io/en/latest/python/python_api.html

# read in data,資料在xgboost安裝的路徑下的demo目錄,現在我們將其copy到當前程式碼下的data目錄
my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')

檢視資料情況

dtrain.num_col()
dtrain.num_row()
dtest.num_row()
三、訓練引數設定
max_depth: 樹的最大深度。預設值為6,取值範圍為:[1,∞]
eta:為了防止過擬合,更新過程中用到的收縮步長。在每次提升計算之後,演算法會直接獲得新特徵的權重。 
eta通過縮減特徵的權重使提升計算過程更加保守。預設值為0.3,取值範圍為:[0,1]
silent:取0時表示打印出執行時資訊,取1時表示以緘默方式執行,不列印執行時資訊。預設值為0
objective: 定義學習任務及相應的學習目標,“binary:logistic” 表示二分類的邏輯迴歸問題,輸出為概率。

其他引數取預設值。
# specify parameters via map
param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }
print(param)

四、訓練模型

# 設定boosting迭代計算次數
num_round = 2

import time
starttime = time.clock()

bst = xgb.train(param, dtrain, num_round) #  dtrain是訓練資料集

endtime = time.clock()
print (endtime - starttime)
XGBoost預測的輸出是概率。這裡蘑菇分類是一個二類分類問題,輸出值是樣本為第一類的概率。

我們需要將概率值轉換為0或1。

train_preds = bst.predict(dtrain)
train_predictions = [round(value) for value in train_preds]
y_train = dtrain.get_label() #值為輸入資料的第一行
train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))

五、測試

模型訓練好後,可以用訓練好的模型對測試資料進行預測

# make prediction
preds = bst.predict(dtest)
檢查模型在測試集上的正確率

XGBoost預測的輸出是概率,輸出值是樣本為第一類的概率。我們需要將概率值轉換為0或1。

predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))

六、模型視覺化

呼叫XGBoost工具包中的plot_tree,在顯示
要視覺化模型需要安裝graphviz軟體包
plot_tree()的三個引數:
1. 模型
2. 樹的索引,從0開始
3. 顯示方向,預設為豎直,‘LR'是水平方向

from matplotlib import pyplot
import graphviz
xgb.plot_tree(bst, num_trees=0, rankdir= 'LR' )
pyplot.show()

#xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )
#pyplot.show()
#xgb.to_graphviz(bst,num_trees=0)
#xgb.to_graphviz(bst,num_trees=1)

七、程式碼整理

# coding:utf-8
import xgboost as xgb

# 計算分類正確率
from sklearn.metrics import accuracy_score

# read in data,資料在xgboost安裝的路徑下的demo目錄,現在我們將其copy到當前程式碼下的data目錄
my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')

dtrain.num_col()

dtrain.num_row()

dtest.num_row()

# specify parameters via map
param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }
print(param)

# 設定boosting迭代計算次數
num_round = 2

import time

starttime = time.clock()

bst = xgb.train(param, dtrain, num_round)  # dtrain是訓練資料集

endtime = time.clock()
print (endtime - starttime)


train_preds = bst.predict(dtrain)    #
print ("train_preds",train_preds)

train_predictions = [round(value) for value in train_preds]
print ("train_predictions",train_predictions)

y_train = dtrain.get_label()
print ("y_train",y_train)

train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))


# make prediction
preds = bst.predict(dtest)
predictions = [round(value) for value in preds]

y_test = dtest.get_label()

test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))

# from matplotlib import pyplot
# import graphviz

import graphviz

# xgb.plot_tree(bst, num_trees=0, rankdir='LR')
# pyplot.show()

# xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )
# pyplot.show()
# xgb.to_graphviz(bst,num_trees=0)
# xgb.to_graphviz(bst,num_trees=1)