1. 程式人生 > >用Keras進行手寫字型識別(MNIST資料集)

用Keras進行手寫字型識別(MNIST資料集)

資料

首先載入資料

from keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

接下來,看看這個資料集的基本情況:

train_images.shape

(60000, 28, 28)

len(train_labels)

60000

train_labels

array([5, 0, 4, …, 5, 6, 8], dtype=uint8)

test_images.shape

(10000, 28, 28)

len(test_labels)

10000

test_labels

array([7, 2, 1, …, 4, 5, 6], dtype=uint8)

網路構架

from keras import models
from keras import layers

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

Keras可以幫助我們實現一層一層的連線起來,在本例中的網路包含2個Dense層,他們是密集連線(也叫全連線)的神經層。第二層是一個10路softmax層,他將返回一個由10個概率值(總和為1)組成的陣列。每個概率值表示當前數字影象屬於10個數字類別中的某一個的概率。

編譯

要想訓練網路,我們還需要設定編譯步驟的三個引數:

  • 損失函式

  • 優化器(optimizer):基於訓練資料和損失函式來更新網路的機制

  • 在訓練和測試過程中需要監控的指標(metric):本例只關心精度,即正確分類的影象所佔的比例。

network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

影象資料預處理

在訓練之前,需要對影象資料預處理,將其變換成網路要求的形狀,並縮放所有值都在[0,1]區間。比如,之前訓練影象儲存在一個uint8型別的陣列中,其形狀為(60000,28, 28),取值範圍為[0, 255]。我們需要將其變換成一個float32陣列,其形狀為(60000, 28*28),取值範圍為0-1

train_images = train_images.reshape((60000, 28 * 28))  #把一個影象變成一列資料用於學習
train_images = train_images.astype('float32') / 255 #astype用於進行資料型別轉換

test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255

訓練

from keras.utils import to_categorical

train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

network.fit(train_images, train_labels, epochs=5, batch_size=128)

會有一個輸出

看看測試集表現如何:

test_loss, test_acc = network.evaluate(test_images, test_labels)

print('test_acc:', test_acc)