1. 程式人生 > >[Keras深度學習淺嘗]實戰三·CNN實現Fashion MNIST 資料集分類

[Keras深度學習淺嘗]實戰三·CNN實現Fashion MNIST 資料集分類

[Keras深度學習淺嘗]實戰三·RNN實現Fashion MNIST 資料集分類

與我們上篇博文[Keras深度學習淺嘗]實戰一結構相同,修改的地方有,定義網路與模型訓練兩部分,可以對比著來看。通過使用RNN結構,預測準確率略有提升,可以通過修改超引數以獲得更優結果。
程式碼部分

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np import matplotlib.pyplot as plt EAGER = True fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() print(train_images.shape,train_labels.shape) train_images = train_images.reshape([-1,28,28]) / 255.0 test_images =
test_images.reshape([-1,28,28]) / 255.0 model = keras.Sequential([ #(-1,28,28)->(-1,100) keras.layers.SimpleRNN( # for batch_input_shape, if using tensorflow as the backend, we have to put None for the batch_size. # Otherwise, model.evaluate() will get error. input_shape=(28, 28
), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS, units=256, unroll=True), keras.layers.Dropout(rate=0.2), #(-1,256)->(-1,10) keras.layers.Dense(10, activation=tf.nn.softmax) ]) print(model.summary()) lr = 0.001 epochs = 5 model.compile(optimizer=tf.train.AdamOptimizer(lr), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_images, train_labels, epochs=epochs,validation_data=[test_images[:1000],test_labels[:1000]]) test_loss, test_acc = model.evaluate(test_images, test_labels) print(np.argmax(model.predict(test_images[:10]),1),test_labels[:10])

輸出結果

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
simple_rnn (SimpleRNN)       (None, 256)               72960
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense (Dense)                (None, 10)                2570
=================================================================
Total params: 75,530
Trainable params: 75,530
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 1000 samples
Epoch 1/5
60000/60000 [==============================] - 56s 927us/step - loss: 0.7429 - acc: 0.7307 - val_loss: 0.6208 - val_acc: 0.7750
Epoch 2/5
60000/60000 [==============================] - 46s 759us/step - loss: 0.5935 - acc: 0.7876 - val_loss: 0.5550 - val_acc: 0.8060
Epoch 3/5
60000/60000 [==============================] - 50s 828us/step - loss: 0.5558 - acc: 0.8004 - val_loss: 0.4969 - val_acc: 0.8220
Epoch 4/5
60000/60000 [==============================] - 53s 886us/step - loss: 0.5267 - acc: 0.8100 - val_loss: 0.5298 - val_acc: 0.8080
Epoch 5/5
60000/60000 [==============================] - 62s 1ms/step - loss: 0.5243 - acc: 0.8115 - val_loss: 0.4916 - val_acc: 0.8180
10000/10000 [==============================] - 4s 435us/step
[9 2 1 1 6 1 6 6 5 7] [9 2 1 1 6 1 4 6 5 7]
yansongdeMacBook-Pro:TFAPP yss$