利用貝葉斯演算法實現手寫體識別
阿新 • • 發佈:2019-02-17
之前記錄過利用knn實現手寫體識別。現在記錄一下利用貝葉斯演算法實現,訓練資料和測試資料和knn的一樣。
首先了解貝葉斯理論知識。
貝葉斯分類是一類分類演算法的總稱,這類演算法均以貝葉斯定理為基礎,故統稱為貝葉斯分類。而樸素樸素貝葉斯分類是貝葉斯分類中最簡單,也是常見的一種分類方法。
那麼既然是樸素貝葉斯分類演算法,它的核心演算法又是什麼呢?
是下面這個貝葉斯公式:
換個表達形式就會明朗很多,如下:
我們最終求的p(類別|特徵)即可!就相當於完成了我們的任務。
- 訓練資料(求P(類別))
class Bayes: def __init__(self): self.length=-1 self.labelrate=dict() self.vectorrate=dict() def fit(self,dataset:list,labels:list): if len(dataset)!=len(labels): raise ValueError("輸入測試陣列和類別陣列長度不一致") self.length=len(dataset[0])#訓練資料特徵值的長度 labelsnum=len(labels) #類別的數量 norlabels=set(labels) #不重複類別的數量 for item in norlabels: self.labelrate[item]=labels.count(item)/labelsnum #求當前類別佔總類別的比例 for vector,label in zip(dataset,labels): if label not in self.vectorrate: self.vectorrate[label]=[] self.vectorrate[label].append(vector) print("訓練結束") return self
- 測試資料(求P(特徵|類別)/P(特徵))
def btest(self,testdata,labelset): if self.length==-1: raise ValueError("未開始訓練,先訓練") #計算testdata分別為各個類別的概率 lbDict=dict() for thislb in labelset: p = 1 alllabel = self.labelrate[thislb] allvector = self.vectorrate[thislb] vnum=len(allvector) allvector=npy.array(allvector).T for index in range(0,len(testdata)): vector=list(allvector[index]) p*=vector.count(testdata[index])/vnum lbDict[thislb]=p * alllabel thislbabel=sorted(lbDict,key=lambda x:lbDict[x],reverse=True)[0] return thislbabel
將測試資料計算的P(類別|特徵)進行排序,(每一個lbDict字典內容是測試資料0~9標籤與訓練資料標籤0~9所對應的概率)
{0: 3.1868338646386474e-110, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.6477211419058441e-296, 5: 2.955403551519686e-240, 6: 0.0, 7: 0.0, 8: 6.040460506986624e-226, 9: 6.948609891826844e-210}
比如標籤0,結果貝葉斯公式得到滿足0的特徵值且類別為0的概率為3.1868338646386474e-110,依此論推。
- 載入資料和取label值在之前knn中寫到過,因為訓練資料和測試資料一樣,所以可以直接使用之前的方法。
- 實現識別及大概計算出錯率:
labelsall=[0,1,2,3,4,5,6,7,8,9]
#識別多個手寫體數字(批量處理)
testfile=os.listdir("............/testdata")
num=len(testfile)
x=0
for i in range(num):
thisfilename=testfile[i]
thislabel=seplabel(thisfilename)
thisdataarr=datatoarray(".....testdata/"+thisfilename)
label=bys.btest(thisdataarr,labelsall)
print("測試數字是:"+str(thislabel)+"識別出來的數字是:"+str(label))
if label!=thislabel:
x+=1
print("識別出錯")
print(x)
print("出錯率:"+str(x/num))
效果圖: