1. 程式人生 > >機器學習實戰程式碼_Python3.6_決策樹_程式碼

機器學習實戰程式碼_Python3.6_決策樹_程式碼

決策樹程式碼

from math import log
import operator

def calc_shannon_ent(data_set):
    num_entries = len(data_set)
    label_counts = { }
    for feat_vec in data_set:
        current_label = feat_vec[-1]
        if current_label not in label_counts.keys():
            label_counts[current_label] = 0
        label_counts[current_label] += 1
shannon_ent = 0.0 for key in label_counts: prob = float(label_counts[key]) /num_entries shannon_ent -= prob*log(prob,2) return shannon_ent def split_data_set(data_set, axis, value): return_data_set = [] for feat_vec in data_set: if feat_vec[axis] == value: reduce_feat_vec = feat_vec[:axis] reduce_feat_vec.extend(feat_vec[axis+1
:]) return_data_set.append(reduce_feat_vec) return return_data_set def choose_best_feature_to_split(data_set): num_features = len(data_set[0]) - 1 bese_entropy = calc_shannon_ent(data_set) best_info_gain = 0.0 best_feature = -1 for i in range(num_features): feat_list = [example[i] for
example in data_set] unique_vals = set(feat_list) new_entropy = 0.0 for value in unique_vals: sub_data_set = split_data_set(data_set, i, value) prob = len(sub_data_set)/float(len(data_set)) new_entropy += prob * calc_shannon_ent(sub_data_set) info_gain = bese_entropy - new_entropy if info_gain > best_info_gain : best_info_gain = info_gain best_feature = i return best_feature def majority_cnt(class_list): class_count = {} for vote in class_list: if vote not in class_count.keys(): class_count[vote] = 0 class_count[vote] += 1 sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True) return sorted_class_count[0][0] def create_tree(data_set, labels): class_list = [example[-1] for example in data_set] if class_list.count(class_list[0]) == len(class_list): return class_list[0] if len(data_set[0]) == 1: return majority_cnt(class_list) best_feat = choose_best_feature_to_split(data_set) best_feat_label = labels[best_feat] my_tree = { best_feat_label:{} } del(labels[best_feat]) feat_values = [example[best_feat] for example in data_set] unique_vals = set(feat_values) for value in unique_vals: sub_labels = labels[:] my_tree[best_feat_label][value] = create_tree(split_data_set(data_set, best_feat, value), sub_labels) return my_tree

繪製程式碼

import matplotlib.pyplot as plt
import decison_tree

decision_node = dict(boxstyle='sawtooth', fc='0.8')
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')

def plot_node(node_text, center_pt, parent_pt, node_type):
    creat_plot.axl.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va='center', ha='center', bbox=node_type,)
    creat_plot.axl.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va='center', ha='center', bbox=node_type, arrowprops=arrow_args)

def creat_plot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()   #清除當前 figure 的所有axes,但是不關閉這個 window,所以能繼續複用於其他的 plot。
    creat_plot.axl = plt.subplot(111, frameon=False)
    plot_node('Decision_node', (0.5, 0.1), (0.1, 0.5), decision_node)
    plot_node('Leaf_node', (0.8, 0.1), (0.3, 0.8), leaf_node)
    plt.savefig('tree_plot.png')
    plt.show()


def get_num_leafs(my_tree):
    num_leafs = 0
    first_str = list(my_tree.keys())[0] #首先轉為list型別才可以使用[0],否則報錯,Python3.x區別於書上Python2.x的程式碼
    second_dict = my_tree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]) == dict:
            num_leafs += get_num_leafs(second_dict[key])
        else:
            num_leafs += 1
    return num_leafs

def get_tree_depth(my_tree):
    max_depth = 0
    first_str = list(my_tree.keys())[0]
    second_dict = my_tree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]) == dict:
            this_depth = 1 + get_tree_depth(second_dict[key])
        else:
            this_depth = 1
        if this_depth > max_depth:
            max_depth = this_depth
    return max_depth

def retrieve_tree(i):
    list_of_trees = [ {'no surfacing':{0:'no', 1:{'flippers':{0:'no', 1:'yes'}}}}, \
                      {'no surfacing':{0:'no', 1:{'flippers':{0:{'head':{0:'no', 1:'yes'}}, 1:'n0'}}}}]
    return list_of_trees[i]

def plot_mid_text(cntr_pt, parent_pt, txt_string):
    x_mid = (parent_pt[0] - cntr_pt[0])/2.0 + cntr_pt[0]
    y_mid = (parent_pt[1] - cntr_pt[1])/2.0 + cntr_pt[1]
    creat_plot.axl.text(x_mid, y_mid, txt_string)

def plot_tree(my_tree, parent_pt, node_txt):
    num_leafs = get_num_leafs(my_tree)
    depth = get_tree_depth(my_tree)
    first_str = list(my_tree.keys())[0]
    cntr_pt = (plot_tree.xOff + (1.0+float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yOff)
    plot_mid_text(cntr_pt, parent_pt, node_txt)
    plot_node(first_str, cntr_pt, parent_pt, decision_node)
    second_dict = my_tree[first_str]
    plot_tree.yOff = plot_tree.yOff - 1.0/plot_tree.totalD
    for key in second_dict.keys():
        if type(second_dict[key]) == dict:
            plot_tree(second_dict[key], cntr_pt, str(key))
        else:
            plot_tree.xOff = plot_tree.xOff + 1.0/plot_tree.totalW
            plot_node(second_dict[key], (plot_tree.xOff, plot_tree.yOff), cntr_pt, leaf_node)
            plot_mid_text((plot_tree.xOff, plot_tree.yOff), cntr_pt, str(key))
    plot_tree.yOff = plot_tree.yOff + 1.0/plot_tree.totalD


def creat_plot(in_tree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    creat_plot.axl = plt.subplot(111, frameon=False, **axprops)
    plot_tree.totalW = float(get_num_leafs(in_tree))
    plot_tree.totalD = float(get_tree_depth(in_tree))
    plot_tree.xOff = -0.5/plot_tree.totalW
    plot_tree.yOff = 1.0
    plot_tree(in_tree, (0.5, 1.0), '')
    plt.savefig('tree_plotter.png') #必須先savefig(),否則儲存的是空白影象
    plt.show()                      #不能再show()之後savefig(),否則儲存的就是空白影象

def classify(input_tree, feat_labels, test_vec):
    first_str = list(input_tree.keys())[0]
    second_dict = input_tree[first_str]
    feat_index = feat_labels.index(first_str)
    for key in second_dict.keys():
        if test_vec[feat_index] == key:
            if type(second_dict[key]) == dict:
                class_label = classify(second_dict[key], feat_labels, test_vec)
            else:
                class_label = second_dict[key]
    return class_label


def store_tree(input_tree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(input_tree, fw)
    fw.close()

def grab_tree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)




if __name__ == '__main__':
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lenses_tree = decison_tree.create_tree(lenses, lenses_labels)
    creat_plot(lenses_tree)