Keras學習(五)——RNN迴圈神經網路分類
阿新 • • 發佈:2018-10-31
本篇文章主要介紹通過RNN實現MNIST手寫資料集分類。
示例程式碼:
import numpy as np from keras.datasets import mnist from keras.utils import np_utils from keras.models import Sequential from keras.layers import Dense, Activation, SimpleRNN from keras.optimizers import Adam # 使多次生成的隨機數相同 np.random.seed(1337) # 超引數 TIME_STEP = 28 # 和圖片的高度相同 INPUT_SIZE = 28 # 和圖片的寬度相同 BATCH_SIZE = 50 BATCH_INDEX = 0 OUTPUT_SIZE = 10 # [0 0 0 0 1 0 0 0 0 0]->4 CELL_SIZE = 50 # RNN裡面hidden unit的數量 LR = 0.001 # 下載資料集 # X_shape(60000 28x28),y shape(10000) (X_train, y_train), (X_test, y_test) = mnist.load_data() # 預處理資料 ''' X_train.reshape(X_train.shape[0], -1) 將60000個28x28的資料變為60000x784 /255:把資料標準化到[0,1] ''' # 除以255為進行標準化 X_train = X_train.reshape(-1, 28, 28) / 255 # -1:sample個數, 1:channel, 28x28:長寬 X_test = X_test.reshape(-1, 28, 28) / 255 # 將標籤變為one-hot形式 y_train = np_utils.to_categorical(y_train, num_classes=10) y_test = np_utils.to_categorical(y_test, num_classes=10) # 建立模型 model = Sequential() # RNN cell model.add(SimpleRNN( batch_input_shape=(None, TIME_STEP, INPUT_SIZE), output_dim=CELL_SIZE, )) # output layer model.add(Dense(OUTPUT_SIZE)) model.add(Activation('softmax')) # softmax進行分類 # 優化器 adam = Adam(LR) model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy']) # 訓練 for step in range(4001): # data shape = (batch_num, steps, input/output) x_batch = X_train[BATCH_INDEX: BATCH_SIZE + BATCH_INDEX, :, :] y_batch = y_train[BATCH_INDEX: BATCH_SIZE + BATCH_INDEX, :] cost = model.train_on_batch(x_batch, y_batch) BATCH_INDEX += BATCH_SIZE BATCH_INDEX = 0 if BATCH_INDEX >= X_train.shape[0] else BATCH_INDEX if step % 500 == 0: cost, accuracy = model.evaluate(X_test, y_test, batch_size=y_test.shape[0], verbose=False) print('test cost:', cost, 'test accuracy:', accuracy)
分類結果:
test cost: 2.4057323932647705 test accuracy: 0.03909999877214432 test cost: 0.6151769757270813 test accuracy: 0.8148999810218811 test cost: 0.43123072385787964 test accuracy: 0.8719000220298767 test cost: 0.3669806122779846 test accuracy: 0.890999972820282 test cost: 0.34460538625717163 test accuracy: 0.8988999724388123 test cost: 0.27395951747894287 test accuracy: 0.9190000295639038 test cost: 0.2991882860660553 test accuracy: 0.9136999845504761 test cost: 0.23682279884815216 test accuracy: 0.9291999936103821 test cost: 0.2040248066186905 test accuracy: 0.9419000148773193