1. 程式人生 > >Python實現KNN演算法手寫識別數字

Python實現KNN演算法手寫識別數字

本文實現用KNN演算法實現手寫識別數字功能。
語言:Python
訓練材料:手寫數字素材32*32畫素

from numpy import *
import os
from os import listdir
import operator
#將檔案32*32轉成1*1024
def img2vector(filename):
    vect=zeros((1,1024))
    f=open(filename)
    for i in range(32):
        line=f.readline()
        for j in range(32):
            vect[0
,32*i+j]=int(line[j]) return vect def dict2list(dic:dict): #''' 將字典轉化為列表 ''' keys = dic.keys() vals = dic.values() lst = [(key, val) for key, val in zip(keys, vals)]#zip是一個可迭代物件 return lst #inputvector:輸入的用於測試的向量 #trainDataSet:訓練的樣本集 #labels:標籤 #k:k鄰近的個數 def knntest(inputvector,trainDataSet,labels,k)
:
datasetsize=trainDataSet.shape[0] #tile(a,[2,3]) ([a a a],[a,a,a])用第一個引數來構造 #這裡用輸入向量來構造一個1024行 1列的矩陣,剛好和訓練矩陣同樣大小 diffmat=tile(inputvector,(datasetsize,1))-trainDataSet #求平方和 #每個元素都平方 sqdiffmat=diffmat**2 #按行求和 sqdistance=sqdiffmat.sum(axis=1) #平方根,得到的是一個一維的矩陣
distance=sqdistance**0.5 #按照從低到高排序 #argsort函式排列後得到的是按下標進行排列的矩陣, #在原先distance中的下標按距離最近排列 argsort函式返回的是陣列值從小到大的索引值 sortdistance=distance.argsort() classcout={}#用來儲存key(標籤)value(標籤出現的次數,選取次數最大的前幾個數,找到其標籤) #依次取出最近的樣本資料 for i in range(k): #記樣本的類別 votelabel=labels[sortdistance[i]] #統計每個標籤的次數 classcout[votelabel]=classcout.get(votelabel,0)+1#獲取votelabel鍵對應的值,無返回預設 #print("*************") #print(classcout) #classcout.iteritems()在Python3中取消了,key=lambda x:x[0](按第0個元素排序)字典排序,按照value來排序,返回鍵 sortclasscount=sorted(dict2list(classcout),key=operator.itemgetter(1),reverse=True) #返回出現頻次最高的類別 return sortclasscount[0][0] #手寫識別 def handwritingClassTest(): print(os.getcwd()) #將訓練資料儲存到一個矩陣中1024維,並存儲對應的標籤 handlabel=[] trainName=listdir(r'digits\trainingDigits') trainNum=len(trainName) trainNumpy = zeros((trainNum,1024)) #print("trainNum=%d"%trainNum) #對檔名進行分析,訓練文字對應的標籤 for i in range(trainNum): filename=trainName[i]#檔名 filestr=filename.split('.')[0]#不帶字尾的檔名 filelabel=int(filestr.split('_')[0])#檔案的標籤 #將標籤新增至handlabel中 handlabel.append(filelabel) trainNumpy[i,:]=img2vector(r'digits\trainingDigits\%s'%filename)#轉成1024 #print(handlabel[:20]) testfilelist=listdir(r'digits\testDigits') errornum=0 testnum=len(testfilelist) errfile=[] #將每一個測試樣本放入訓練集中使用KNN進行測試 for i in range(testnum): testfilename=testfilelist[i] testfilestr=testfilename.split('.')[0] testfilelabel=int(testfilestr.split('_')[0])#實際的數字標籤 #將測試樣本1024 testvector=img2vector(r'digits\testDigits\%s'%testfilename) #進行測試 #print("-----------") result=knntest(testvector,trainNumpy,handlabel,3) print("test value is %d, real value is %d"%(result,testfilelabel)) if(result!=testfilelabel): errornum+=1 errfile.append(testfilename) print("the num of error is %d"%errornum) print("the right rate of test is %f "%(1-errornum/float(testnum))) print("the error of file are ") count=0 for i in range(len(errfile)): if(count==9): print() print(errfile[i]+' ',end="") count+=1 def main(): #path=os.getcwd() handwritingClassTest() if __name__=='__main__': main();