1. 程式人生 > >R語言編寫決策樹(rpart)CART ID3演算法

R語言編寫決策樹(rpart)CART ID3演算法

決策樹(decision tree)是一類常見的機器學習方法。以二分類任務為例,我們希望從給定訓練資料集學得一個模型用以對新示例進行分類,這個把樣本分類的任務,可看做對“當前樣本屬於正常嗎?”這個問題的‘決策’或者‘判定’過程。顧名思義,決策樹是基於樹結構來進行決策的,這恰是人類在面臨決策問題時一種很自然的處理機制。

常用的決策樹演算法:

  1. ID3 以資訊增益作為分類標準
  2. CART 以基尼係數作為分類標準
    演算法的具體理論可以參考周志華的《機器學習》

資料預處理

loc<-"http://archive.ics.uci.edu/ml/machine-learning-databases/"
ds<-"breast-cancer-wisconsin/breast-cancer-wisconsin.data" url<-paste(loc,ds,sep="") data<-read.table(url,sep=",",header=F,na.strings="?") names(data)<-c("編號","腫塊厚度","腫塊大小","腫塊形狀","邊緣黏附","單個表皮細胞大小","細胞核大小","染色質","細胞核常規","有絲分裂","類別") #print(data) data$類別[data$類別==2]<-"良性" data$類別[data$類
別==4]<-"惡性" #print(data) data<-data[-1] #刪除第一列元素# #print(data) set.seed(1234) #隨機抽樣設定種子 train<-sample(nrow(data),0.7*nrow(data)) #抽樣函式,第一個引數為向量,nrow()返回行數 後面的是抽樣引數前 tdata<-data[train,] #根據抽樣引數列選擇樣本,都好逗號是選擇行 vdata<-data[-train,] #刪除抽樣行

採用的資料是UCI機器學習資料庫裡的威斯康星州乳腺癌資料集,通過對資料的分析,提取出關鍵特徵來判斷乳腺癌患病情況
tdata為訓練資料集
vdata為測試資料集

建立決策樹

library(rpart)
dtree<-rpart(類別~.,data=tdata,method="class", parms=list(split="information"))
printcp(dtree)

rpart()函式的格式:
rpart(formula,data,weights,subsets,na.action=na.rpart,method,parms,control…)
如果library報錯,需要install資料包
這裡寫圖片描述
使用ID3演算法時候,split = “information” ,使用CART演算法的時候, split = “gini”

決策樹剪枝

剪枝(pruning)是決策樹學習演算法對付“過擬合”的主要手段,在決策樹學習中,為了儘可能正確分類訓練樣本,結合劃分過程將不斷重複,有事會造成決策樹分支過多,這時就可能因訓練樣本學得“太好了”,以致於把訓練集自身的一些特點當做所有資料都具有的一般性質而導致過擬合。因此,可通過主動去掉一些分支來降低過擬合的風險。

print(dtree)

可以看到,訓練之後,採用了四個指標作為分支節點來建立決策樹,而忽略了很多與乳腺癌不相關的特徵

Variables actually used in tree construction:
[1] 細胞核大小 腫塊大小   腫塊厚度   腫塊形狀  

在建立決策樹之後通過printcp可以列印決策樹的複雜性引數,觀察樹的誤差等資料。
這裡寫圖片描述
cp是引數複雜度(complexity parameter)作為控制樹規模的懲罰因子,簡而言之,就是cp越大,樹分裂規模(nsplit)越小。輸出引數(rel error)指示了當前分類模型樹與空樹之間的平均偏差比值。xerror為交叉驗證誤差,xstd為交叉驗證誤差的標準差。可以看到,當nsplit為3的時候,即有四個葉子結點的樹,要比nsplit為4,即五個葉子結點的樹的交叉誤差要小。而決策樹剪枝的目的就是為了得到更小交叉誤差(xerror)的樹。

使用prune()來剪枝,格式:prune(tree,cp,…)
從格式可以看出,按照的是cp值來進行剪枝,選擇cp=0.0125來剪枝

tree<-prune(dtree,cp=0.0125)

如果要寫更加具有通用性的程式碼,可以自動選擇xerror最小時候對應的cp值來剪枝

tree<-prune(dtree,cp=dtree$cptable[which.min(dtree$cptable[,"xerror"]),"CP"])

畫出樹圖

格式:格式 rpart.plot(tree,type,fallen.leaves=T,branch,…)

引數 解釋
tree 畫圖所用的樹模型。
type 可取1,2,3,4.控制圖形中節點的形式。
fallen.leaves fallen.leaves
branch 控制圖的外觀。如branch=1,獲得垂直樹幹的決策樹。
opar<-par(no.readonly = T)
par(mfrow=c(1,2))
library(rpart.plot)
png(file = "./R/tree1.png")
rpart.plot(dtree,branch=1,type=2, fallen.leaves=T,cex=0.8, sub="剪枝前")
png(file = "./R/tree2.png")
rpart.plot(tree,branch=1, type=4,fallen.leaves=T,cex=0.8, sub="剪枝後")
par(opar)
dev.off()

這裡寫圖片描述
可以看到剪枝前後的對比圖,剪枝前有五個葉節點(nplist = 4),剪枝之後(nplist = 3),剪枝之後,具有更小交叉驗證誤差。

利用測試集檢測模型

格式 predict(fit,newdata,type,…)

predtree<-predict(tree,newdata=vdata,type="class")   #利用預測集進行預測
table(vdata$類別,predtree,dnn=c("真實值","預測值"))    #輸出混淆矩陣
真實值 惡性 良性
惡性 79 2
良性 7 122

從混淆矩陣可以看出此模型準確率為(79+122)/(79+2+7+122)=95.71%

使用基尼係數建立決策樹的混淆矩陣:

真實值 惡性 良性
惡性 76 5
良性 7 122

準確率為:(76+122) / (81+129) = 94.29%