1. 程式人生 > >利用貝葉斯演算法實現手寫體識別

利用貝葉斯演算法實現手寫體識別

之前記錄過利用knn實現手寫體識別。現在記錄一下利用貝葉斯演算法實現,訓練資料和測試資料和knn的一樣。

首先了解貝葉斯理論知識。

貝葉斯分類是一類分類演算法的總稱,這類演算法均以貝葉斯定理為基礎,故統稱為貝葉斯分類。而樸素樸素貝葉斯分類是貝葉斯分類中最簡單,也是常見的一種分類方法。

那麼既然是樸素貝葉斯分類演算法,它的核心演算法又是什麼呢?

是下面這個貝葉斯公式:

換個表達形式就會明朗很多,如下:

                                                    

我們最終求的p(類別|特徵)即可!就相當於完成了我們的任務。

  • 訓練資料(求P(類別))
class Bayes:
    def __init__(self):
        self.length=-1
        self.labelrate=dict()
        self.vectorrate=dict()
    def fit(self,dataset:list,labels:list):
        if len(dataset)!=len(labels):
            raise ValueError("輸入測試陣列和類別陣列長度不一致")
        self.length=len(dataset[0])#訓練資料特徵值的長度
        labelsnum=len(labels) #類別的數量
        norlabels=set(labels) #不重複類別的數量
        for item in norlabels:
            self.labelrate[item]=labels.count(item)/labelsnum #求當前類別佔總類別的比例
        for vector,label in zip(dataset,labels):
            if label not in self.vectorrate:
                self.vectorrate[label]=[]
            self.vectorrate[label].append(vector)
        print("訓練結束")
        return self
  • 測試資料(求P(特徵|類別)/P(特徵))
    def btest(self,testdata,labelset):
        if self.length==-1:
            raise ValueError("未開始訓練,先訓練")
        #計算testdata分別為各個類別的概率
        lbDict=dict()
        for thislb in labelset:
            p = 1
            alllabel = self.labelrate[thislb]
            allvector = self.vectorrate[thislb]
            vnum=len(allvector)
            allvector=npy.array(allvector).T
            for index in range(0,len(testdata)):
                vector=list(allvector[index])
                p*=vector.count(testdata[index])/vnum
            lbDict[thislb]=p * alllabel
        thislbabel=sorted(lbDict,key=lambda x:lbDict[x],reverse=True)[0]
        return thislbabel

將測試資料計算的P(類別|特徵)進行排序,(每一個lbDict字典內容是測試資料0~9標籤與訓練資料標籤0~9所對應的概率)

{0: 3.1868338646386474e-110, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.6477211419058441e-296, 5: 2.955403551519686e-240, 6: 0.0, 7: 0.0, 8: 6.040460506986624e-226, 9: 6.948609891826844e-210}

比如標籤0,結果貝葉斯公式得到滿足0的特徵值且類別為0的概率為3.1868338646386474e-110,依此論推。

  • 載入資料和取label值在之前knn中寫到過,因為訓練資料和測試資料一樣,所以可以直接使用之前的方法。
  • 實現識別及大概計算出錯率:
labelsall=[0,1,2,3,4,5,6,7,8,9]
#識別多個手寫體數字(批量處理)
testfile=os.listdir("............/testdata")
num=len(testfile)
x=0
for i in range(num):
    thisfilename=testfile[i]
    thislabel=seplabel(thisfilename)
    thisdataarr=datatoarray(".....testdata/"+thisfilename)
    label=bys.btest(thisdataarr,labelsall)
    print("測試數字是:"+str(thislabel)+"識別出來的數字是:"+str(label))
    if label!=thislabel:
        x+=1
        print("識別出錯")
print(x)
print("出錯率:"+str(x/num))

效果圖: