1. 程式人生 > >機器學習第七篇

機器學習第七篇

決策樹

相比於其他方法,決策樹是一種更為簡單的機器學習方法,它是對被觀測資料進行分類的一種相當直觀的方法,決策樹在經過訓練之後,看起來更像是以樹狀形式排列的一系列if-then語句。只要沿著樹的路徑一直向下,正確回答每一個問題,最終就會得到答案,沿著最終的葉節點向上回溯,就會得到一個有關最終分類結果的推理過程。

一、預測註冊使用者

推出功能時,向所有註冊使用者群發郵件的方式,不具有針對性

如果我們知道有哪些因素可以表明使用者將會成為付費顧客,那麼就可以利用這些資訊指導我們的廣告策咯制定工作,讓網站的某些功能具有更好的可用性,或者採取其他能夠有效增加付費顧客數量的策略。

為了減少使用者的工作量,使其能夠儘快的註冊賬號

,網站不會過多的詢問使用者的個人資訊,相反,它會從伺服器的日誌中收集這些資訊,比如:使用者來自哪個網站,所在的地理位置,以及他們在註冊之前曾經瀏覽過多少網頁,等等。

新建treepredict.py

#來源網站、位置、是否閱讀過FAQ、瀏覽網頁數、選擇服務型別
my_data=[['slashdot','USA','yes',18,'None'],
         ['google','France','yes',23,'Premium'],
         ['digg','USA','yes',24,'Basic'],
         ['kiwitobes','France','yes',23,'Basic'],
         ['google','UK','no',21,'Premium'],
         ['(direct)','New Zealand','no',12,'None'],
         ['(direct)','UK','no',21,'Basic'],
         ['google','USA','no',24,'Premium'],
         ['slashdot','France','yes',19,'None'],
         ['digg','USA','no',18,'None'],
         ['google','UK','no',18,'None'],
         ['kiwitobes','UK','no',19,'None'],
         ['digg','New Zealand','yes',12,'Basic'],
         ['slashdot', 'UK', 'no', 21, 'None'],
         ['google','UK','yes',18,'Basic'],
         ['kiwitobes','France','yes',19,'Basic']]

或使用  my_data=[line.split('\t') for line in open('decision_tree_example.txt')]  將檔案載入進來

現在,只需找到一種方法,能夠將一個合理的推測值填入‘服務’欄即可。

類 decisionnode代表樹上的每一個節點

class decisionnode:
    def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
        self.col=col
        self.value=value
        self.results=results
        self.tb=tb
        self.fb=fb

對樹進行訓練:CART

首先建立一個根節點,然後通過評估表中的所有觀測變數,從中選出最合適的變數對資料進行拆分。

divideset(rows,column,value):根據列表中某一欄的資料將列表拆分成兩個資料集

#lambda作為一個表示式,定義了一個匿名函式
#在某一列上對資料進行拆分,能夠處理數值型資料或名詞性資料
def divideset(rows,column,value):
    #定義一個函式,令其告訴我們資料行屬於第一組還是第二組
    split_function=None
    if isinstance(value,int) or isinstance(value,float):
        split_function=lambda row:row[column]>=value
    else:
        split_function=lambda row:row[column]==value
    #將資料集拆分成兩個集合,並返回
    set1=[row for row in rows if split_function(row)]
    set2=[row for row in rows if not split_function(row)]
    return (set1,set2)

按照是否閱讀過FAQ劃分: 

 

拆分結果並不理想,因為兩邊似乎都混雜了各種情況,我們需要一種方法來選擇最合適的變數。

選擇最合適的拆分方案

uniquecounts(rows):找出所有不同的可能結果,並返回一個字典,其中包含了每一項的出現次數,其他函式將利用該函式來計算資料集合中的混雜程度

#對各種可能的結果進行計數(每一行資料的最後一列記錄了這一計數結果)
def uniquecounts(rows):
    results={}
    for row in rows:
        #計數結果在最後一行
        r=row[len(row)-1]
        if r not in results: results[r]=0
        results[r]+=1
    return results

基尼不純度:將來自集合中的某種結果隨機應用於集合中某一資料項的預期誤差率

#隨機放置的資料項出現於錯誤分類中的概率
def giniimpurity(rows):     #******************************************
    total=len(rows)
    counts=uniquecounts(rows)
    imp=0
    #
    for k1 in counts:
        p1=float(counts[k1])/total
        for k2 in counts:
            if k1==k2: continue
            p2=float(counts[k2])/total
            imp+=p1*p2
    return imp

沒搞懂。。。。。。。。。。。。。。

熵:代表集合的無序程度

