多層RNN的定義與理解
阿新 • • 發佈:2018-11-10
程式碼:
import tensorflow as tf import numpy as np def get_a_cell(): ### 128 是 狀態向量的長度 return tf.nn.rnn_cell.BasicRNNCell(num_units=128) cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) print(cell.state_size) ## 32 是 batch_size ,100 是 inputs 向量的長度 inputs = tf.placeholder(np.float32,shape=(32,100)) h0 = cell.zero_state(32,np.float32) ## 通過zero_state得到一個全0的初始狀態(只需給出狀態的向量長度即可,因為狀態肯定是向量) output,h1 = cell(inputs,h0) print(output) print(h1)
輸出:
(128, 128, 128) Tensor("multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32) ( <tf.Tensor 'multi_rnn_cell/cell_0/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, <tf.Tensor 'multi_rnn_cell/cell_1/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, <tf.Tensor 'multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32> )