1. 程式人生 > >LSTM模型簡介及Tensorflow實現

LSTM模型簡介及Tensorflow實現

LSTM模型在RNN模型的基礎上新增加了單元狀態C(cell state)。

一. 模型的輸入和輸出

在t時刻,LSTM的輸入有3個:
(1) 當前時刻LSTM的輸入值x(t);
(2) 上一時刻LSTM的輸出值h(t-1);
(3) 上一時刻的單元狀態c(t-1);

LSTM的輸出有2個:
(1) 當前時刻LSTM的輸出值h(t);
(2) 當前時刻的單元狀態c(t);

二. 模型的計算

這裡寫圖片描述

(1) 遺忘門:forget gate,控制上一時刻的單元狀態有多少傳入:

這裡寫圖片描述

(2) 輸入門:input gate,控制上一時刻LSTM的輸出有多少傳入:

這裡寫圖片描述

(3) 當前時刻輸入的單元狀態:

這裡寫圖片描述

(4) 當前時刻LSTM的單元狀態:

這裡寫圖片描述

(5) 輸出門:output gate,控制有多少傳入到LSTM當前時刻的輸出:

這裡寫圖片描述

(6) 當前時刻LSTM的輸出:

這裡寫圖片描述

note:公式中的X表示對應元素相乘;

三. TensorFlow實現LSTM-regression模型

# load module
from tensorflow.example.tutorial.mmist import input_data
import tensorflow as tf
import numpy as np

# definite hyperparameters
BATCH_SIZE = 64
TIME_STEP = 28 INPUT_SIZE = 28 LR = 0.01 # load data mnist = input_data.read_data_sets('mnist', one_hot=True) # test data test_x = mnist.test.images[:2000] test_y = mnist.test.labels[:2000] # placeholder tf_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE]) image = tf.reshape(tf_x, [-1, TIME_STEP, INPUT_SIZE]) tf_y = tf.placeholder(tf.int32, [None
, 10]) # RNN rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64) outputs, (h_c, h_n) = tf.nn.dynamic_rnn(rnn_cell, image, dtype=tf.float32) loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output) train_op = tf.train.AdamOptimizer(LR).minimize(loss) accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1] # open an tf session sess = tf.Session() init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # train for step in range(1200): b_x, b_y = mnist.train.next_batch(BATCH_SIZE) _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y}) if step % 50 == 0: accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y}) print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_) test_output = sess.run(output, {tf_x: test_x[: 10]}) pred_y = np.argmax(test_output, 1) print(pred_y, 'prediction_number') print(np.argmax(test_y[: 10], 1), 'real number')

四. 參考