#熵是遍歷所有可能的結果之後所得的p(x)log(p(x))之和
def entropy(rows):    #******************************************
    from math import log
    log2=lambda x:log(x)/log(2)
    results=uniquecounts(rows)
    #此處開始計算熵的值
    ent=0.0
    for r in results.keys():
        p=float(results[r])/len(rows)
        ent=ent-p*log2(p)
    return ent

沒搞懂。。。。。。。。。。。。。。

群組越是混亂,相應的熵就越高

熵和基尼不純度之間的主要區別在於,熵達到峰值的過程要相對慢一些。因此,熵對於混亂集合的‘判罰’往往更重一些

以遞迴方式構樹

 

def buildtree(rows,scoref=entropy):
    if len(rows)==0: return decisionnode()
    current_score=scoref(rows)

    #定義一些變數以記錄最佳拆分條件
    best_gain=0.0
    best_criteria=None
    best_sets=None

    column_count=len(rows[0])-1
    for col in range(0,column_count):
        #在當前列中生成一個由不同值構成的序列
        column_values={}
        for row in rows:
            column_values[row[col]]=1
        #接下來根據這一列中的每個值,嘗試對資料集進行拆分
        for value in column_values.keys():
            (set1,set2)=divideset(rows,col,value)
            #資訊增溢
            p=float(len(set1))/len(rows)
            gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
            if gain>best_gain and len(set1)>0 and len(set2)>0:
                best_gain=gain
                best_criteria=(col,value)
                best_sets=(set1,set2)
    if best_gain>0:
        trueBranch=buildtree(best_sets[0])
        falseBranch=buildtree(best_sets[1])
        return decisionnode(col=best_criteria[0],value=best_criteria[1],
                            tb=trueBranch,fb=falseBranch)
    else:
        return decisionnode(results=uniquecounts(rows))

buildtree遍歷資料集中的每一列(最後一列除外),針對各列查詢一種可能的取值(資訊增溢最大),並將資料集拆分成兩個新的子集若由熵值最低的一對子集求得的加權平均熵比當前的熵要大,則拆分過程就結束了,否則演算法就會在新生成的子集上繼續呼叫buildtree函式,並把呼叫所得的結果新增到樹上。

決策樹顯示

 

文字顯示:

def printtree(tree,indent=''):
    #這是一個葉節點嗎
    if tree.results!=None:
        print(str(tree.results))
    else:
        #列印判斷條件
        print(str(tree.col)+':'+str(tree.value)+'?')
        #列印分支
        print(indent+'T->', )
        printtree(tree.tb,indent+'  ')
        print(indent+'F->', )
        printtree(tree.fb,indent+'  ')

圖形顯示:


def getwidth(tree):
    if tree.tb==None and tree.fb==None: return 1
    return getwidth(tree.tb)+getwidth(tree.fb)

def getdepth(tree):
    if tree.tb==None and tree.fb==None: return 0
    return max(getdepth(tree.tb),getdepth(tree.fb))+1

from PIL import Image,ImageDraw

def drawtree(tree,jpeg='tree.jpg'):
    w=getwidth(tree)*100
    h=getdepth(tree)*100+120

    img=Image.new('RGB',(w,h),(255,255,255))
    draw=ImageDraw.Draw(img)

    drawnode(draw,tree,w/2,20)
    img.save(jpeg,'JPEG')

def drawnode(draw,tree,x,y):
    if tree.results==None:
        #得到每個分支的寬度
        w1=getwidth(tree.fb)*100
        w2=getwidth(tree.tb)*100

        #確定此節點所要佔據的總空間
        left=x-(w1+w2)/2
        right=x+(w1+w2)/2

        #繪製判斷條件字串
        draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))

        #繪製到分支的連線
        draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
        draw.line((x, y, right - w2 / 2, y + 100), fill=(255, 0, 0))

        #繪製分支的節點
        drawnode(draw,tree.fb,left+w1/2,y+100)
        drawnode(draw, tree.tb, right - w2 / 2, y + 100)
    else:
        txt=' \n'.join(['%s:%d' %v for v in tree.results.items()])
        draw.text((x-20,y),txt,(0,0,0))
tree=buildtree(my_data)
drawtree(tree,jpeg='treeview.jpeg')

 

 對新的觀測資料進行分類

def classify(observation,tree):
    print(tree.results)
    if tree.results!=None:
        return tree.results
    else:
        v=observation[tree.col]
        branch=None
        if isinstance(v,int) or isinstance(v,float):
            if v>=tree.value: branch=tree.tb
            else: branch=tree.fb
        else:
            if v==tree.value: branch=tree.tb
            else: branch=tree.fb
        return classify(observation,branch)


print(classify(['(direct)','USA','yes',5],tree))