1. 程式人生 > >【機器學習】手寫數字識別算法

【機器學習】手寫數字識別算法

alt gdi 數字識別 -1 轉換 error: erro files turn

1.數據準備

樣本數據獲取忽略,實際上就是將32*32的圖片上數字格式化成一個向量,如下:

技術分享

本demo所有樣本數據都是基於這種格式的

訓練數據:將圖片數據轉成1*1024的數組,作為一個訓練數據。

訓練數據集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/trainingDigits

測試數據集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/testDigits

樣本的文件名格式為:真實值_xxx.txt

轉換代碼:

1 def img2vector(filename):
2     returnVect=zeros((1,1024))
3 fr=open(filename) 4 for i in range(32): 5 lineStr=fr.readline() 6 for j in range(32): 7 returnVect[0,32*i+j]=int(lineStr[j]) 8 return returnVect

2.測試算法

 1 def handwritingClassTest():
 2     hwLabels=[]    # 訓練樣本的標簽數組
 3     traningFileList=listdir("trainingDigits
") # 獲取所有的訓練樣本目錄下的文件名 4 m=len(traningFileList) 5 traningMat=zeros((m,1024)) # 初始化訓練樣本數列 6 7 for i in range(m): 8 fileNameStr=traningFileList[i] # 獲取文件名 9 fileStr=fileNameStr.split(".")[0] 10 clasNumStr=int(fileStr.split("_")[0]) # 獲取樣本的實際值 放入標簽數組
11 hwLabels.append(clasNumStr) 12 traningMat[i,:]=img2vector("trainingDigits/{}".format(fileNameStr)) # 將樣本轉化成1*1024的行放入訓練樣本數列 13 14 testFileList=listdir("testDigits") # 測試樣本目錄 15 error=0 16 mtest=len(testFileList) 17 for i in range(mtest): 18 fileNameStr=testFileList[i] 19 fileStr=fileNameStr.split(".")[0] 20 clasNumStr=int(fileStr.split("_")[0]) 21 testMat=img2vector("testDigits/{}".format(fileNameStr)) 22 res=classify(testMat,traningMat,hwLabels,3) # 使用分類器分類 23 print "came bank with:{} the real anwser is:{}".format(clasNumStr,res) 24 if clasNumStr!=res: # 對比與真實的結果 計算錯誤率 25 error+=1 26 27 print "total:{}".format(mtest) 28 print "error:{}".format(error) 29 print "error:{}".format(float(error/mtest))

這個案例中 算法的識別率為:98.84%

classify是分類器 上上一篇文章中有寫到,具體了解可以點擊這裏

【機器學習】手寫數字識別算法