機器學習實戰--KNN手寫數字識別
阿新 • • 發佈:2018-12-15
程式碼:
import numpy as np import operator import matplotlib import matplotlib.pyplot as plt import os def classfy0KNN(intX,dataset,labels,K): newX = np.tile(intX,(dataset.shape[0],1)) diff = newX - dataset sqrDiff = diff**2 sumSqrDiff = sqrDiff.sum(axis=1) distance = sumSqrDiff**0.5 sortIndex = distance.argsort() LabelDir = {} for i in range(K): labelName = labels[sortIndex[i]] LabelDir[labelName] = LabelDir.get(labelName,0) + 1 sortDir = sorted(LabelDir.items(),key=operator.itemgetter(1),reverse=True) return sortDir[0][0] def img2vector(filename): oneImg = np.zeros((1,1024)) with open(filename) as f: for i in range(32): oneline = f.readline() for j in range(32): oneImg[0,32*i+j] = int(oneline[j]) return oneImg def handwritingClassTest(): hwLables = [] trainingFileList = os.listdir('trainingDigits') m = len(trainingFileList) trainDataset = np.zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] #獲取該影象對應數字標籤 fileStr = fileNameStr.split('.')[0] classNumLable = int(fileStr.split('_')[0]) hwLables.append(classNumLable) trainDataset[i,:] = img2vector('trainingDigits/'+fileNameStr) testFileList = os.listdir('testDigits') m = len(testFileList) errorNum = 0 for i in range(m): fileNameStr = testFileList[i] # 獲取該影象對應數字標籤 fileStr = fileNameStr.split('.')[0] classNumLable = int(fileStr.split('_')[0]) testVect = img2vector('testDigits/'+fileNameStr) predict = classfy0KNN(testVect,trainDataset,hwLables,5) print('the real number is : ',classNumLable,' predict is : ',predict) if predict != classNumLable: errorNum += 1 print('the error rate is : ',(errorNum/m)) if __name__ == '__main__': # testVect = img2vector('testDigits/0_13.txt') # print(testVect[0,:31]) handwritingClassTest()
執行結果: