1. 程式人生 > >python手寫神經網路實現識別手寫數字

python手寫神經網路實現識別手寫數字

實驗說明

一直想自己寫一個神經網路來實現手寫數字的識別,而不是套用別人的框架。恰巧前幾天,有幸從同學那拿到5000張已經貼好標籤的手寫數字圖片,於是我就嘗試用matlab寫一個網路。

  • 實驗資料:5000張手寫數字圖片(.jpg),圖片命名為1.jpg,2.jpg…5000.jpg。還有一個放著標籤的excel檔案。

  • 資料處理:前4000張作為訓練樣本,後1000張作為測試樣本。

  • 圖片處理:用matlab的imread()函式讀取圖片的灰度值矩陣(28,28),然後把每張圖片的灰度值矩陣reshape為(28*28,1),然後把前4000張圖片的灰度值矩陣合併為x_train,把後1000張圖片的灰度值矩陣合併為x_test。

數字圖片截圖


數字標籤截圖

神經網路設計

  • 網路層設計:一層隱藏層,一層輸出層

  • 輸入層:一張圖片的灰度值矩陣reshape後的784個數,也就是x_train中的某一列

  • 輸出層:(10,1)的列向量,其中列向量中最大的數所在的索引+1就是預測的數字

  • 激勵函式:sigmoid函式(公式)

  • 更新法則:後向傳播演算法(參考

  • 測試:統計預測正確的個數

網路實現

  • 函式說明:讀圖片的函式(read_photo() )、讀excel的函式(read_excel(path) )、修正函式(layerout(w,b,x) )、訓練函式(mytrain(x_train,y_train) )、測試函式(mytest(x_test,y_test,w,b,w_h,b_h) )、主函式(main() )

具體程式碼如下:

# -*- coding: utf-8 -*-

from PIL import Image

from pylab import *

import numpy as np

import xlrd



#讀取圖片的灰度值矩陣
def read_photo():
    for i in range(5000):
        j = i+1
        j = str(j)
        st = '.jpg'
        j = j+st
        im1 = array(Image.open(j))
        #(28,28)-->(28*28,1)
im1 = im1.reshape((784,1)) #把所有的圖片灰度值放到一個矩陣中 #一列代表一張圖片的資訊 if i == 0: im = im1 else: im = np.hstack((im,im1)) return im #讀取excel檔案內容(path為檔案路徑) def read_excel(path): # 獲取所有sheet workbook = xlrd.open_workbook(path) sheet_names = workbook.sheet_names() # 根據sheet索引或者名稱獲取sheet內容 for sheet_name in sheet_names: isheet = workbook.sheet_by_name(sheet_name) #獲取該sheet的列數 ncols = isheet.ncols #獲取每一列的內容 for i in range(ncols): if i == 0: xl1 = isheet.col_values(i) xl1 = np.array(xl1) xl1 = xl1.reshape((10,1)) xl = xl1 else: xl1 = isheet.col_values(i) xl1 = np.array(xl1) xl1 = xl1.reshape((10,1)) xl = np.hstack((xl,xl1)) return xl #layerout函式 def layerout(w,b,x): y = np.dot(w,x) + b t = -1.0*y # n = len(y) # for i in range(n): # y[i]=1.0/(1+exp(-y[i])) y = 1.0/(1+exp(t)) return y #訓練函式 def mytrain(x_train,y_train): ''' 設定一個隱藏層,784-->隱藏層神經元個數-->10 ''' step=int(input('mytrain迭代步數:')) a=double(input('學習因子:')) inn = 784 #輸入神經元個數 hid = int(input('隱藏層神經元個數:'))#隱藏層神經元個數 out = 10 #輸出層神經元個數 w = np.random.randn(out,hid) w = np.mat(w) b = np.mat(np.random.randn(out,1)) w_h = np.random.randn(hid,inn) w_h = np.mat(w_h) b_h = np.mat(np.random.randn(hid,1)) for i in range(step): #打亂訓練樣本 r=np.random.permutation(4000) x_train = x_train[:,r] y_train = y_train[:,r] #mini_batch for j in range(400): #取batch為10 更新取10次的平均值 x = np.mat(x_train[:,j]) x = x.reshape((784,1)) y = np.mat(y_train[:,j]) y = y.reshape((10,1)) hid_put = layerout(w_h,b_h,x) out_put = layerout(w,b,hid_put) #更新公式的實現 o_update = np.multiply(np.multiply((y-out_put),out_put),(1-out_put)) h_update = np.multiply(np.multiply(np.dot((w.T),np.mat(o_update)),hid_put),(1-hid_put)) outw_update = a*np.dot(o_update,(hid_put.T)) outb_update = a*o_update hidw_update = a*np.dot(h_update,(x.T)) hidb_update = a*h_update w = w + outw_update b = b+ outb_update w_h = w_h +hidw_update b_h =b_h +hidb_update return w,b,w_h,b_h #test函式 def mytest(x_test,y_test,w,b,w_h,b_h): ''' 統計1000個測試樣本中有多少個預測正確了 預測結果表示:10*1的列向量中最大的那個數的索引+1就是預測結果了 ''' sum = 0 for k in range(1000): x = np.mat(x_test[:,k]) x = x.reshape((784,1)) y = np.mat(y_test[:,k]) y = y.reshape((10,1)) yn = np.where(y ==(np.max(y))) # print(yn) # print(y) hid = layerout(w_h,b_h,x); pre = layerout(w,b,hid); #print(pre) pre = np.mat(pre) pre = pre.reshape((10,1)) pren = np.where(pre ==(np.max(pre))) # print(pren) # print(pre) if yn == pren: sum += 1 print('1000個樣本,正確的有:',sum) def main(): #獲取圖片資訊 im = read_photo() immin = im.min() immax = im.max() im = (im-immin)/(immax-immin) #前4000張圖片作為訓練樣本 x_train = im[:,0:4000] #後1000張圖片作為測試樣本 x_test = im[:,4000:5000] #獲取label資訊 xl = read_excel('./label.xlsx') y_train = xl[:,0:4000] y_test = xl[:,4000:5000] print("---------------------------------------------------------------") w,b,w_h,b_h = mytrain(x_train,y_train) mytest(x_test,y_test,w,b,w_h,b_h) print("---------------------------------------------------------------") if __name__ == '__main__': main()

實驗結果

---------------------------------------------------------------
mytrain迭代步數:300
學習因子:0.3
隱藏層神經元個數:28
1000個樣本,正確的有: 933
---------------------------------------------------------------

迭代300步,正確率就有93.3%啦,還不錯的正確率~

相關推薦

no