1. 程式人生 > >西瓜書上樸素貝葉斯的實現,完全按照書上的步驟

西瓜書上樸素貝葉斯的實現,完全按照書上的步驟

注:西瓜書上的資料有錯誤如P152的5/8=0.375,所以程式碼的計算是正確的。如果讀者想要“拉普拉斯修正“的原始碼請訪問https://download.csdn.net/download/song91425/10385345 。 所謂的拉普拉斯就是避免出現概率為0的情況。

import numpy as np


def load_data(filepath):
    '''
    :arg filepath  filepath是資料的路徑
    :fun 載入資料:1,青綠,蜷縮,濁響,清晰,凹陷,硬滑,0.697,0.46,是
    :return 載入後的資料
    '''

    file_object = open(filepath, encoding='UTF-8')
    train_data = []
    file_object.readline()
    while 1:
        data = file_object.readline()
        if not data:
            break
        else:
            train_data.append(data)
    file_object.close()
    test = []
    for s in train_data:
        test.append(s.replace('\n', '').split(','))   #去掉\n和把資料按照’,‘分割再存
    return test


def count_labels(data):
    '''

    :param data:資料集
    :return: 返回好瓜和壞瓜的數目
    '''
    yes = 0
    no = 0
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            yes += 1
        else:
            no += 1
    return yes, no


def handle_one_data(data, attr, location, yes, no):
    '''
    :param data: 資料集
    :param attr: 要傳入的屬性
    :param location: 傳入屬性的位置
    :param yes: 好瓜數量
    :param no: 壞瓜數量
    :return: 返回該屬性在好瓜或者是壞瓜的前提下的概率
    '''
    attr_y, attr_n = 0, 0
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            if data[s][location] == attr:
                attr_y += 1
        else:
            if data[s][location] == attr:
                attr_n += 1
    return attr_y / yes, attr_n / no


def handle_data(data):
    '''

    :param data: 資料集
    :return: 對密度和含糖率的均值和標準差
    '''
    midu_y = []
    tiandu_y = []
    midu_n = []
    tiandu_n = []
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            midu_y.append(np.float(data[s][-3]))
            tiandu_y.append(np.float(data[s][-2]))
        else:
            midu_n.append(np.float(data[s][-3]))
            tiandu_n.append(np.float(data[s][-2]))
    m_midu_y = np.mean(midu_y)
    m_midu_n = np.mean(midu_n)
    t_tiandu_y = np.mean(tiandu_y)
    t_tiandu_n = np.mean(tiandu_n)
    std_midu_y = np.std(midu_y)
    std_midu_n = np.std(midu_n)
    std_tiandu_y = np.std(tiandu_y)
    std_tiandu_n = np.std(tiandu_n)

    return m_midu_y, m_midu_n, t_tiandu_y, t_tiandu_n, std_midu_y, std_midu_n, std_tiandu_y, std_tiandu_n


def show_result(p_yes, p_no):
    '''

    :param p_yes: 在好瓜的前提下,測試資料各個屬性的概率
    :param p_no: 在是壞瓜的前提下,測試資料的各個屬性的概率
    :return: 是好瓜或者是壞瓜
    '''
    p1 = 1.0
    p2 = 1.0
    for s in range(p_yes.__len__()):
        p1 *= np.float(p_yes[s])
        p2 *= np.float(p_no[s])
    if p1 > p2:
        print("好瓜", p1, p2)
    else:
        print("壞瓜", p1, p2)


def count_attr_dis(data):
    '''

    :param data: 資料集
    :return: 各個屬性取值的個數
    '''
    count = [] # 記錄各個屬性的取值有多少個不同
    for i in range(data[0].__len__()):
        if i == 0 or i == 7 or i == 8: # 去掉編號,密度,甜度這個屬性
           continue
        d = []
        for s in range(data.__len__()):
            if not d.__contains__(data[s][i]): # 如果讀到的屬性不包含在d裡就加入到d中
                d.append(data[s][i])
        count.append(d.__len__())  # 統計屬性取值不同的個數
    return count


if __name__ == '__main__':
    filepath = 'D:\\pycharm\\bayes.txt'
    data = load_data(filepath)
    m_midu_y, m_midu_n, t_tiandu_y, t_tiandu_n, std_midu_y, std_midu_n, std_tiandu_y, std_tiandu_n = handle_data(data)
    yes, no = count_labels(data)
    p_yes = [yes / (yes + no)]
    p_no = [no / (yes + no)]
    test_data = ['青綠', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.697, 0.460]
    for s in range(6):
        s_yes, s_no = handle_one_data(data, test_data[s], s+1, yes, no)
        p_yes.append(s_yes)
        p_no.append(s_no)

    #求西瓜書公式(7.18)
    p_yes.append(1/(np.sqrt(2*np.pi) * std_midu_y) * np.exp((-1) * ((test_data[6] - m_midu_y)**2)/std_midu_y**2))
    p_no.append(1/(np.sqrt(2 * np.pi) * std_midu_n) * np.exp((-1) * ((test_data[6] - m_midu_n) ** 2) / std_midu_n ** 2))

    p_yes.append(1/(np.sqrt(2 * np.pi) * std_tiandu_y) * np.exp((-1) * ((test_data[7] - t_tiandu_y) ** 2) / std_tiandu_y ** 2))
    p_no.append(1/(np.sqrt(2 * np.pi) * std_tiandu_n) * np.exp((-1) * ((test_data[7] - t_tiandu_n) ** 2) / std_tiandu_n ** 2))

    print(p_yes)
    print(p_no)
    show_result(p_yes, p_no)

    # 防止某個屬性的取值個數為0的概率出現,採用拉皮拉斯修正(各個屬性不同取值已經完成如函式count_attr_dis)

    print(count_attr_dis(data), '不同屬性取值')