Transformer & Attention

Understanding the basics


Is the point of this page just a fancy way of writing the Readme?

Table of Contents

Attention Transformer

Attention

The attention mechanism was introduced in this paper Attention is all you need! that provided competition to the then state of the art Seq2Seq modeling components like Recurrent Neural Network i.e. LSTM, GRU. As the name suggests, Attention is all you need to capture sequential information.

Attention Mechanism

Let's get into some equations now. We have three variables in Scaled Dot Product Attention: \(Q \in \mathbb{R}^{d}, < K \in \mathbb{R}^{T \times d}, V \in \mathbb{R}^{T \times d_v}> \)

$$ \text{Un-normalized Attention} = \text{Q} \cdot \text{K}^T$$

where Q is the query vector, K is sequence of Key vector and d is the dimension of the vectors. T is the total number of past time steps or sequence length we are considering which forms the history.

The intuition behind this is to find the cosine similarity between the two vector \( \text{Q }\& \text{ K}_t \) at timestep t. The cosine similarlity acts as a weightage to the corresponding value vector i.e. \( V_t \). All the weighted \(V_t\) is then summed up.

Since these are weightage we would like the scores to be somewhat smaller in value ideally close to the range \([0, 1] \) The point is we want scores that are somewhat close to each other in magnitude and the lowest is 0 which means do not given any attention to that timestamp in the sequence. Hence we normalize the above quantity by taking a softmax. The factor of \(\sqrt{d}\) is also multiplied to reduce the magnitude so that it does not push the values of the softmax in regions that have very small gradients.

$$\text{Normalized Attention} = \text{softmax}(\frac{\text{Q} \cdot \text{K}^T}{\sqrt{d}})$$ The shape of the above is going to be \( \text{T} \) i.e. a T dimensional vector where each dimension corresponds to the attention given to the value vector at that timestep. $$\text{Attention Value} = \text{softmax}(\frac{\text{Q} \cdot \text{K}^T}{\sqrt{d}})V$$

Lastly we get the attention value which is essentially multiplying the normalized attention we calculated for each timestep with the corresponding value.

Let's take a look at the code now.

In practice, we do not pass in one query, rather we pass batches as highlighted in the code above.

Multi-head Attention

Instead of having on attention mechanism, we can instead have multiple attention mechanism called heads, that is running in parallel to learn to extract features independently.

Thinking about how CNN acts as a feature extractor. We can have one filter but rather we have multiple filters that independently learns to extract some features from the images. Simmilarly, the attention mechanism can be thought of as a way of extracting feature from the sequential data. Instead of learning one feature, we now learn multiple features simultaneously.

In terms of equations, we first linearly transform Q, K, V for each of this filter. Then we run our attention mechanism. Once the attention values are found, we concatenate them back. Denote each of the head by \(i\). Then we have : $$ \text{AttnVal}_i = \text{Attention}(QW^{Q}_i,KW^{K}_i,VW^{V}_i)$$ where \(W_i\) are learned parameters that transforms the vectors.The AttentionValue are concatenated to get $$\text{MultiHead-AttnVal} = \text{concat}[...\text{AttnVal}_i..]$$

Lastly the MultiHead-AttnVal is again transformed using a linear TransformerAttention

$$\text{Ouptut} = (\text{MultiHead-AttnVal})\cdot W^{O}$$

Note that there are various dimension that we could be working with.

For the code, I have taken the second approach.

Positional Encoding

If you noticed in the attention mechanism we defined above, the sequential order i.e. who came before whom, gets lost when we calculate the attention. In case of RNNs, the order in which they are processed stores that information. For attention, we have to add this information about the sequence order in the form of positional encoding. There are couple of ways we can do this : (1) having a random vector that is learned during the training process, (2) having a vector defined using the \(\sin \& \cos\).

These are vectors of the same dimension as the input vector, \(d\) so that they can be summed up with the input vector. They are defined as

$$ \text{PE}_{\text{pos}, 2i} = \sin{(\frac{\text{pos}}{1000^\frac{2i}{d}})}$$ $$ \text{PE}_{\text{pos}, 2i+1} = \cos{(\frac{\text{pos}}{1000^{\frac{2i+1}{d}}})}$$

Note that i ranges from \([0,\frac{d}{2}]\) because we will be concatenating these two sin and cosine values to form the vector. This vector \(\text{PE}_{\text{pos}} \in \mathbb{R}^{d}\) gets vector added to the input i.e. the token at the sequence position, \(\text{pos}\)

$$x_{pos}^{'} = x_{pos} + \text{PE}_{\text{pos}}$$

The positional encoding using the sin and cosine forms somewhat a cool pattern

The code for generating this looks like below. There are much efficient way by using pytorch arrange to generate them though.

Transformer

The transformer is a sequence to sequence model (Seq2Seq) which means that it maps a sequence of input to another sequence. Common examples where this model is used is translation task, summarization.

Architecture

The architecture of transformer comprises of a standard encoder that takes input a sequence. For example in a translation task from English to Hindi, the input sequence would be a sentence in English. The sequence is first converted into a learnt embedding followed by addition of positional encoding to the embedded tokens. This goes into a encoder comprising of a multihead-attention and a fc layer with layer normalization in betwee. These encoders are stacked on top of each other.
Simmilarly, you have the decoder which takes in input as tokens from the output sequence i.e Hindi in our example. During training the true or ground labels are fed to predict the next token at each timestamp which is known as teacher forcing. During inference, this turns into auto-regressive i.e, the prediction of the next timestamp takes the prediction of the current timestamp. The decoder has an additional multihead attention that takes the key, value pairs from the encoder itself hence giving direct attention to the input sequence as shown in the architecture.