The use of artificial neural networks to create chatbots is increasingly popular nowadays, however, teaching a computer to have natural conversations is very difficult and often requires large and complicated language models.
With all the changes and improvements made in TensorFlow 2.0 we can build complicated models with ease. In this post, we will demonstrate how to build a Transformer chatbot. All of the code used in this post is available in this colab notebook, which will run end to end (including installing TensorFlow 2.0).
This article assumes some knowledge of text generation, attention and transformer. In this tutorial we are going to focus on:
tf.data
input: where have you been ?
output: i m not talking about that .
input: i am not crazy , my mother had me tested .
output: i m not sure . i m not hungry .
input: i m not sure . i m not hungry .
output: you re a liar .
input: you re a liar .
output: i m not going to be a man . i m gonna need to go to school .
Sample conversations of a Transformer chatbot trained on Movie-Dialogs Corpus.
Transformer, proposed in the paper Attention is All You Need, is a neural network architecture solely based on self-attention mechanism and is very parallelizable.
tf.keras model plot of our Transformer
A Transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:
The disadvantage of this architecture:
If you are interested in knowing more about Transformer, check out The Annotated Transformer and Illustrated Transformer.
We are using the Cornell Movie-Dialogs Corpus as our dataset, which contains more than 220k conversational exchanges between more than 10k pairs of movie characters.
“+++$+++” is being used as a field separator in all the files within the corpus dataset.
movie_conversations.txt
has the following format: ID of the first character, ID of the second character, ID of the movie that this conversation occurred, and a list of line IDs. The character and movie information can be found in movie_characters_metadata.txt
and movie_titles_metadata.txt
respectively.
u0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ u2 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ m0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ u2 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ m0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ [‘L198’, ‘L199’]
u0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ u2 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ m0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ u2 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ m0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ [‘L204’, ‘L205’, ‘L206’]
u0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ u2 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ m0 +++u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]
++ [‘L207’, ‘L208’]
*Samples of conversations pairs from *<em>movie_conversations.txt</em>
movie_lines.txt
has the following format: ID of the conversation line, ID of the character who uttered this phase, ID of the movie, name of the character and the text of the line.
L901 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ u5 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ m0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ KAT +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ He said everyone was doing it. So I did it.
L900 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ u0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ m0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ BIANCA +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ As in…
L899 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ u5 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ m0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ KAT +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ Now I do. Back then, was a different story.
L898 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ u0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ m0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ BIANCA +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ But you hate Joey
L897 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ u5 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ m0 +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ KAT +++L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe
++ He was, like, a total babe
*Samples of conversation text from *<em>movie_lines.txt</em>
We are going to build the input pipeline with the following steps:
move_conversations.txt
and movie_lines.txt
START_TOKEN
and END_TOKEN
to indicate the start and end of each sentence.MAX_LENGTH
tokens.MAX_LENGTH
tf.data.Dataset
with the tokenized sentencesNotice that Transformer is an autoregressive model, it makes predictions one part at a time and uses its output so far to decide what to do next. During training this example uses teacher-forcing. Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.
The full preprocessing code can be found at the Prepare Dataset section of the colab notebook.
i really , really , really wanna go , but i can t . not unless my sister goes .
i m workin on it . but she doesn t seem to be goin for him .
Sample preprocessed conversation pair
Like many sequence-to-sequence models, Transformer also consist of encoder and decoder. However, instead of recurrent or convolution layers, Transformer uses multi-head attention layers, which consist of multiple scaled dot-product attention.
Attention architecture diagrams from Attention is All You Need
Scaled dot product attention
The scaled dot-product attention function takes three inputs: Q (query
), K (key
), V (value
). The equation used to calculate the attention weights is:
As the softmax normalization being applied on the key
, its values decide the amount of importance given to the query
. The output represents the multiplication of the attention weights and value
. This ensures that the words we want to focus on are kept as is and the irrelevant words are flushed out.
def scaled_dot_product_attention(query, key, value, mask):
matmul_qk = tf.matmul(query, key, transpose_b=True)
depth = tf.cast(tf.shape(key)[-1], tf.float32)
logits = matmul_qk / tf.math.sqrt(depth)
# add the mask zero out padding tokens.
if mask is not None:
logits += (mask * -1e9)
attention_weights = tf.nn.softmax(logits, axis=-1)
return tf.matmul(attention_weights, value)
Implementation of a scaled dot-product attention layer
The Sequential models allow us to build models very quickly by simply stacking layers on top of each other; however, for more complicated and non-sequential models, the Functional API and Model subclassing are needed. The tf.keras
API allows us to mix and match different API styles. My favourite feature of Model subclassing is the capability for debugging. I can set a breakpoint in the call()
method and observe the values for each layer’s inputs and outputs like a numpy array, and this makes debugging a lot simpler.
Here, we are using Model subclassing to implement our MultiHeadAttention
layer.
Multi-head attention consists of four parts:
Each multi-head attention block takes a dictionary as input, which consist of query, key and value. Notice that when using Model subclassing with Functional API, the input(s) has to be kept as a single argument, hence we have to wrap query, key and value as a dictionary.
The input are then put through dense layers and split up into multiple heads. scaled_dot_product_attention()
defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated and put through a final dense layer.
Instead of one single attention head, query, key, and value are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, name="multi_head_attention"):
super(MultiHeadAttention, self).__init__(name=name)
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.query_dense = tf.keras.layers.Dense(units=d_model)
self.key_dense = tf.keras.layers.Dense(units=d_model)
self.value_dense = tf.keras.layers.Dense(units=d_model)
self.dense = tf.keras.layers.Dense(units=d_model)
def split_heads(self, inputs, batch_size):
inputs = tf.reshape(
inputs, shape=(batch_size, -1, self.num_heads, self.depth))
return tf.transpose(inputs, perm=[0, 2, 1, 3])
def call(self, inputs):
query, key, value, mask = inputs['query'], inputs['key'], inputs[
'value'], inputs['mask']
batch_size = tf.shape(query)[0]
# linear layers
query = self.query_dense(query)
key = self.key_dense(key)
value = self.value_dense(value)
# split heads
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
scaled_attention = scaled_dot_product_attention(query, key, value, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model))
outputs = self.dense(concat_attention)
return outputs
Implementation of multi-head attention layer with model subclassing
Transformer architecture diagram from Attention is All You Need
Transformer uses stacked multi-head attention and dense layers for both the encoder and decoder. The encoder maps an input sequence of symbol representations to a sequence of continuous representations. Then the decoder takes the continuous representation and generates an output sequence of symbols one element at a time.
Since Transformer doesn’t contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence.
The formula for calculating the positional encoding
The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space. To learn more about Positional Encoding, check out this tutorial.
We implemented the Positional Encoding with Model subclassing where we apply the encoding matrix to the input in call()
.
class PositionalEncoding(tf.keras.layers.Layer):
def __init__(self, position, d_model):
super(PositionalEncoding, self).__init__()
self.pos_encoding = self.positional_encoding(position, d_model)
def get_angles(self, position, i, d_model):
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
return position * angles
def positional_encoding(self, position, d_model):
angle_rads = self.get_angles(
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
d_model=d_model)
# apply sin to even index in the array
sines = tf.math.sin(angle_rads[:, 0::2])
# apply cos to odd index in the array
cosines = tf.math.cos(angle_rads[:, 1::2])
pos_encoding = tf.concat([sines, cosines], axis=-1)
pos_encoding = pos_encoding[tf.newaxis, ...]
return tf.cast(pos_encoding, tf.float32)
def call(self, inputs):
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
Implementation of Positional Encoding with Model subclassing
With the Functional API, we can stack our layers similar to Sequential model but without the constraint of it being a sequential model, and without declaring all the variables and layers we needed in advance like Model subclassing. One advantage of the Functional API is that it validate the model as we build it, such as checking the input and output shape for each layer, and raise meaningful error message when there is a mismatch.
We are implementing our encoding layers, encoder, decoding layers, decoder and the Transformer itself using the Functional API.
Checkout how to implement the same models with Model subclassing from this tutorial.
Encoding Layer
Each encoder layer consists of sublayers:
def encoder_layer(units, d_model, num_heads, dropout, name="encoder_layer"):
inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
attention = MultiHeadAttention(
d_model, num_heads, name="attention")({
'query': inputs,
'key': inputs,
'value': inputs,
'mask': padding_mask
})
attention = tf.keras.layers.Dropout(rate=dropout)(attention)
attention = tf.keras.layers.LayerNormalization(
epsilon=1e-6)(inputs + attention)
outputs = tf.keras.layers.Dense(units=units, activation='relu')(attention)
outputs = tf.keras.layers.Dense(units=d_model)(outputs)
outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
outputs = tf.keras.layers.LayerNormalization(
epsilon=1e-6)(attention + outputs)
return tf.keras.Model(
inputs=[inputs, padding_mask], outputs=outputs, name=name)
Implementation of an encoder layer with Functional API
We can use tf.keras.utils.plot_model()
to visualize our model. (Checkout all the model plots on the colab notebook)
Flow diagram of an encoder layer
The Encoder consists of:
The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.
def encoder(vocab_size,
num_layers,
units,
d_model,
num_heads,
dropout,
name="encoder"):
inputs = tf.keras.Input(shape=(None,), name="inputs")
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
embeddings = PositionalEncoding(vocab_size, d_model)(embeddings)
outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)
for i in range(num_layers):
outputs = encoder_layer(
units=units,
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
name="encoder_layer_{}".format(i),
)([outputs, padding_mask])
return tf.keras.Model(
inputs=[inputs, padding_mask], outputs=outputs, name=name)
Implementation of encoder with Functional API
Each decoder layer consists of sublayers:
As query receives the output from decoder’s first attention block, and key receives the encoder output, the attention weights represent the importance given to the decoder’s input based on the encoder’s output. In other words, the decoder predicts the next word by looking at the encoder output and self-attending to its own output.
def decoder_layer(units, d_model, num_heads, dropout, name="decoder_layer"):
inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
enc_outputs = tf.keras.Input(shape=(None, d_model), name="encoder_outputs")
look_ahead_mask = tf.keras.Input(
shape=(1, None, None), name="look_ahead_mask")
padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')
attention1 = MultiHeadAttention(
d_model, num_heads, name="attention_1")(inputs={
'query': inputs,
'key': inputs,
'value': inputs,
'mask': look_ahead_mask
})
attention1 = tf.keras.layers.LayerNormalization(
epsilon=1e-6)(attention1 + inputs)
attention2 = MultiHeadAttention(
d_model, num_heads, name="attention_2")(inputs={
'query': attention1,
'key': enc_outputs,
'value': enc_outputs,
'mask': padding_mask
})
attention2 = tf.keras.layers.Dropout(rate=dropout)(attention2)
attention2 = tf.keras.layers.LayerNormalization(
epsilon=1e-6)(attention2 + attention1)
outputs = tf.keras.layers.Dense(units=units, activation='relu')(attention2)
outputs = tf.keras.layers.Dense(units=d_model)(outputs)
outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
outputs = tf.keras.layers.LayerNormalization(
epsilon=1e-6)(outputs + attention2)
return tf.keras.Model(
inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
outputs=outputs,
name=name)
Implementation of decoder layer with Functional API
The Decoder consists of:
The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.
def decoder(vocab_size,
num_layers,
units,
d_model,
num_heads,
dropout,
name='decoder'):
inputs = tf.keras.Input(shape=(None,), name='inputs')
enc_outputs = tf.keras.Input(shape=(None, d_model), name='encoder_outputs')
look_ahead_mask = tf.keras.Input(
shape=(1, None, None), name='look_ahead_mask')
padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')
embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
embeddings = PositionalEncoding(vocab_size, d_model)(embeddings)
outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)
for i in range(num_layers):
outputs = decoder_layer(
units=units,
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
name='decoder_layer_{}'.format(i),
)(inputs=[outputs, enc_outputs, look_ahead_mask, padding_mask])
return tf.keras.Model(
inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
outputs=outputs,
name=name)
Implementation of a decoder with Functional API
Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.
enc<em>padding</em>mask
and dec_padding_mask
are used to mask out all the padding tokens. look_ahead_mask
is used to mask out future tokens in a sequence. As the length of the masks changes with different input sequence length, we are creating these masks with Lambda layers.
def transformer(vocab_size,
num_layers,
units,
d_model,
num_heads,
dropout,
name="transformer"):
inputs = tf.keras.Input(shape=(None,), name="inputs")
dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs")
enc_padding_mask = tf.keras.layers.Lambda(
create_padding_mask, output_shape=(1, 1, None),
name='enc_padding_mask')(inputs)
# mask the future tokens for decoder inputs at the 1st attention block
look_ahead_mask = tf.keras.layers.Lambda(
create_look_ahead_mask,
output_shape=(1, None, None),
name='look_ahead_mask')(dec_inputs)
# mask the encoder outputs for the 2nd attention block
dec_padding_mask = tf.keras.layers.Lambda(
create_padding_mask, output_shape=(1, 1, None),
name='dec_padding_mask')(inputs)
enc_outputs = encoder(
vocab_size=vocab_size,
num_layers=num_layers,
units=units,
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)(inputs=[inputs, enc_padding_mask])
dec_outputs = decoder(
vocab_size=vocab_size,
num_layers=num_layers,
units=units,
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask])
outputs = tf.keras.layers.Dense(units=vocab_size, name="outputs")(dec_outputs)
return tf.keras.Model(inputs=[inputs, dec_inputs], outputs=outputs, name=name)
Implementation of Transformer with Functional API
We can initialize our Transformer as follows:
NUM_LAYERS = 2
D_MODEL = 256
NUM_HEADS = 8
UNITS = 512
DROPOUT = 0.1
model = transformer(
vocab_size=VOCAB_SIZE,
num_layers=NUM_LAYERS,
units=UNITS,
d_model=D_MODEL,
num_heads=NUM_HEADS,
dropout=DROPOUT)
After defining our loss function, optimizer and metrics, we can simply train our model with model.fit()
. Notice that we have to mask our loss function such that the padding tokens get ignored, also we are writing our custom learning rate.
def loss_function(y_true, y_pred):
y_true = tf.reshape(y_true, shape=(-1, MAX_LENGTH - 1))
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
loss = tf.multiply(loss, mask)
return tf.reduce_mean(loss)
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps**-1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(D_MODEL)
optimizer = tf.keras.optimizers.Adam(
learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
def accuracy(y_true, y_pred):
# ensure labels have shape (batch_size, MAX_LENGTH - 1)
y_true = tf.reshape(y_true, shape=(-1, MAX_LENGTH - 1))
accuracy = tf.metrics.SparseCategoricalAccuracy()(y_true, y_pred)
return accuracy
model.compile(optimizer=optimizer, loss=loss_function, metrics=[accuracy])
EPOCHS = 20
model.fit(dataset, epochs=EPOCHS)
To evaluate, we have to run inference one time-step at a time, and pass in the output from the previous time-step as input.
Notice that we don’t normally apply dropout during inference, but we didn’t specify a training
argument for our model. This is because training
and mask
are already built-in for us, if we want to run model
for evaluation, we can simply call model(inputs, training=False)
to run the model in inference mode.
def evaluate(sentence):
sentence = preprocess_sentence(sentence)
sentence = tf.expand_dims(
START_TOKEN + tokenizer.encode(sentence) + END_TOKEN, axis=0)
output = tf.expand_dims(START_TOKEN, 0)
for i in range(MAX_LENGTH):
predictions = model(inputs=[sentence, output], training=False)
# select the last word from the seq_len dimension
predictions = predictions[:, -1:, :]
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# return the result if the predicted_id is equal to the end token
if tf.equal(predicted_id, END_TOKEN[0]):
break
# concatenated the predicted_id to the output which is given to the decoder as its input.
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0)
def predict(sentence):
prediction = evaluate(sentence)
predicted_sentence = tokenizer.decode([i for i in prediction if i < tokenizer.vocab_size])
return predicted_sentence
Transformer evaluation implementation
To test our model, we can call predict(sentence)
.
>>> output = predict(‘Where have you been?’)
>>> print(output)
i don t know . i m not sure . i m a paleontologist .
Here we are, we have implemented a Transformer in TensorFlow 2.0 in around 500 lines of code.
In this tutorial, we focus on the two different approaches to implement complex models with Functional API and Model subclassing, and how to incorporate them.
If you want to know more about the two different approaches and their pros and cons, check out when to use the functional API section on TensorFlow’s guide.
Try using a different dataset or hyper-parameters to train the Transformer!
#tensorflow #chatbot #machine-learning #artificial-intelligence