So far in our journey through the interesting architecture of Transformer we covered several topics. First we had a chance how this huge system looks like from the higher level. We saw how this type of sequence-to-sequence model harness the same principles like Recurrent Neural Networks and LSTMs, but we were also able to see which principles it utilizes to overcome their shortcomings. Because Transformers consist of so many parts we decided to split implementation into several articles. We started from the ground up and built low-level elements first.

So, we covered many topics already like pre-processing of data, attention layers, encoder and decoder. The main goal of our Transformer is to translate translate Russian into English, so the first thing we had to do was to implement positional encoding and attention layers. Then we used those “low-level” parts and combined them into Encoder and Decoder layers. After that, we stacked those layers and create big Encoder and Decoder components. Now we have to combine those elements too and get one step up. To sum it up, we are finalizing our Transformer which should look something like this:

High-level overview of Transformer architecture

Of course, this is just a high-level overview of this architecture. If you want to see how each individual Encoder and Decoder is build, check out our previous article. Encoder and Decoder layers stacked together and connected to each other.

It is important to notice that complete implementation is based on the amazing “Attention is all you need” paper, so we are relying heavily on the things that are defined there. We suggest you to read this paper if you are serious about doing any kind of development in the sequence-to-sequence modeling.

Transformer Class

Ok, so in the previous article we implemented big Encoder and Decoder blocks. We stacked a bunch of layers in the architecture that should look like this:

We also added pre-processing layers that are performing Embedding and Positional Encoding. So we created and connected each Encoder and Decoder individually and added data processing beforehand. Now let’s combine them into Transformer class and add final Linear layer on top of that. Linear layer is practically just one Dense layer. Here is how Transformer class looks like:

Since we have done all the heavy lifting in previous articles, this one is a cake walk. We just instantiated Encoder and Decoder class we implemented in the previous article and added Dense layer on top of that. It is important to notice that we inherited Model class, so we are able to perform training and get predictions using this class. Apart from that, note that we need to pass on masks that Encoder and Decoder use during the training process.

Training

Ok, now to the fun part – training. Since we follow “Attention is all you need” paper we use Adam optimizer as the authors of the article suggested. However, since in this paper learning rate variate, we need to create custom scheduler that is able to do this.

Scheduler and Optimizer

The formula used for changing learning rate during the training is:

In a nutshell, the learning rate is increasing in the first part of the training. Namely it is increasing until the number of training steps reaches the number – warmup_steps. After that it is decreasing proportionally to the inverse square root of the step number. In this paper, value 4000 is used for warmup_steps, so we are doing the same. This means that for the first 4000 steps the learning rate will increase and than it will slowly downgrade. Something like this:

Variable Learning Rate

The previously mentioned formula is implemented within Schedule class:

Note that this class inherits LearningRateSchedule. Because of this we can pass on object of this class into the optimizer object and control the learning rate during the training process. Something like this:

Padding Loss Function

Since all sequences are padded, we need to apply padding mask when loss is calculated as well. As an objective function SparseCategoricalCrossentropy is used and it is padded in padded_lossfunction function like this:

There is nothing special about this function. We are calculating loss using predefined objective function and then we pad it with the mask.

Training Process

Finally we can start the training process. First we need to initialize all necessary parameters and instantiate the object of the Transformer class:

Then we create train_step function. This is the TensorFlow function that is in charge of the training process. This choice was made because we wanted to speed up the execution using TensorFlow graph. Here is how it looks like:

In an essence, this function receives two inputs, ie. two sequences, which are defined in the signature. Shapes are broadly defined to avoid variable re-tracing. In the beginning we need to create masks for Encoder and Decoder. They are passed on to the call of transformer function. Then we utilize GradientTape and run the Transformer.

We pick up the predictions and use them to calculate loss. For that, we use padded function we defined previously. Once that is done, we utilize optimizer and modify Transformers trainable parameters. In the end we call training_loss and training_accuracy. After this we can easily start Transformer training:

The output looks something like this:

Training process output

When we run transformer here are the results:

Input:  это проблема, которую мы должны решить.
Predicted: this is a problem that we have to solve .
Real: this is a problem we have to solve .

Conclusion

In this article we finalized our journey through the world of Transformers. We finally put all pieces from previous articles together and run this massive architecture. We saw how we can create a scheduler which can control learning rate in the optimizer and we saw how we can create training process for this structure. In the end we got really good results as expected.

Thank you for reading!


Read more posts from the author at Rubik’s Code.


Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.