1. 程式人生 > >多層RNN的定義與理解

多層RNN的定義與理解

程式碼:


import tensorflow as tf
import numpy as np

def get_a_cell():
    ### 128 是 狀態向量的長度
    return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

print(cell.state_size)

## 32 是 batch_size ,100 是 inputs 向量的長度
inputs = tf.placeholder(np.float32,shape=(32,100))
h0 = cell.zero_state(32,np.float32) ## 通過zero_state得到一個全0的初始狀態(只需給出狀態的向量長度即可,因為狀態肯定是向量)

output,h1 = cell(inputs,h0)
print(output)
print(h1)

 

輸出:

(128, 128, 128)
Tensor("multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32)
(
<tf.Tensor 'multi_rnn_cell/cell_0/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, 
<tf.Tensor 'multi_rnn_cell/cell_1/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, 
<tf.Tensor 'multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>
)