ā±ļø 90 min

Transformers & Attention Mechanism

Master the transformer architecture that powers modern NLP

The Transformer Revolution

Transformers revolutionized NLP by replacing recurrence with attention mechanisms. **Key Innovation: Self-Attention** Instead of processing sequences step-by-step like RNNs, transformers: - Process entire sequences in parallel - Learn which parts of input are relevant to each other - Scale to very long sequences - Enable models like BERT, GPT, T5 **Architecture:** 1. **Encoder**: Process input sequence (BERT-style) 2. **Decoder**: Generate output sequence (GPT-style) 3. **Both**: Seq2seq tasks like translation **Advantages over RNNs:** - Parallelizable (faster training) - Better at long-range dependencies - More interpretable (attention weights)

Understanding Self-Attention

Self-attention lets each token attend to all other tokens:

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert (self.head_dim * heads == embed_size), "Embed size must be divisible by heads"
        
        # Linear layers for Q, K, V
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        
    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]  # Batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        # Compute attention scores
        # Q * K^T / sqrt(d_k)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # energy shape: (N, heads, query_len, key_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # Softmax to get attention weights
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        
        # Apply attention to values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        # Concatenate heads
        out = out.reshape(N, query_len, self.heads * self.head_dim)
        
        # Final linear layer
        out = self.fc_out(out)
        return out, attention

# Example usage
batch_size = 2
seq_len = 10
embed_size = 512
heads = 8

# Create random input
x = torch.randn(batch_size, seq_len, embed_size)

# Create self-attention layer
attention_layer = SelfAttention(embed_size, heads)

# Forward pass
output, attention_weights = attention_layer(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nAttention weights for first head, first sequence:")
print(attention_weights[0, 0, :, :].detach().numpy().round(3))

# Visualizing attention pattern
print(f"\nAttention pattern (sum across heads):")
avg_attention = attention_weights[0].mean(dim=0)
print(avg_attention.detach().numpy().round(2))
Output:
Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Attention weights shape: torch.Size([2, 8, 10, 10])

Attention weights for first head, first sequence:
[[0.112 0.098 0.105 0.089 0.103 0.097 0.091 0.102 0.099 0.104]
 [0.095 0.108 0.102 0.094 0.098 0.105 0.099 0.097 0.103 0.099]
 [0.101 0.097 0.106 0.098 0.095 0.102 0.104 0.098 0.099 0.100]
 [0.099 0.103 0.097 0.107 0.100 0.096 0.101 0.098 0.102 0.097]
 [0.098 0.101 0.099 0.095 0.109 0.102 0.097 0.103 0.098 0.098]
 [0.102 0.096 0.098 0.104 0.097 0.108 0.099 0.100 0.097 0.099]
 [0.097 0.099 0.103 0.098 0.101 0.095 0.110 0.097 0.102 0.098]
 [0.103 0.098 0.096 0.101 0.098 0.099 0.097 0.109 0.101 0.098]
 [0.098 0.102 0.097 0.100 0.102 0.098 0.096 0.101 0.107 0.099]
 [0.095 0.098 0.097 0.104 0.097 0.098 0.104 0.095 0.098 0.114]]

Attention pattern (sum across heads):
[[0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]]

Complete Transformer Block

Build a full transformer encoder block:

python
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        
        # Multi-head self-attention
        self.attention = nn.MultiheadAttention(embed_size, heads, dropout=dropout)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, value, key, query, mask=None):
        # Multi-head attention with residual connection
        attention_output, _ = self.attention(query, key, value, attn_mask=mask)
        
        # Add & Norm
        x = self.dropout(self.norm1(attention_output + query))
        
        # Feed-forward with residual connection
        forward = self.feed_forward(x)
        
        # Add & Norm
        out = self.dropout(self.norm2(forward + x))
        
        return out

class Transformer(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 embed_size=512,
                 num_layers=6,
                 heads=8,
                 forward_expansion=4,
                 dropout=0.1,
                 max_length=100):
        super(Transformer, self).__init__()
        
        self.embed_size = embed_size
        
        # Token embedding
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        
        # Positional encoding
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, dropout, forward_expansion)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
        # Output layer
        self.fc_out = nn.Linear(embed_size, vocab_size)
        
    def forward(self, x, mask=None):
        N, seq_length = x.shape
        
        # Create positions
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)
        
        # Embeddings
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        
        # Pass through transformer blocks
        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        # Output projection
        out = self.fc_out(out)
        
        return out

# Create transformer model
vocab_size = 10000
model = Transformer(vocab_size, embed_size=512, num_layers=6, heads=8)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Transformer Model")
print(f"Total parameters: {total_params:,}")

# Example forward pass
batch_size = 4
seq_len = 20
x = torch.randint(0, vocab_size, (batch_size, seq_len))

with torch.no_grad():
    output = model(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nāœ“ Transformer model ready for training!")
Output:
Transformer Model
Total parameters: 60,182,536

Input shape: torch.Size([4, 20])
Output shape: torch.Size([4, 20, 10000])

āœ“ Transformer model ready for training!

Key Takeaways

**Transformer Advantages:** 1. **Parallel Processing**: All tokens processed simultaneously 2. **Long-Range Dependencies**: Attention connects any two positions directly 3. **Scalability**: Scales to very large models (GPT-3 has 175B parameters) 4. **Transfer Learning**: Pre-train on large corpus, fine-tune on specific tasks **Common Variants:** - **BERT**: Encoder-only, bidirectional (understanding) - **GPT**: Decoder-only, autoregressive (generation) - **T5**: Encoder-decoder (sequence-to-sequence) - **Vision Transformer (ViT)**: Transformers for images **When to Use:** - Text classification, NER, QA → BERT - Text generation, completion → GPT - Translation, summarization → T5

Sharan Initiatives - Making a Difference Together