tf.contrib.rnn.static_bidirectional_rnn和MultiRNNCell構建多層靜態雙向LSTM
阿新 • • 發佈:2018-12-10
import tensorflow as tf import numpy as np # 設定訓練引數 learning_rate = 0.01 max_examples = 40 batch_size = 128 display_step = 10 # 每間隔10次訓練就展示一次訓練情況 n_input = 100#詞向量維度 n_steps = 300#時間步長 fw_n_hidden = 256#正向神經元數量 bw_n_hidden = 128#反向神經元數量 n_classes = 10 x = tf.placeholder("float", [max_examples, n_steps, n_input]) y = tf.placeholder('float', [max_examples, n_classes]) weights = tf.Variable(tf.random_normal([(fw_n_hidden + bw_n_hidden), n_classes])) biases = tf.Variable(tf.random_normal([n_classes])) x = tf.transpose(x, [1, 0, 2]) print(x.shape) x = tf.reshape(x, [-1, n_input]) print(x.shape) x = tf.split(x, n_steps) print(len(x), x[0].shape) # lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0) # 正向RNN,輸出神經元數量為256 # lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0) # 反向RNN,輸出神經元數量為128 lstm_fw_cell=[] lstm_bw_cell=[] for i in range(3): lstm_fw_cell.append(tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0) ) lstm_bw_cell.append( tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0)) mul_lstm_fw_cell=tf.contrib.rnn.MultiRNNCell(lstm_fw_cell) mul_lstm_bw_cell=tf.contrib.rnn.MultiRNNCell(lstm_bw_cell) outputs, fw_state, bw_state = tf.contrib.rnn.static_bidirectional_rnn(mul_lstm_fw_cell, mul_lstm_bw_cell, x, dtype=tf.float32) print(len(outputs))##300,等於時間步的長度,一般取outputs[-1]也就是最後一步的輸出進行運算 print(outputs[0].shape)#(40, 384) print(outputs[-1].shape)#(40, 384),一般取最後一個時間步的輸出來進行運算 print(len(fw_state))#三個LSTM隱藏層 # print(fw_state) #正向RNN第一個LSTM隱藏層的c狀態 print(fw_state[0].c.shape)#(40, 256) print(fw_state[1].c.shape)#(40, 256) print(fw_state[2].c.shape)#(40, 256) #正向RNN第一個LSTM隱藏層的h狀態 print(fw_state[0].h.shape)#(40, 256) print(fw_state[1].h.shape)#(40, 256) print(fw_state[2].h.shape)#(40, 256) #反向RNN第一個LSTM隱藏層的c狀態 print(bw_state[0].c.shape)#(40, 256) print(bw_state[1].c.shape)#(40, 256) print(bw_state[2].c.shape)#(40, 256) #反向RNN第一個LSTM隱藏層的h狀態 print(bw_state[0].h.shape)#(40, 256) print(bw_state[1].h.shape)#(40, 256) print(bw_state[2].h.shape)#(40, 256)