How to Visualize Your Recurrent Neural Network with Attention in Keras

Now for the interesting part: the decoder. For any given character at position t in the sequence, our decoder accepts the encoded sequence h=(h1,...,hT) as well as the previous hidden state st-1(shared within the decoder cell) and character yt-1. Our decoder layer will output y=(y1,...,yT)(the characters in the standardized date). Our overall architecture is summarized in Figure 7.


As shown in Figure 6, the decoder is quite complicated. So let’s break it down into the steps executed by the decoder cell when trying to predict character t.In the following equations, the capital letter variables represent trainable parameters (Note that I have dropped the bias terms for brevity.)

Equation 1: A feed-forward neural network that calculates the unnormalized importance of character j in predicting character t. Equation 2: The softmax operation that normalizes the probability.
  1. Calculate the attention probabilities α=(α1,…,αT) based on the encoded sequence and the internal hidden state of the decoder cell, st-1
    . These are shown in Equation 1 and Equation 2.
Figure 3: Calculation of the context vector for the t-th character.

2. Calculate the context vector which is the weighted sum of the encoded sequence with the attention probabilities. Intuitively, this vector summarizes the importance of the different encoded characters in predicting the t-th character.

Equation 4: Reset gate. Equation 5: Update gate. Equation 6: Proposal hidden state. Equation 7: New hidden state.

3. We then update our hidden state. If you are familiar with the equations of an LSTM cell, these might be ring a bell as the reset gate, r, update gate, z, and the proposal state. We use the reset gate to control how much information from the previous hidden state st-1is used to create a proposal hidden state. The update gate controls how we much of the proposal we use in the new hidden state st. (Confused? See step by step walk through of LSTM equations)

Equation 8: A simple neural network to predict the next character.

4. Calculate the t-th character using a simple one layer neural network using the context, hidden state, and previous character. This is a modification from the paper which uses a maxout layer. Since we are trying to keep things as simple as possible this works fine!

Equations 1–8 are applied for every character in the encoded sequence to produce a decoded sequence y which represents the probability of a translated character at each position.


Our custom layer is implemented in . This part is somewhat complicated in particular because of the manipulations we need to make for vectorized operations acting on the complete encoded sequence. It will make more sense when you think about it. I promise it will become easier the more you look at the equations and the code simultaneously.

A minimal custom Keras layer has to implement a few methods: __init__, compute_ouput_shape, build and call. For completeness, we also implement get_config which allows you to load the model back into memory easily. In addition to these, a Keras recurrent layer implements a step method that will hold all the computations of our cell.

First let us break down boiler-plate layer code:

  • __init__ is what is called when the Layer is first instantiated. It sets functions that will eventually initialize the weights, regularizers, and constraints. Since the output of our layer is a sequence, we hard code self.return_sequences=True.
  • build is called when we run Model.compile(…). Since our model is quite complicated, you can see there are a ton of weights to initialize here. The call self.add_weight automatically handles initializing the weights and setting them as trainable within the model. Weights with the subscript a are used to calculate the context vector (step 1 and 2). Weights with the subscript r, z, p will calculate the new hidden states from step 3. Finally, weights with the subscript o will calculate the output of the layer.
  • Some convenience functions are implemented as well: (a) compute_output_shape will calculate output shapes for any given input; (b)get_config let’s us load the model using just a saved file (once we are done training)

Now for the cell logic:

  • By default, each execution of the cell only has information from the previous time step. Since we need to access the entire encoded sequence within the cell, we need to save it somewhere. Therefore, we make a simple modification in call. The only way I could find to do this was to set the sequence being fed into the cell as self.x so that we can access it later:

Now we need to think vectorized: The _time_distributed_dense function calculates the last term of Equation 1 for all the elements of the encoded sequence.

  • We now walk through the most important part of the code, that is in step which executes the cell logic. Recall that step is applied to every element of the input sequence.

In this cell we want to access the previous character ytm and hidden state stm which is obtained from states in line 4.

Think vectorized: we manipulate stm to repeat it for the number of characters we have in our input sequence.

On lines 11–18 we implement a version of equation 1 that does the calculations on all the characters in the sequence at once.

In lines 24–28 we have implemented Equation 2 in the vectorized form for the whole sequence. We use repeat to allow us to divide every time step by the respective sums.

To calculate the context vector, we need to keep in mind that self.x_seq and at have a “batch dimension” and therefore we need to use batch_dot to avoid doing the multiplication over that dimension. The squeeze operation just removes left-over dimensions. This is done in lines 33–37.

The next few lines of code are a more straightforward implementation of equations 4 –8.

Now we think a bit ahead: We would like to calculate those fancy attention maps in Figure 1. To do this, we need a “toggle” that returns the attention probabilities at.



Any good learning problem should have training data. In this case, it’s easy enough thanks to the Faker library which can generate fake dates with ease. I also use the Babel library to generate dates in different languages and formats as inspired by rasmusbergpalm/normalization. The script will generate some fake data and I won’t bore you with the details but invite you to poke around or to make it better.

The script also generates a vocabulary that will convert characters into integers so that the neural network can understand it. Also included is a script in to read and prepare data for the neural network to consume.


This simple model with a bidirectional LSTM and the decoder we wrote above is implemented in . You can run it using where I have set some default arguments (Readme has more information). I’d recommend training on a machine with a GPU as it can be prohibitively slow on a CPU-only machine.

If you want to skip the training part, I have provided some weights in


Now the easy part. For the visualizer implemented in , we need to load the weights in twice: Once with the predictive model, and the other to obtain the probabilities. Since we implemented the model architecture in models/NMT.py we can simply call the function twice:

from models.NMT import simpleNMT
predictive_model = simpleNMT(...)
predictive_model.load_weights(..., return_probabilities=False)
probability_model = simpleNMT(..., return_probabilities=True)

To simply use the implemented visualizer, you can type:

python visualizer.py -h

to see available command line arguments.

Example visualizations

Let us now examine the attentions we generated from probability_model. The predictive_model above returns the translated date that you see on the y-axis. On the x-axis is our input date. The map shows us which input characters (on the x-axis) were used in the prediction of the output character on the y-axis. The brighter the white, the more weight that character had. Here are some I thought were quite interesting.

A correct example which doesn’t pay attention to unnecessary information like days of the week:


