mnist LSTM 訓練、測試,模型儲存、載入和識別
MNIST 字元資料庫每個字元(0-9) 對應一張28x28的一通道圖片,可以將圖片的每一列(行)當作特徵,所有行(列)當做一個序列。那麼可以通過輸入大小為28,時間長度為28的RNN(lstm)對字元建模。對於同一個字元,比如0,其行與行之間的動態變化可以很好地被RNN表示,所有這些連續行的變化表徵了某個字元的特定模式。因此可以使用RNN來進行字元識別。
Tensorflow提供了不錯的RNN介面,基本思路是
1. 建立RNN網路中的基本單元 cell; tf提供了很多中型別的cell, BasicRNNCell,BasicLSTMCell,LSTMCell 等等
2. 通過呼叫rnn.static_rnn 函式或者rnn.static_bidirectional_rnn將cell連成RNN 網路。本例子採用的是rnn.static_bidirectional_rnn函式。(版本不同有所區別)
LSTM訓練、測試
import os
import numpy as np
'''
A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
'''
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import rnn
# Import user date convertor
import os
from convert_data import convert_datas
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True )
'''
To classify images using a bidirectional recurrent neural network, we consider
every image row as a sequence of pixels. Because MNIST image shape is 28*28px,
we will then handle 28 sequences of 28 steps for every sample.
'''
# Parameters
learning_rate = 0.001
# 訓練迭代次數
training_iters = 100000
# 每次訓練的樣本大小
batch_size = 128
# 這個是用來顯示的。
display_step = 10
# Network Parameters
# n_steps*n_input其實就是那張圖 把每一行拆到每個time step上。
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
# 隱藏層大小
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)
# tf Graph input
# [None, n_steps, n_input]這個None表示這一維(樣本數)不確定大小
x = tf.placeholder("float", [None, n_steps, n_input], name="input_x")
y = tf.placeholder("float", [None, n_classes], name="input_y")
# Define weights and biases
weights = tf.Variable(tf.random_normal([2*n_hidden, n_classes]), name="weights")
biases = tf.Variable(tf.random_normal([n_classes]), name="biases")
def BiRNN( x, weights, biases):
# Prepare data shape to match `bidirectional_rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
# Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
# 變成了n_steps*(batch_size, n_input)
x = tf.unstack(x, n_steps, 1)
# Define lstm cells with tensorflow
# Forward direction cell
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Backward direction cell
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Get lstm cell output
try:
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
except Exception: # Old TensorFlow version only returns outputs not states
outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
# Linear activation, using rnn inner loop last output
# return tf.matmul(outputs[-1], weights['out']) + biases['out']
# return tf.matmul(outputs[-1], weights) + biases
return tf.add(tf.matmul(outputs[-1], weights), biases)
pred = BiRNN(x, weights, biases)
# Define loss and optimizer
# softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive
# return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.
# reduce_mean就是對所有數值(這裡沒有指定哪一維)求均值。
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph
with tf.Session() as sess:
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
", Training Accuracy= " + "{:.5f}".format(acc))
step += 1
print("Optimization Finished!")
# Calculate accuracy for 128 mnist test images
# test_len = 128
# test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
# test_label = mnist.test.labels[:test_len]
## Input 為 batch_size*30*17
## 實際測試,需要滿足 tensorflow的輸入placeholder要求
test_data = mnist.test.images.reshape((-1, n_steps, n_input))
test_label = mnist.test.labels
print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
儲存訓練模型
緊接著上面測試進度輸出後,輸入以下程式碼 重複執行即可
saver = tf.train.Saver()
model_path = "./model/my_model"
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path)
這裡只是一種方式,並且儲存整個網路結構。 model_path
中的model
是模型儲存的資料夾,my_model
是儲存模型的字首,可以理解為模型的名稱。
執行完畢後,當前目錄會新建名稱為“model”的資料夾,且含有四個資料夾:checkpoint、my_model.data-00000-of-00001、my_model.index和my_model.meta。這裡的四個檔案的有關介紹網上有很多。
注意,這裡值是進行了模型的儲存,這裡儲存的目的是為了進行載入並對輸入的資料進行測試,並且不需要重建整個網路。因此,還需要對某些計算節點進行儲存,在識別階段利用這些節點計算輸出。這裡需要增加1個預測節點。在pred = BiRNN(x, weights, biases)
後增加:
tf.add_to_collection('predict', pred)
將pred整個計算和“predict”整個名字繫結在一起,就可以在載入後通過整個名字讀取整個運算節點。
載入訓練模型 、識別
載入模型很簡單,主要程式碼如下
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./model/my_model.meta')
new_saver.restore(sess, './model/my_model')
這裡需要注意,restore()函式的路徑和儲存時要一致。
接著,從載入的模型中讀取需要的節點。首先是predict節點對應的pred運算,其次這個pred運算需要輸入x,也就是訓練程式碼中的佔位符“input_x”。繼續新增程式碼如下
graph = tf.get_default_graph()
predict = tf.get_collection('predict')[0]
input_x = graph.get_operation_by_name("input_x").outputs[0]
最後,就是輸入一個圖片資料,對其進行識別分類了。
x = mnist.test.images[0].reshape((1, n_steps, n_input))
res = sess.run(predict, feed_dict={input_x: x})
這裡用的test資料集的第一個圖,這裡的過程和測試部分類似,只是沒有第二個引數label。返回的結果可以通過tf.argmax進行獲取類別值。
在利用argmax函式時,需要確認資料的shape,再確定計算的維度。這一部分完整程式碼如下:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True)
n_input = 28
n_steps = 30
n_classes = 2
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./model/my_model.meta')
new_saver.restore(sess, './model/my_model')
graph = tf.get_default_graph()
predict = tf.get_collection('predict')[0]
input_x = graph.get_operation_by_name("input_x").outputs[0]
x = mnist.test.images[0].reshape((1, n_steps, n_input))
y = mnist.test.labels[0].reshape(-1, n_classes) # 轉為one-hot形式
res = sess.run(predict, feed_dict={input_x: test_data })
print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
", predict class ",str(sess.run(tf.argmax(res, 1))), \
", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))
)