用 KNN 做手寫數字識別
阿新 • • 發佈:2018-11-25
用 KNN 做手寫數字識別
目錄
作為一個小白,寫此文章主要是為了自己記錄,方便回過頭來查詢! 本文主要參考ApacheCN(專注於優秀專案維護的開源組織)中MachineLearning中的KNN專案。有些程式碼參考這個專案(大部分),有些程式碼是自己寫的。建議去ApacheCN去學習,還有專門的視訊講解,個人感覺非常好。下面對利用KNN進行手寫數字識別的過程進行簡要的描述:
1. KNN的原理
KNN的原理,本文不做解釋,想做了解的人可以去ApacheCN上的專案進行學習或者觀看對應視訊學習。
2. KNN實現手寫數字識別過程
本文主要是測試了一下在測試集上的準確度。測試集樣本個數為946個(資料集同樣可以在ApacheCN上面進行下載),訓練集樣本個數為1934個(0~9),其樣本儲存方式是用.txt檔案儲存的圖片文字。用KNN實現手寫識別的核心思想就是在訓練集中找到一個歐氏距離最小的那個樣本所屬的類別,用該類別來確定未知樣本的類別。
在識別中需要對圖片進行向量化,因此需要一個圖片轉換成向量的函式:
# 將影象文字資料轉換為向量 def img2vector(filename): returnVect = np.zeros((1,1024)) # returnVect = [] fr = open(filename) for i in range(32): read_oneline=fr.readline() for j in range(32): returnVect[0,i*32+j]=int(read_oneline[j]) return returnVect
然後就是在測試集上的精度測試:
def handwritingClassTest(filename,testFileName): # 1. 匯入訓練資料 hwLabels=[] # 標籤集 trainingFileList = os.listdir(filename) # 獲得檔案列表 m = len(trainingFileList) trainingMat = np.zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] classNumStr = fileNameStr.split('_')[0] hwLabels.append(classNumStr) filename_all = filename+'/'+fileNameStr trainingMat[i, :] = img2vector(filename_all) # 2. 匯入測試資料 testFileList=os.listdir(testFileName) mTest = len(testFileList) errorCount = 0.0 for i in range(mTest): fileNameStr = testFileList[i] classNumStr = int(fileNameStr.split('_')[0]) filename_all = testFileName + '/' + fileNameStr vectorUnderTest =img2vector(filename_all) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels) print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)) if (classifierResult != classNumStr): errorCount += 1.0 print("\nthe total number of errors is: %d" % errorCount) print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
上面測試功能中的classify0利用歐氏距離度量實現了最近類別查詢:
def classify0(testVector, traningMat, hwLabels): row_num_train=traningMat.shape[0] testMat=np.zeros((row_num_train,1024)) for i in range(row_num_train): testMat[i,:]=testVector diff=testMat-traningMat diff=np.abs(diff) diff_row=np.sum(diff,axis=1) # 因為向量中的值不是1就是-1,平方後都是1,因此開根號後直接進行求和即可。 diff_min_index=np.argmin(diff_row) return int(hwLabels[diff_min_index])
最後,用 handwritingClassTest 函式測試一下就OK了。測試集946個,錯誤了13個,錯誤率為0.013742。自己可以試一下,程式碼有些地方寫的不夠規範,體諒下吧。