用Keras進行手寫字型識別(MNIST資料集)
阿新 • • 發佈:2018-12-09
資料
首先載入資料
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
接下來,看看這個資料集的基本情況:
train_images.shape
(60000, 28, 28)
len(train_labels)
60000
train_labels
array([5, 0, 4, …, 5, 6, 8], dtype=uint8)
test_images.shape
(10000, 28, 28)
len(test_labels)
10000
test_labels
array([7, 2, 1, …, 4, 5, 6], dtype=uint8)
網路構架
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
Keras可以幫助我們實現一層一層的連線起來,在本例中的網路包含2個Dense層,他們是密集連線(也叫全連線)的神經層。第二層是一個10路softmax層,他將返回一個由10個概率值(總和為1)組成的陣列。每個概率值表示當前數字影象屬於10個數字類別中的某一個的概率。
編譯
要想訓練網路,我們還需要設定編譯步驟的三個引數:
損失函式
優化器(optimizer):基於訓練資料和損失函式來更新網路的機制
在訓練和測試過程中需要監控的指標(metric):本例只關心精度,即正確分類的影象所佔的比例。
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
影象資料預處理
在訓練之前,需要對影象資料預處理,將其變換成網路要求的形狀,並縮放所有值都在[0,1]區間。比如,之前訓練影象儲存在一個uint8型別的陣列中,其形狀為(60000,28, 28),取值範圍為[0, 255]。我們需要將其變換成一個float32陣列,其形狀為(60000, 28*28),取值範圍為0-1
train_images = train_images.reshape((60000, 28 * 28)) #把一個影象變成一列資料用於學習
train_images = train_images.astype('float32') / 255 #astype用於進行資料型別轉換
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
訓練
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
network.fit(train_images, train_labels, epochs=5, batch_size=128)
會有一個輸出
看看測試集表現如何:
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)