1. 程式人生 > >【決策樹】ID3演算法理解與R語言實現

【決策樹】ID3演算法理解與R語言實現

一、演算法理解

想來想去,還是決定用各大暢銷書中的相親例子來解釋什麼叫決策樹。


簡單來說,決策樹就是根據各種變數,作為輸入條件,最終輸出決策的過程。比如上圖中女方在相親過程中,影響是否見男方的變數有年齡、長相、收入、是否是公務員等。

最終在各種變數組合下,最終輸出見或不見的決策。

下邊是決策樹的一種定義:

決策樹(decision tree)是一個樹結構(可以是二叉樹或非二叉樹)。其每個非葉節點表示一個特徵屬性上的測試,每個分支代表這個特徵屬性在某個值域上的輸出,而每個葉節點存放一個類別。使用決策樹進行決策的過程就是從根節點開始,測試待分類項中相應的特徵屬性,並按照其值選擇輸出分支,直到到達葉子節點,將葉子節點存放的類別作為決策結果。

二、數學公式

對於決策樹有大體認識後,我們來討論其背後的包含的數學理論支撐,主要是資訊理論中的資訊。為了理解,我們需要了解兩個數學概念。

資訊熵:熵是無序性(或不確定性)的度量指標。假如事件A的全概率劃分是(A1,A2,...,An),每部分發生的概率是(p1,p2,...,pn),那資訊熵的公式如下:


資訊增益

簡單來說,就是在某種變數算出其相應的資訊熵後,用總體資訊熵減去,即為該變數的資訊增益。比如,我們算出見與不見總體的資訊熵,減去年齡變數的資訊熵,即為資訊增益,Gain(x)。

一般我們選擇資訊增益最大的變數進行節點劃分,這樣能快速對決策樹進行分叉,並且保證決策樹的高度最小。

舉例說明:

在某社群中,我們根據某使用者的使用者部落格密度,好友密度、是否使用真實頭像來判斷該使用者是真人還是機器人。

具體的資料如下:

日誌密度 好友密度 是否真實頭像 賬號是否是真實的
s s no no
s l yes yes
l m yes yes
m m yes yes
l m yes yes
m l no yes
m s no no
l m no yes
m s no yes
s s yes no

很明顯,我們需要判別的分類變數為賬號是否是真實的。於是,我們計算該變數的資訊熵為:


0.7代表上述訓練集中,賬號為真的概率為0.7,賬號為假的概率為0.3

下面,我們再計算一下日誌密度(簡稱L)變數的資訊熵:


第一個0.3代表日誌密度為L的概率為0.3,括號中0/3 代表在日誌密度為L的情況下,賬號為假的概率為0/3 ,後邊的3/3代表在日誌密度為L的情況下,賬號為真的概率為3/3.

0.4代表日誌密度為M的概率為0.4,括號中1/4 代表在日誌密度為M的情況下,賬號為假的概率為1/4 ,後邊的3/4代表在日誌密度為L的情況下,賬號為真的概率為3/4.

0.3代表日誌密度為S的概率為0.3,括號中2/3代表在日誌密度為S的情況下,賬號為假的概率為2/3 ,後邊的1/3代表在日誌密度為L的情況下,賬號為真的概率為1/3.

在計算其它變數的資訊熵時,也是這個邏輯,這裡不再贅述,最終算得:

總體的資訊熵為:0.879

日誌密度L的資訊熵為:0.603 ,資訊增益Gain(L) = 0.879-0.603=0.276。

同理,好友密度的資訊增益為0.553。真實頭像的資訊增益為0.033。

我們以資訊增益最大的變數作為初始的分支判斷條件。

------------------

其實,不管算哪個變數的熵值,都是在以決策結果變數為維度算,只不過限制在了某個變數等於特定值的子集中去算了。

如果一個變數在某種取值下,決策變數的取值也唯一(在上例子中,好友密度為M的情況下,是真實賬號的情況權威yes),這時候該變數在該取值下的資訊熵為0,

我們稱該節點的純度較高。

我們選擇純度高、資訊熵高的變數,因為拿這種變數進行劃分,最能直接將樹節點分分開。

可以通過下邊的R語言自定義函式中的熵值計算函式以及決定用哪個變數拆分函式來理解這個道理。

三、用R語言自帶包實現演算法

用R中的Rpart包實現iris資料集分類的程式碼:

SNS<-read.csv("./DataSource/SNS.data.csv")

library(rpart)

#使用rpart包並傳引數
iris.rp<-rpart(class~.,data = iris,method = "class") 

#畫圖
plot(iris.rp,uniform = T,branch = 0,margin = 0.1,main="iris ID3")#http://f.dataguru.cn/thread-121228-1-1.html
text(iris.rp,use.n = T,col="blue",cex=1.2) #use.n 是控制下邊50/0/0樣本分類概況,col字型顏色、cex 字型大小


#用fancyRpartPlot畫圖,但是rattle包總是安裝失敗

以下是輸出結果:


四、用R語言自定義函式實現演算法

#用R語言實現決策樹ID3演算法,以iris資料集為例

