Positional Embeddings

Positional Embeddings#

In the RNN architectures,the decoder state at time step \(t\) was a function of the decoder state at time step \(t-1\) and the input token at time step \(t\). In other words, the order of the tokens in the input sequential order was inherently maintained as the next hidden state could simply not be produced before the previous hidden state.

In transformers, since we got rid of the recurring connections, we need to capture the order of the tokens in the input sequence with some other way. To do so, we use positional embeddings where we adopt an approach similar to the one we used for the word embeddings such as word2vec. Instead of learning the embeddings of the tokens though, we learn the embeddings of the positions of the tokens. We define a learnable embedding matrix \(E\) of size \(T \times d\) where \(T\) is the input sequence length and \(d\) is the embedding dimension. We then add the positional embedding of each token to its corresponding token embedding.

import torch 
import torch.nn as nn

class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size,
                                             config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
        # Create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings