1. 程式人生 > >keras中使用MLP(多層感知機)神經網路來實現MNIST手寫體識別

keras中使用MLP(多層感知機)神經網路來實現MNIST手寫體識別

    Keras是一個基於python的的深度學習框架,比tensorflow更簡單易用,適合入門學習,本篇文章主要介紹使用keras實現手寫體識別任務。環境為python3+,Keras2.1,神經網路基礎知識在此不做過多介紹。

    1.    載入MNIST資料。

方式一:from keras.datasets import mnist

    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    print("The MNIST database has a training set of %d examples." % len(X_train))
    print("The MNIST database has a test set of %d examples." % len(X_test))



        方式二:方式一有可能載入失敗,可以直接下載資料mnist.npz(文末有連結,可直接下載),放在當前目錄下,使用numpy進行載入:

import numpy as np
f = np.load('mnist.npz')
x_train, y_train = f['x_train'], f['y_train']  
x_test, y_test = f['x_test'], f['y_test']  
f.close() 
print('訓練資料集樣本數: %d ,標籤個數 %d ' % (len(x_train), len(y_train)))
print('測試資料集樣本數: %d ,標籤個數  %d ' % (len(x_test), len(y_test)))

print(x_train.shape)
print(y_train[0])

    輸出如下:

        

    2.    使用matplotlib檢視前六張圖片:

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np

fig = plt.figure(figsize = (20, 20))
for i in range(6):
    ax = fig.add_subplot(1, 6,i + 1, xticks = [], yticks = [])
    ax.imshow(x_train[i], cmap = 'gray')
    ax.set_title(str(y_train[i]))

        輸出如下:

        

    3.    每張圖片都是28*28畫素組成的,我們可以檢視一張圖片的畫素構成細節:

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else 'black')

fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(x_train[0], ax)

    輸出:

        

4.    特徵值縮放:該示例中影象的每個畫素點都是28 * 28畫素的圖片,每個畫素的值在0-255之間。我們需要將訓練資料和測試資料的輸入特徵值縮放到0-1之間(除以255),方便處理。

x_train = x_train.astype('float') / 255
x_test = x_test.astype('float') / 255

    5.    對輸出標籤進行One-hot編碼:

            注意:注意:這段程式碼不能重複執行,因為後面執行的y_train已經不是最初的資料了。

from keras.utils import np_utils

print('Integer-valued labels:')
print(y_train[:10])

#標籤進行one-hot編碼
y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)

print('One-hot labels:')
print(y_train[:10])

    輸出:

        

    6.    定義神經網路模型:使用的網路模型架構如下:

            

        程式碼:            

from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten

model = Sequential()
model.add(Flatten(input_shape = x_train.shape[1:]))
model.add(Dense(512, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(512, activation = 'relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation = 'softmax'))

    輸出:

        

    7.    編譯模型:

model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    8.    在訓練模型之前在測試集上看分類精確度:
score = model.evaluate(x_test, y_test, verbose=0)
accuracy = 100*score[1]
print('Test accuracy: %.4f%%' % accuracy)

    輸出:

    Test accuracy: 11.0700%
    可以看到此時模型在測試集上的準確率只有11.0700%

    9.    開始迭代訓練模型:

from keras.callbacks import ModelCheckpoint

checkpointer = ModelCheckpoint(filepath = 'mnist.model.best.hdf5',verbose=1, save_best_only=True)
hist = model.fit(x_train, y_train, batch_size=128, epochs=10,
          validation_split=0.2, callbacks=[checkpointer],
          verbose=1, shuffle=True)

    輸出:

        

    10.    載入訓練好的模型,並在測試集上進行測試準確率:

model.load_weights('mnist.model.best.hdf5')

score = model.evaluate(x_test, y_test, verbose=0)
accuracy = 100 * score[1]
print('Test accuracy: %.4f%%' % accuracy)

       輸出:

            Test accuracy: 97.8800%

    可以看到模型在測試集上準確率達到了97.8800%。

    對於第9步訓練模型方法model.fit()方法的引數解釋:

            a.     validation_split=0.2引數代表的意思是將原來訓練集中20%的資料拿出來作為交叉驗證資料集,驗證集的資料不參與反向傳播更新權重,但是它能提供判斷模型是否過擬合以及幫助選擇最好的模型。

            b.    ModelCheckpoint類允許我們在每個epoch之後儲存模型權重,用於比較得到最好的模型,filepath引數指定了權重的儲存位置,通過將save_best_only引數設定為True,可以告訴模型僅儲存權重,以讓驗證集達到最佳準確率,verbose設為1表示訓練過程中的文字輸出將告訴你權重檔案何時更新了。

            c.    model.load_weights('mnist.model.best.hdf5') 可以載入達到最佳驗證準確率的權重。