1. 程式人生 > >Tensorflow 學習筆記之使用LSTM實現MNIST資料集

Tensorflow 學習筆記之使用LSTM實現MNIST資料集

LSTM實現MNIST手寫集識別

這幾天剛好看了RNN之後瞭解了LSTM(原理可以去參考這個)。雖然LSTM主要用於處理自然語言、語音、機器人翻譯等領域,但圖片也可以看做一個有序列的資料。所以用LSTM來識別Tensorflow入門資料集。

配置神經網路引數

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('C:\\Users\\Qigq\\Desktop'
                                  '\\P_Data\\LSTM\\LSTM_TEST\\data'
,one_hot=True) batch_size = 128 # 批次大小 num_units = 128 # 單元數(就是有幾個A) input_size = 28 # 輸入資料單個序列的長度 time_steps = 28 # 序列本身的長度 classes = 10 # 分類 train_step = 10000 # 訓練次數 learnning_rate = 1e-4 # 學習率

在這裡插入圖片描述

LSTM之後有一層全連線層和softmax分類層所以要配置全連線層weights和biases

w = tf.
Variable(tf.truncated_normal([num_units,classes]),dtype=tf.float32) b = tf.Variable(tf.constant(value=0.1,shape=[classes]),dtype=tf.float32) x = tf.placeholder(dtype=tf.float32,shape=[None,784]) # 28*28 y = tf.placeholder(dtype=tf.float32,shape=[None,10])

定義神經網路

def lstm_softmax(x,w,b):
    x =
tf.reshape(x,[-1,num_units,input_size]) #這裡需要reshape一下以符合神經網路輸入 lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units) outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,x,dtype=tf.float32) softmax = tf.nn.softmax(tf.matmul(final_state[1],w)+b) return softmax

定義方向傳播過程

prediction = lstm_softmax(x,w,b)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
optimizer = tf.train.AdamOptimizer(learnning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction變為float32型別
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    for i in range(10000):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict={x:batch_xs,y:batch_ys})
        loss_value = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        print("step",i," loss:",loss_value)
        print("acc:",acc)

大概經過3000步的時候準確率到達94%,這個準確率在只有一層LSTM的時候應該算挺好了吧。
在這裡插入圖片描述
這裡需要特別說一下tf.nn.dynamic_rnn這個函式的兩個返回值。

outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,x,dtype=tf.float32)
outputs是所有的h(看lstm結構圖)
final_state是(c,h)
outputs[:,-1,:]和final_state[1]是相同的。

如有錯處還請多多指點。