Tensorflow 學習筆記之使用LSTM實現MNIST資料集
阿新 • • 發佈:2019-02-19
LSTM實現MNIST手寫集識別
這幾天剛好看了RNN之後瞭解了LSTM(原理可以去參考這個)。雖然LSTM主要用於處理自然語言、語音、機器人翻譯等領域,但圖片也可以看做一個有序列的資料。所以用LSTM來識別Tensorflow入門資料集。
配置神經網路引數
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('C:\\Users\\Qigq\\Desktop'
'\\P_Data\\LSTM\\LSTM_TEST\\data' ,one_hot=True)
batch_size = 128 # 批次大小
num_units = 128 # 單元數(就是有幾個A)
input_size = 28 # 輸入資料單個序列的長度
time_steps = 28 # 序列本身的長度
classes = 10 # 分類
train_step = 10000 # 訓練次數
learnning_rate = 1e-4 # 學習率
LSTM之後有一層全連線層和softmax分類層所以要配置全連線層weights和biases
w = tf. Variable(tf.truncated_normal([num_units,classes]),dtype=tf.float32)
b = tf.Variable(tf.constant(value=0.1,shape=[classes]),dtype=tf.float32)
x = tf.placeholder(dtype=tf.float32,shape=[None,784]) # 28*28
y = tf.placeholder(dtype=tf.float32,shape=[None,10])
定義神經網路
def lstm_softmax(x,w,b):
x = tf.reshape(x,[-1,num_units,input_size]) #這裡需要reshape一下以符合神經網路輸入
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,x,dtype=tf.float32)
softmax = tf.nn.softmax(tf.matmul(final_state[1],w)+b)
return softmax
定義方向傳播過程
prediction = lstm_softmax(x,w,b)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
optimizer = tf.train.AdamOptimizer(learnning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction變為float32型別
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in range(10000):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optimizer,feed_dict={x:batch_xs,y:batch_ys})
loss_value = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("step",i," loss:",loss_value)
print("acc:",acc)
大概經過3000步的時候準確率到達94%,這個準確率在只有一層LSTM的時候應該算挺好了吧。
這裡需要特別說一下tf.nn.dynamic_rnn
這個函式的兩個返回值。
outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,x,dtype=tf.float32)
outputs是所有的h(看lstm結構圖)
final_state是(c,h)
outputs[:,-1,:]和final_state[1]是相同的。
如有錯處還請多多指點。