#計算總體資訊值的函式,這裡只允許最後一列作為決策結果列
info<-function(dataSet){
  rowCount=nrow(dataSet) #計算資料集中有幾行,也即有幾個樣本點
  colCount=ncol(dataSet)
  resultClass=NULL
  resultClass=levels(factor(dataSet[,colCount]))  #此程式碼取得判別列中有個可能的值,輸出  "Iris-setosa"     "Iris-versicolor" "Iris-virginica" 
  classCount=NULL
  classCount[resultClass]=rep(0,length(resultClass)) #以決策變數的值為下標構建計數陣列,用於計算和儲存樣本中出現相應變數的個數
  
  for(i in 1:rowCount){ #該for迴圈的作用是計算決策變數中每個值出現的個數,為計算資訊值公式做準備
    if(dataSet[i,colCount] %in% resultClass){
      temp=dataSet[i,colCount]
      classCount[temp]=classCount[temp]+1
    }
  }
  
  #計算總體的資訊值
  t=NULL
  info=0
  for (i in 1:length(resultClass)) {
    t[i]=classCount[i]/rowCount
    info=-t[i]*log2(t[i])+info
  }
  return(info)
}


#拆分資料集,此函式的作用在於對於每列自變數,按照其包含的類別值將原始資料集按行拆分,以便在這個子集中計算特定自變數的熵值
splitDataSet<-function(originDataSet,axis,value){#含義即從originDataSet資料集中拆分出第axis個變數等於value的所有行,合併成子集
  retDataSet=NULL
  for (i in 1:nrow(originDataSet)) { #迴圈原始資料集所有行
    if(originDataSet[i,axis]==value){ #限制特定自變數,遇到目標值則記錄下原始資料集整行,然後rbind行連線
      tempDataSet=originDataSet[i,]
      retDataSet=rbind(tempDataSet,retDataSet)
    }
  }
  rownames(retDataSet)=NULL
  return(retDataSet) #返回針對某個自變數的值篩選後的子集
}

#選擇最佳拆分變數
chooseBestFeatureToSplita<-function(dataSet){
  bestGain=0.0
  bestFeature=-1
  baseInfo=info(dataSet) #計算總的資訊熵
  numFeature<-ncol(dataSet)-1 #計算除決策變數之外的所有列,即為自變數個數 
  for (i in 1:numFeature) {#對於每個自變數計算資訊熵
    featureInfo=0.0
    Feature=dataSet[,i]#定位到第i列
    classCount=levels(factor(Feature)) #計算第i列中變數類別,即有幾種值
    for (j in 1:classCount) { 
    subDataSet=splitDataSet(dataSet,i,Feature[j]) #將dataSet中第i個變數等於Feature[j]的行拆分出來
    newInfo=info(subDataSet) #計算該子集的資訊熵,也就是計算該變數在該取值下的資訊熵部分
    prob=length(subDataSet[,1]*1.0)/nrow(dataSet)# 這裡計算該變數等於Feature[j]的情況在總資料集中出現的概率
    featureInfo=featureInfo+prob*newInfo #不不斷將該變數下各部分資訊熵加總
    } #第第i個變數的資訊熵計算結束
    
    infoGain=baseInfo-featureInfo 
    if(infoGain>bestGain){ #
      bestGain=infoGain
      bestFeature=i
    }
    
  }# 所有所有變數資訊熵計算結束,並且得出了最佳拆分變數
  return(bestFeature) #返回最佳變數值
}


#最終判斷屬於哪一類的條件  
majorityCnt <- function(classList){  
  classCount = NULL  
  count = as.numeric(table(classList))  
  majorityList = levels(as.factor(classList))  
  if(length(count) == 1){  
    return (majorityList[1])  
  }else{  
    f = max(count)  
    return (majorityList[which(count == f)][1])  
  }  
}  

#判斷剩餘的值是否屬於同一類,是否已經純淨了
trick <- function(classList){  
  count = as.numeric(table(classList))  
  if(length(count) == 1){  
    return (TRUE)  
  }else  
    return (FALSE)  
} 

#遞迴生成樹
createTree<-function(dataSet){
  decision_tree = list()  
  classList = dataSet[,length(dataSet)]  
  #判斷是否屬於同一類  
  if(trick(classList))  
    return (rbind(decision_tree,classList[1]))  
  #是否在矩陣中只剩Label標籤了,若只剩最後一列,則都分完了  
  if(ncol(dataSet) == 1){  
    decision_tree = rbind(decision_tree,majorityCnt(classList))  
    return (decision_tree)  
  } 
  
  #選擇最佳屬性進行分割
  bestFeature=chooseBestFeatureToSplita(dataSet)
  labelFeature=colnames(dataSet)[bestFeature] #獲取最佳劃分屬性的變數名
  decision_tree=rbind(decision_tree,labelFeature) #這裡rbind方法,如果有一個變數列數不足,會自動重複補齊
  t=dataSet[,bestFeature]
  temp_tree=data.frame()
  for(j in 1:length(levels(as.factor(t)))){  
    #這個標籤的兩個屬性,比如“yes”,“no”,所屬的資料集  
    dataSet = splitDataSet(dataSet,bestFeature,levels(as.factor(t))[j])  
    dataSet=dataSet[,-bestFeature]  
    #遞迴呼叫這個函式  
    temp_tree = createTree(dataSet)  
    decision_tree = rbind(decision_tree,temp_tree)  
  } 
  return (decision_tree)
}

t<-createTree(iris)





以上程式碼及問題說明請訪問我的github:https://github.com/HelloMrChen/AlgorithmPractise-R