1. 程式人生 > >編寫knn演算法實現手寫體識別

編寫knn演算法實現手寫體識別

  • 一、首先學習學習knn演算法。

kNN演算法的核心思想是如果一個樣本在特徵空間中的k個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。 kNN方法在類別決策時,只與極少量的相鄰樣本有關。由於kNN方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對於類域的交叉或重疊較多的待分樣本集來說,kNN方法較其他方法更為適合。

                     

畫個簡單的圖,假設其他型別的圖案是有類別的,我們需要將中間的六邊形進行歸類,這是我們可以利用knn,計算它與其他圖形的距離,取k值,決策它應該歸類到哪一類中。

看右上圖,綠色圓要被決定賦予哪個類,是紅色三角形還是藍色四方形?如果K=3,由於紅色三角形所佔比例為2/3,綠色圓將被賦予紅色三角形那個類,如果K=5,由於藍色四方形比例為3/5,因此綠色圓被賦予藍色四方形類。

  • 二、接下來程式碼實現knn演算法:
def knn(k,testdata,traindata,labels):
    traindatasize=traindata.shape[0]
    dif=tile(testdata,(traindatasize,1))-traindata
    sqdif=dif**2
    sumsqdif=sqdif.sum(axis=1)
    distance=sumsqdif**0.5
    sortdistance=distance.argsort()
    count={}
    for i in range(0,k):
        vote=labels[sortdistance[i]]
        count[vote]=count.get(vote,0)+1
    sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
    return sortcount[0][0]

knn演算法步驟:

1、處理資料

2、資料向量化

3、計算歐幾里得距離

4、根據距離進行分類

引數   k用於改變誤差率,testdata:測試資料集,traindata:訓練資料集,labels:標籤

  • 三、瞭解手寫體識別

我們通過畫圖,或者在紙上寫上數字或字母,將照片進行處理,得到固定的照片規格,將照片轉換為文字0,1表示的內容。例如下圖:(圖中內容68)訓練集和測試集可以自己手寫利用PIL庫轉化,也可以網際網路上找。

為了簡單起見,固定圖片的畫素為32*32。

為了保證結果的低誤差率,可以將訓練集數量設定多一點。

  • 四、將圖片轉換為0,1文字

pip install pillow   安裝對應庫,PIL.Image處理圖片,getpixel方法獲取畫素,判斷畫素的顏色,進行文字內容0,1的寫入。(下面我對應圖片沒有設定畫素,導致寫入內容很多,在此只做簡單思路分析)

  • 五、載入資料,將訓練集(測試集)轉化為陣列
def datatoarray(fname):
    arr=[]
    fh=open(fname)
    for i in range(0,32):
        thisline=fh.readline()
        for j in range(0,32):
            arr.append(int(thisline[j]))
    return arr
  • 六、建立一個函式取出對應手寫體的名字(輸入的引數是檔案目錄),從而建立label
def seplabel(fname):
    filestr=fname.split(".")[0]
    labels=int(filestr.split("_")[0])
    return labels
  • 七、建立訓練資料集
def traindata():
    labels=[]
    trainfile=os.listdir("./traindata")
    num=len(trainfile)
    #畫素32*32=1024
    #建立一個數組存放訓練資料,行為檔案總數,列為1024,為一個手寫體的內容 zeros建立規定大小的陣列
    trainarr=zeros((num,1024))
    for i in range(0,num):
        thisfname=trainfile[i]
        thislabel=seplabel(thisfname)
        labels.append(thislabel)
        trainarr[i]=datatoarray("./traindata/"+thisfname)
    return trainarr,labels
  • 八、用測試資料呼叫knn演算法完成測試
def datatest():
    trainarr,labels=traindata()
    testlist=os.listdir("./testdata")
    tnum=len(testlist)
    for i in range(tnum):
        thisname=testlist[i]
        testarr=datatoarray("./testdata/"+thisname)
        rknn=knn(k=3,testdata=testarr,traindata=trainarr,labels=labels)  
        print(str(thisname)+"  :  "+str(rknn))

執行效果(冒號前是測試集的檔名,對應數字的第幾個測試樣本):可以看到基本上都能準確的將測試集中的手寫體數字識別正確並歸類,有少部分數字識別失敗,將測試樣本歸類為其他內容。可以更改K值,來改變誤差率。