1. 程式人生 > >keras實現attention based sequence to sequence model(首稿)

keras實現attention based sequence to sequence model(首稿)

class AttentionGRU(GRU):

  def __init__(self, atten_states, states_len, L2Strength, **kwargs):
    '''
    :param atten_states: previous states for attention
    :param states_len: length of state
    :param L2Strength: for regularization
    :param kwargs: for GRU
    '''
    self.p_states = atten_states
    self.states_len = states_len
    self.size = kwargs['units'
] self.L2Strength = L2Strength super(AttentionGRU, self).__init__(**kwargs) def build(self, input_shape): input_dim = input_shape[-1] input_length = input_shape[1] self.W1 = self.add_weight(shape = (self.units + input_dim, 1), initializer = 'random_uniform'
, regularizer=l2(self.L2Strength), trainable = True) self.b1 = self.add_weight(shape=(1,), initializer = 'zero', regularizer=l2(self.L2Strength), trainable= True
) self.W2 = self.add_weight(shape=(self.units + input_dim, self.units), initializer='random_uniform', regularizer=l2(self.L2Strength), trainable=True) self.b2 = self.add_weight(shape=(self.units,), initializer='zero', regularizer=l2(self.L2Strength), trainable=True) super(AttentionGRU, self).build(input_shape) def step(self, inputs, states): h, _ = super(AttentionGRU, self).step(inputs, states) alfa = K.repeat(h, self.states_len) # alfa = [batch_size, states_len, units] alfa = K.concatenate([self.p_states, alfa], axis = 2) # alfa = [batch_size, states_len, 2*units] scores = K.tanh(K.dot(alfa, self.W1) + self.b1) # scores = [batch_size, states_len, 1] scores = K.softmax(scores) scores = K.reshape(scores, (-1, 1, self.states_len)) # scores = [batch_size, 1, states_len] attn = K.batch_dot(scores, self.p_states) # attn = [batch_size, 1, units] attn = K.reshape(attn, (-1, self.units)) # attn = [batch_size, units] h = keras.layers.concatenate([h, attn]) # h = [batch_size, 2*units] h = K.dot(h, self.W2) + self.b2 # h = [batch_size, units] return h, [h] def compute_output_shape(self, input_shape): return input_shape[0], self.units

需要把encoder的states傳給引數atten_states,然後就當Keras裡標準的GRU用就好了。因為是GRU不是LSTM,所以step裡計算方式和論文裡有點不一樣。units是hidden size,這裡假設encoder和decoder的hidden size一樣。