1. 程式人生 > >tensorflow實現LSTM進行MNIST資料集分類

tensorflow實現LSTM進行MNIST資料集分類

大大的部落格講得很詳細,先拿過來分享一下:http://blog.csdn.net/jerr__y/article/details/61195257

自己組合的第一部分程式碼:

import sys
reload(sys)
sys.setdefaultencoding('utf8')
import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as 
plt sess=tf.InteractiveSession() # 首先匯入資料,看一下資料的形式 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) print mnist.train.images.shape lr = 1e-3 keep_prob = tf.placeholder(tf.float32, []) # 在訓練和測試的時候,我們想用不同的 batch_size.所以採用佔位符的方式 batch_size = tf.placeholder(tf.int32,[]) # 注意型別必須為 tf.int32
# 1.0 版本以後請使用 : # keep_prob = tf.placeholder(tf.float32, []) # batch_size = tf.placeholder(tf.int32, []) # 每個時刻的輸入特徵是28維的,就是每個時刻輸入一行,一行有 28 個畫素 input_size = 28 # 時序持續長度為28,即每做一次預測,需要先輸入28timestep_size = 28 # 每個隱含層的節點數 hidden_size = 256 # LSTM layer 的層數 layer_num = 2 # 最後輸出分類類別數量,如果是迴歸預測的話應該是 1 class_num = 10
_X = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, class_num]) # 784個點的字元資訊還原成 28 * 28 的圖片 # 下面幾個步驟是實現 RNN / LSTM 的關鍵 #################################################################### # **步驟1RNN 的輸入shape = (batch_size, timestep_size, input_size) X = tf.reshape(_X, [-1, 28, 28]) # **步驟2:定義一層 LSTM_cell,只需要說明 hidden_size, 它會自動匹配輸入的 X 的維度 #lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True) # **步驟3:新增 dropout layer, 一般只設置 output_keep_prob #lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) # **步驟4:呼叫 MultiRNNCell 來實現多層 LSTM def lstm_cell(): cell = rnn.LSTMCell(hidden_size, reuse=tf.get_variable_scope().reuse) return rnn.DropoutWrapper(cell, output_keep_prob=keep_prob) mlstm_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(layer_num)], state_is_tuple = True) # **步驟5:用全零來初始化state init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32) # **步驟6:方法一,呼叫 dynamic_rnn() 來讓我們構建好的網路執行起來 # ** time_major==False 時, outputs.shape = [batch_size, timestep_size, hidden_size] # ** 所以,可以取 h_state = outputs[:, -1, :] 作為最後輸出 # ** state.shape = [layer_num, 2, batch_size, hidden_size], # ** 或者,可以取 h_state = state[-1][1] 作為最後輸出 # ** 最後輸出維度是 [batch_size, hidden_size] # outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False) # h_state = outputs[:, -1, :] # 或者 h_state = state[-1][1] # *************** 為了更好的理解 LSTM 工作原理,我們把上面 步驟6 中的函式自己來實現 *************** # 通過檢視文件你會發現, RNNCell 都提供了一個 __call__()函式(見最後附),我們可以用它來展開實現LSTM按時間步迭代。 # **步驟6:方法二,按時間步展開計算 outputs = list() state = init_state with tf.variable_scope('RNN'): for timestep in range(timestep_size): if timestep > 0: tf.get_variable_scope().reuse_variables() # 這裡的state儲存了每一層 LSTM 的狀態 (cell_output, state) = mlstm_cell(X[:, timestep, :], state) outputs.append(cell_output) h_state = outputs[-1] # 上面 LSTM 部分的輸出會是一個 [hidden_size] tensor,我們要分類的話,還需要接一個 softmax # 首先定義 softmax 的連線權重矩陣和偏置 # out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights') # out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias') # 開始訓練和測試 W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32) bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32) y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias) # 損失和評估函式 cross_entropy = -tf.reduce_mean(y * tf.log(y_pre)) train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) for i in range(2000): _batch_size = 128 batch = mnist.train.next_batch(_batch_size) if (i+1)%200 == 0: train_accuracy = sess.run(accuracy, feed_dict={ _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size}) # 已經迭代完成的 epoch : mnist.train.epochs_completed print "Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy) sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size}) saver.save(sess, "MNIST_data/save") # 計算測試資料的準確率 #saver.restore(sess,"MNIST_data/save") print "test accuracy %g"% sess.run(accuracy, feed_dict={ _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]}) _batch_size = 5 X_batch, y_batch = mnist.test.next_batch(_batch_size) print X_batch.shape, y_batch.shape _outputs, _state = np.array(sess.run([outputs, state],feed_dict={ _X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size})) print '_outputs.shape =', np.asarray(_outputs).shape print 'arr_state.shape =', np.asarray(_state).shape X3 = mnist.train.images[12] img3 = X3.reshape([28, 28]) plt.imshow(img3, cmap='gray') plt.show() X3.shape = [-1, 784] y_batch = mnist.train.labels[0] y_batch.shape = [-1, class_num] X3_outputs = np.array(sess.run(outputs, feed_dict={ _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1})) print X3_outputs.shape X3_outputs.shape = [28, hidden_size] print X3_outputs.shape h_W = sess.run(W, feed_dict={ _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1}) h_bias = sess.run(bias, feed_dict={ _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1}) h_bias.shape = [-1, 10] bar_index = range(class_num) for i in xrange(X3_outputs.shape[0]): plt.subplot(7, 4, i+1) X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size]) pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias)) plt.bar(bar_index, pro[0], width=0.2 , align='center') plt.axis('off') plt.show()

最後的結果:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784)
Iter0, step 200, training accuracy 0.898438
Iter0, step 400, training accuracy 0.960938
Iter1, step 600, training accuracy 0.96875
Iter1, step 800, training accuracy 0.976562
Iter2, step 1000, training accuracy 0.984375
test accuracy 0.9718
(5, 784) (5, 10)
_outputs.shape = (28, 5, 256)
arr_state.shape = (2, 2, 5, 256)
(28, 1, 256)
(28, 256)