1. 程式人生 > >tf.contrib.rnn.static_bidirectional_rnn和MultiRNNCell構建多層靜態雙向LSTM

tf.contrib.rnn.static_bidirectional_rnn和MultiRNNCell構建多層靜態雙向LSTM

import tensorflow as tf
import numpy as np

# 設定訓練引數
learning_rate = 0.01
max_examples = 40
batch_size = 128
display_step = 10  # 每間隔10次訓練就展示一次訓練情況

n_input = 100#詞向量維度
n_steps = 300#時間步長
fw_n_hidden = 256#正向神經元數量
bw_n_hidden = 128#反向神經元數量
n_classes = 10

x = tf.placeholder("float", [max_examples, n_steps, n_input])
y = tf.placeholder('float', [max_examples, n_classes])
weights = tf.Variable(tf.random_normal([(fw_n_hidden + bw_n_hidden), n_classes]))
biases = tf.Variable(tf.random_normal([n_classes]))

x = tf.transpose(x, [1, 0, 2])
print(x.shape)  
x = tf.reshape(x, [-1, n_input])
print(x.shape) 
x = tf.split(x, n_steps)
print(len(x), x[0].shape) 

# lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0)  # 正向RNN,輸出神經元數量為256
# lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0)  # 反向RNN,輸出神經元數量為128

lstm_fw_cell=[]
lstm_bw_cell=[]
for i in range(3):
    lstm_fw_cell.append(tf.contrib.rnn.BasicLSTMCell(fw_n_hidden, forget_bias=1.0) )
    lstm_bw_cell.append( tf.contrib.rnn.BasicLSTMCell(bw_n_hidden, forget_bias=1.0))

mul_lstm_fw_cell=tf.contrib.rnn.MultiRNNCell(lstm_fw_cell)
mul_lstm_bw_cell=tf.contrib.rnn.MultiRNNCell(lstm_bw_cell)

outputs, fw_state, bw_state = tf.contrib.rnn.static_bidirectional_rnn(mul_lstm_fw_cell, mul_lstm_bw_cell, x, dtype=tf.float32)

print(len(outputs))##300,等於時間步的長度,一般取outputs[-1]也就是最後一步的輸出進行運算
print(outputs[0].shape)#(40, 384)
print(outputs[-1].shape)#(40, 384),一般取最後一個時間步的輸出來進行運算

print(len(fw_state))#三個LSTM隱藏層
# print(fw_state)

#正向RNN第一個LSTM隱藏層的c狀態
print(fw_state[0].c.shape)#(40, 256)
print(fw_state[1].c.shape)#(40, 256)
print(fw_state[2].c.shape)#(40, 256)

#正向RNN第一個LSTM隱藏層的h狀態
print(fw_state[0].h.shape)#(40, 256)
print(fw_state[1].h.shape)#(40, 256)
print(fw_state[2].h.shape)#(40, 256)

#反向RNN第一個LSTM隱藏層的c狀態
print(bw_state[0].c.shape)#(40, 256)
print(bw_state[1].c.shape)#(40, 256)
print(bw_state[2].c.shape)#(40, 256)

#反向RNN第一個LSTM隱藏層的h狀態
print(bw_state[0].h.shape)#(40, 256)
print(bw_state[1].h.shape)#(40, 256)
print(bw_state[2].h.shape)#(40, 256)