Hello and welcome back to an article where we are going to discuss an architecture that had mixed impressions. Some called it brilliant and some of them said, “Nah, this ain’t no transformer!” The architecture I’m talking about is the Fastformer: Additive Attention can be all you need. As we all know by now, Transformers are quite inefficient while scaling up and we have seen a plethora of architectures that claim to mitigate this in their own ways. Fastformer is different though — No pairwise interactions. Interesting, isn’t it? Pairwise interactions were the essence of Transformers that helped them perform excellently on NLP tasks. Without the pairwise interactions, it did perform very well and was able to match and extend the SOTA against models like Longformer, Linformer, Linear Transformers, etc with a linear scaling factor!
The Transformers’ rise was due to the self-attention concept which models a sequence with the context around it. Since there are many tokens to attend and to be attended by, it becomes a brute force calculation of dot product between the tokens. We have seen architectures like BigBird, Linformer, Longformer, Transformer-XL, Reformer, etc which have to some extent solved this problem. Fastformer claims to have done this in linear time complexity with a nice technique to calculate the global context key and global context self-attention matrices. The authors put forward a point saying,
Fastformer is the most efficient Transformer architecture
Let’s do a deep dive to validate their claim, or not.
Fastformer uses additive attention to model the global context and employs an element-wise dot product that reduces the computation and is also effective in modeling contextual information. What is this additive attention? Have a look at the diagram below, and we shall get started.
To prepare the model for its inputs, there are three transformations — Query, Key, and Value Transformation, which is having to create a representation of the input sequence independent of the surrounding tokens. e1 corresponds to q1, k1, and v1. e2 corresponds to q2, k2, and v2, and so on. Next, all the queries computed its weight/importance α among all queries. Next, all the queries are assigned a parameter α which will indicate its contribution towards the overall computation of the global query vector. Scratching my head: Isn’t this the opposite of Attention altogether? Continuing, once we have calculated α(i) for all the queries, they are multiplied with q(i) and summed up to form a global query vector. The formula is given below,
Once we have computed the global query vector q, we need to multiply it with the key vectors (the original transformer way). Here lies the scaling problem and the authors did an element-wise product of the global query vector and each of the key vectors to make sure the context understanding is not lost out as opposed to concatenation or a plain addition to a global key vector computed just like the global query vector.
Now we have an intermediate representation p wherein the query and keys are combined. Following this, the operation of computing a global key vector k is performed in the same way as computing the global query vector. Instead of α, we have β and instead of doing w(q).q(i) in the attention weight formula, we are doing w(k).p(i).
With this, we have our global key matrix which has a lot of summarized information. This is then combined with the attention value matrix by an element-wise product which results in another intermediate transformation vector u.
Finally, this u vector is applied a linear transformation just like in the vanilla transformer to learn the hidden representations. Post this, the final matrix from all these computations R is then added with the original query matrix for the final output for a self-attention block with a single head. Stacking layers and multiple heads of self-attention blocks gives us the overall Fastformer model. The parameters are shared across the layers to prevent overfitting. If we carefully observe the process, we can notice that there are no pairwise computations and instead everything is modelled into an element-wise product. The element-wise product is not of quadratic complexity and shows that the Fastformer is having a time complexity of O(N.d) instead of O(N².d) of the original transformer. If you were able to notice the fact no pairwise computations are happening and everything is either an element-wise product or attention affirms the fact that this is not a quadratic complexity. The time complexity comes to O(N.d) instead of O(N².d)
The model was trained on 5 different datasets — Amazon, IMDB, MIND, CNN/Daily Mail and PubMed. The tasks were text classification, news topic classification, news recommendation, and text summarization.
Fastformer is pretty much on top of the game in both classification and summarization tasks, almost outperforming all the other models. Along with the performance, Fastformer is very efficient when compared to other benchmark models as shown in the graph below.
There are studies of the influence of interaction functions among the key, query, and value matrices and influence of parameter sharing described in the paper and I recommend a thorough read for a better intuition of the above-mentioned design characteristics of the Fastformer.
Fastformer was fast, efficient, and delivered results. The design choices paid off and we were in the assumption we assumed that without the pairwise interaction modelling (which was the essence of vanilla transformer) the model would not perform as robustly as it did because of the obvious reason — The global context vectors might have dropped vital information and is the tradeoff worth it? Looks like it did, and I think it was due to the smart attention weighting. Hope you were able to follow this article and if you did, share it and put down your thoughts in the comments below.
- Transformer (Attention is All You Need): https://arxiv.org/pdf/1706.03762.pdf
- Fastformer: https://arxiv.org/pdf/2108.09084.pdf
- Fastformer Implementation: https://github.com/microsoft/fastformers