1. 程式人生 > >python實現貝葉斯推斷——垃圾郵件分類

python實現貝葉斯推斷——垃圾郵件分類

理論

前期準備

資料來源

資料來源於《機器學習實戰》中的第四章樸素貝葉斯分類器的實驗資料。資料書上只提供了50條資料(25條正常郵件,25條垃圾郵件),感覺資料量偏小,以後打算使用scikit-learn提供的iris資料。

資料準備

和很多機器學習一樣,資料需要拆分成訓練集和測試集。
拆分訓練集和測試集的思路如下:
1.遍歷包含50條資料的email資料夾,獲取檔案列表
2.使用random.shuffle()函式打亂列表
3.擷取亂序後的檔案列表前10個檔案路徑,並轉移到test資料夾下,作為測試集。
程式碼實現:

# -*- coding: utf-8 -*-
# @Date : 2017-05-09 13:06:56 # @Author : Alan Lau ([email protected]) # @Language : Python3.5 # from fwalker import fun import random # from reader import writetxt, readtxt import shutil import os def fileWalker(path): fileArray = [] for root, dirs, files in os.walk(path): for
fn in files: eachpath = str(root+'\\'+fn) fileArray.append(eachpath) return fileArray def main(): filepath = r'..\email' testpath = r'..\test' files = fileWalker(filepath) random.shuffle(files) top10 = files[:10] for ech in top10: ech_name = testpath+'\\'
+('_'.join(ech.split('\\')[-2:])) shutil.move(ech, testpath) os.rename(testpath+'\\'+ech.split('\\')[-1], ech_name) print('%s moved' % ech_name) if __name__ == '__main__': main()

最後獲取的檔案列表如下:


copy是備份資料,防止操作錯誤
ham檔案列表:


spam檔案列表:

test檔案列表:


可見,資料準備後的測試集中,有7個垃圾郵件,3個正常的郵件。

程式碼實現

# -*- coding: utf-8 -*-
# @Date     : 2017-05-09 09:29:13
# @Author   : Alan Lau ([email protected])
# @Language : Python3.5

# from fwalker import fun
# from reader import readtxt
import os


def readtxt(path, encoding):
    with open(path, 'r', encoding = encoding) as f:
        lines = f.readlines()
    return lines

def fileWalker(path):
    fileArray = []
    for root, dirs, files in os.walk(path):
        for fn in files:
            eachpath = str(root+'\\'+fn)
            fileArray.append(eachpath)
    return fileArray

def email_parser(email_path):
    punctuations = """,.<>()*&^%$#@!'";~`[]{}|、\\/~+_-=?"""
    content_list = readtxt(email_path, 'utf8')
    content = (' '.join(content_list)).replace('\r\n', ' ').replace('\t', ' ')
    clean_word = []
    for punctuation in punctuations:
        content = (' '.join(content.split(punctuation))).replace('  ', ' ')
        clean_word = [word.lower()
                      for word in content.split(' ') if len(word) > 2]
    return clean_word


def get_word(email_file):
    word_list = []
    word_set = []
    email_paths = fileWalker(email_file)
    for email_path in email_paths:
        clean_word = email_parser(email_path)
        word_list.append(clean_word)
        word_set.extend(clean_word)
    return word_list, set(word_set)


def count_word_prob(email_list, union_set):
    word_prob = {}
    for word in union_set:
        counter = 0
        for email in email_list:
            if word in email:
                counter += 1
            else:
                continue
        prob = 0.0
        if counter != 0:
            prob = counter/len(email_list)
        else:
            prob = 0.01
        word_prob[word] = prob
    return word_prob


def filter(ham_word_pro, spam_word_pro, test_file):
    test_paths = fileWalker(test_file)
    for test_path in test_paths:
        email_spam_prob = 0.0
        spam_prob = 0.5
        ham_prob = 0.5
        file_name = test_path.split('\\')[-1]
        prob_dict = {}
        words = set(email_parser(test_path))
        for word in words:
            Psw = 0.0
            if word not in spam_word_pro:
                Psw = 0.4
            else:
                Pws = spam_word_pro[word]
                Pwh = ham_word_pro[word]
                Psw = spam_prob*(Pws/(Pwh*ham_prob+Pws*spam_prob))
            prob_dict[word] = Psw
        numerator = 1
        denominator_h = 1
        for k, v in prob_dict.items():
            numerator *= v
            denominator_h *= (1-v)
        email_spam_prob = round(numerator/(numerator+denominator_h), 4)
        if email_spam_prob > 0.5:
            print(file_name, 'spam', email_spam_prob)
        else:
            print(file_name, 'ham', email_spam_prob)
        # print(prob_dict)
        # print('******************************************************')
        # break


def main():
    ham_file = r'..\email\ham'
    spam_file = r'..\email\spam'
    test_file = r'..\email\test'
    ham_list, ham_set = get_word(ham_file)
    spam_list, spam_set = get_word(spam_file)
    union_set = ham_set | spam_set
    ham_word_pro = count_word_prob(ham_list, union_set)
    spam_word_pro = count_word_prob(spam_list, union_set)
    filter(ham_word_pro, spam_word_pro, test_file)


if __name__ == '__main__':
    main()

實驗結果

ham_24.txt ham 0.0005
ham_3.txt ham 0.0
ham_4.txt ham 0.0
spam_11.txt spam 1.0
spam_14.txt spam 0.9999
spam_17.txt ham 0.0
spam_18.txt spam 0.9992
spam_19.txt spam 1.0
spam_22.txt spam 1.0
spam_8.txt spam 1.0

可見正確率為90%,實際上嚴格來說,應當將所有資料隨機均分十組,每一組輪流作為一次測試集,剩下九組作為訓練集,再將十次計算結果求均值,這個模型求出的分類效果才具有可靠性,其次,資料量小導致準確率較小的原因不排除在外。