1. 程式人生 > >tensorflow學習之MultiRNNCell詳解

tensorflow學習之MultiRNNCell詳解

tf.contrib.rnn.MultiRNNCell

Aliases:

  • Class tf.contrib.rnn.MultiRNNCell
  • Class tf.nn.rnn_cell.MultiRNNCell

由多個簡單的cells組成的RNN cell。用於構建多層迴圈神經網路。

__init__(

    cells,

    state_is_tuple=True

)

引數:

  • cells:RNNCells的list。
  • state_is_tuple:如果為True,接受和返回的states是n-tuples,其中n=len(cells)。如果為False,states是concatenated沿著列軸.後者即將棄用。

程式碼例項:

import tensorflow as tf



batch_size=10

depth=128

inputs=tf.Variable(tf.random_normal([batch_size,depth]))

previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))

previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))

previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))

num_units=[100,200,300]

print(inputs)

cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]

mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)



outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))

print(outputs.shape) #(10, 300)

print(states[0]) #第一層LSTM

print(states[1]) #第二層LSTM

print(states[2]) ##第三層LSTM

print(states[0].h.shape) #第一層LSTM的h狀態,(10, 100)

print(states[0].c.shape) #第一層LSTM的c狀態,(10, 100)

print(states[1].h.shape) #第二層LSTM的h狀態,(10, 200)