1. 程式人生 > >python3 keras獲取GRU的每一個時刻的輸出,,切分GRU的輸出示例

python3 keras獲取GRU的每一個時刻的輸出,,切分GRU的輸出示例

1. keras的GRU很好用,return_sequences,False 返回單個, true 返回全部time step 的 hidden state值。

2. 但是keras的time step是根據輸入決定的,這裡我們在處理的時候,可能就需要分割以下某一層的資料,然後輸入到後面的不同的層中,但是我們在對tensor操作的時候,會產生很多莫名其妙的錯誤,這裡我寫一個跟別人寫的示例程式碼,分享給大家:


from keras.models import Model
from keras.layers import Input, Dropout, TimeDistributed, Masking, Dense, Lambda
from keras.layers import BatchNormalization, Embedding, Activation, Reshape,Permute,Bidirectional
from keras.layers.merge import Add,Dot
from keras.layers.recurrent import LSTM, GRU
from keras import backend as K
import tensorflow as tf



image_input = Input(shape=(20, 2048), name='image')
recurrent_network = GRU(units=10,return_sequences=True,
                                    name='recurrent_network')(image_input)

new_inpu=recurrent_network
splits = Lambda(lambda x: tf.split(x, num_or_size_splits=20, axis=1))(new_inpu)
print(splits[0])

list_recur=[]
for i in range(20):
    
    # x = new_inpu[:, i, :]   # AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
    # print(x.shape)
    # x= K.reshape(x,(1,10,)) 
    # x= K.expand_dims(x,axis=-1) 
    # x=Lambda(lambda x: K.expand_dims(x, axis=-1))(x)
    x=Reshape((1,10))(splits[i])
    print(x)
    # x=Permute((2, 1))(x)
    # print(x)
    recurrent_network = GRU(units=10,return_sequences=True,
                                    name='recurrent_network{}'.format(i))(x)
    # print(recurrent_network)
    list_recur.append(recurrent_network)

decoder_outputs = Lambda(lambda x: K.concatenate(x, axis=1))(list_recur)

# recurrent_network = GRU(units=10,return_sequences=True,
#                                     name='recurrent_network')(image_input)
# print(decoder_outputs)
# output = TimeDistributed(Dense(units=20,
#                                     activation='softmax'),
#                                     name='output')(decoder_outputs)


model = Model(inputs=image_input, outputs=decoder_outputs)

print(model.summary()

參考文獻

[1].Keras Multi-inputs AttributeError: 'NoneType' object has no attribute 'inbound_nodes'.https://stackoverflow.com/questions/44627977/keras-multi-inputs-attributeerror-nonetype-object-has-no-attribute-inbound-n%EF%BC%89%EF%BC%8C

[2].理解LSTM在keras API中引數return_sequences和return_state. https://blog.csdn.net/u011327333/article/details/78501054

[3].tf.split.https://www.tensorflow.org/api_docs/python/tf/split

[4].Are there slice layer and split layer in Keras?.https://github.com/keras-team/keras/issues/890