Understanding Attention: A Code-First Journey Through Transformers

24 min read

Transformers underpin most modern language models. This post builds them from scratch, starting simple and adding complexity step by step. The goal is to understand not just what attention does, but how it works at the tensor level.

The code examples are meant to be run interactively. Appendix A. Virtual Environment Setup and B. REPL setup are provided for a quick environment setup. I recommend reading the code carefully, pasting into your Python shell and playing with it to get a feel for attention and multi-headed attention. Note that the example outputs shown are illustrative — your values will differ on each run since the random seed varies.

Let’s begin.

The Simplest Attention Mechanism

Tokens and Embeddings

Let’s start with the absolute minimum — three tokens (words) in a sequence (sentence), each token represented by a 4-dimensional vector (embedding):

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# Create a simple sequence: 3 tokens, 4 dimensions each
seq_len = 3
embed_dim = 4
 
# Random embeddings for our tokens
x = torch.randn(seq_len, embed_dim)
 
print(f"Input shape: {x.shape}")
print(f"Input:\n{x}")
Input shape: torch.Size([3, 4])
Input:
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303],
        [-1.1229, -0.1863,  2.2082, -0.6380],
        [ 0.4617,  0.2674,  0.5349,  0.8094]])

torch.Size([3, 4]) means 3 tokens (“words”) with each token represented by 4 features (a 4-dimensional embedding).

Queries, Keys, and Values

In attention, each token decides what to pay attention to. This is done through three projections:

ProjectionMathQuestion
QueryQ=XWqQ=XW_qWhat am I looking for?
KeyK=XWkK=XW_kWhat information do I contain?
ValueV=XWvV=XW_vWhat information will I actually send?

Representing in code:

# Simple linear projections (no bias for clarity)
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)
 
# Create Q, K, V
Q = W_q(x)  # (3, 4)
K = W_k(x)  # (3, 4)
V = W_v(x)  # (3, 4)
 
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
Q shape: torch.Size([3, 4])
K shape: torch.Size([3, 4])
V shape: torch.Size([3, 4])

Each nn.Linear is a matrix multiplication, or written another way:

Q = x @ W_q.T    # (3,4) @ (4,4) = (3,4)

Thus, each token’s embedding gets transformed with these WW learned weights. The WW matrices are square with the same dimensions as the embedding dimension (4 features). After multiplying with these weights, the final shape of the query, key and value matrices are the same as the input data. The meaning is changed though through these learned weights.

Q, K, and V are produced using linear layers because the model needs simple, learnable projections of the same token embedding but in different subspaces — one for querying, one for matching, and one for carrying information forward.

A linear transformation is ideal because attention relies on dot-product similarity, which assumes a linear geometric structure; extra nonlinearities would distort that relationship.

Bias terms are omitted because 1) they don’t help with similarity scoring but shift all vectors uniformly 2) bias adds parameters without improving the dot-product matching that drives attention 3) omitting bias also makes multi-head splitting easier and symmetric and hence more parallelizable which matters when you have billions of parameters to learn.

Computing Attention Scores

Next, we compute how much attention each token should pay to every other token:

# Attention scores: Q @ K^T
# This gives us: for each query position, how much does it match each key?
scores = Q @ K.T  # (3, 4) @ (4, 3) = (3, 3)
 
print(f"Scores shape: {scores.shape}")
print(f"Scores:\n{scores}")
Scores shape: torch.Size([3, 3])
Scores:
tensor([[ 0.5423, -0.2819,  0.3764],
        [-0.2819,  0.8934, -0.1234],
        [ 0.3764, -0.1234,  0.7856]])

Matrix multiplication needs the column dimension of the first matrix to be equal to the row dimension of the second matrix. Since the embedding dimension is always fixed (the 4-dimensional “embedding” for each token), we align QQ and KK along this dimension. So 3 tokens go in and we get a 3x3 matrix of scores in a sort-of lookup table. Higher values mean pay more attention to this other token, and lower mean less.

Let’s visualize what this means:

print("\nAttention score interpretation:")
print("                    Key Position")
print("              Token0  Token1  Token2")
print(f"Query Token0  {scores[0,0]:6.2f}  {scores[0,1]:6.2f}  {scores[0,2]:6.2f}")
print(f"Query Token1  {scores[1,0]:6.2f}  {scores[1,1]:6.2f}  {scores[1,2]:6.2f}")
print(f"Query Token2  {scores[2,0]:6.2f}  {scores[2,1]:6.2f}  {scores[2,2]:6.2f}")
Attention score interpretation:
                    Key Position
              Token0  Token1  Token2
Query Token0    0.54   -0.28    0.38
Query Token1   -0.28    0.89   -0.12
Query Token2    0.38   -0.12    0.79
Figure 1: Attention score matrix showing query-key relationships. Each row represents how much a query attends to all keys. Higher values indicate stronger attention.

Scaling for Stability

Raw scores can get very large, which causes problems in softmax. We scale by the square root of the embedding dimension:

# Scale by sqrt(d_k) for numerical stability
d_k = K.shape[-1]  # Last dimension = feature dimension = 4
scaled_scores = scores / math.sqrt(d_k)
 
print(f"d_k = {d_k}")
print(f"Scale factor = 1/sqrt({d_k}) = {1/math.sqrt(d_k):.4f}")
print(f"Scaled scores:\n{scaled_scores}")
d_k = 4
Scale factor = 1/sqrt(4) = 0.5000
Scaled scores:
tensor([[ 0.2712, -0.1409,  0.1882],
        [-0.1409,  0.4467, -0.0617],
        [ 0.1882, -0.0617,  0.3928]])

Why scale? As dimensions grow, dot products grow in magnitude. Without scaling, softmax would produce very peaked distributions (almost one-hot), making gradients vanish. Scaling by 1/sqrt(d_k) keeps variance constant.

Applying Softmax

Now we convert scores to probabilities — each row must sum to 1, or, the attention scores for each query should sum to 1 across all keys.

# Apply softmax: converts scores to probability distribution
attention_weights = F.softmax(scaled_scores, dim=-1)
 
print(f"Attention weights shape: {attention_weights.shape}")
print(f"Attention weights:\n{attention_weights}")
print(f"\nRow sums (should be 1.0):")
print(attention_weights.sum(dim=-1))
Attention weights shape: torch.Size([3, 3])
Attention weights:
tensor([[0.3505, 0.2858, 0.3637],
        [0.2858, 0.4183, 0.2959],
        [0.3276, 0.2458, 0.4266]])
 
Row sums (should be 1.0):
tensor([1.0000, 1.0000, 1.0000])

dim=-1 means the “last dimension” (columns, or the keys) and F.softmax(..., dim=-1) normalizes each row independently (across the keys)

Computing the Output

Finally, we use these attention weights to create a weighted average of the values:

# Apply attention to values: weighted sum
output = attention_weights @ V  # (3, 3) @ (3, 4) = (3, 4)
 
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")
Output shape: torch.Size([3, 4])
Output:
tensor([[-0.0523,  0.0234, -0.1245,  0.0789],
        [-0.0312,  0.0445, -0.0923,  0.0534],
        [-0.0678,  0.0156, -0.1478,  0.0912]])

What just happened?

# For token 0:
output[0] = 0.35 * V[0] + 0.29 * V[1] + 0.36 * V[2]

Each output token is a mixture of all value vectors, weighted by attention. If token 0 strongly attends to token 1 (high attention weight), then output[0] will be heavily influenced by V[1].

Key takeaway: Attention is just four matrix operations:

  1. Project to Q, K, V
  2. Compute Q @ K.T / sqrt(d_k)
  3. Apply softmax
  4. Multiply by V

That’s it. Next, we’ll see how to handle batches and split attention across multiple heads.

Adding the Batch Dimension

Real models process multiple sequences at once for efficiency. Let’s extend our simple attention mechanism to handle batches beginning with a very small batch of 2 sentences each containing 3 words:

# Now we have a batch of sequences
batch_size = 2
seq_len = 3
embed_dim = 4
 
# Shape: (batch_size, seq_len, embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)
 
print(f"Input shape: {x.shape}")
print(f"Interpretation: {batch_size} sequences, {seq_len} tokens each, {embed_dim} features per token")
Input shape: torch.Size([2, 3, 4])
Interpretation: 2 sequences, 3 tokens each, 4 features per token

Understanding 3D tensors:

  • Dimension 0: Which sequence in the batch
  • Dimension 1: Which token in the sequence
  • Dimension 2: Which feature of the token

Think of it as a stack of matrices, one per sequence.

Batched Attention

# Create new projection layers that handle batches
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)
 
# Project (linear layers handle batches automatically!)
Q = W_q(x)  # (2, 3, 4)
K = W_k(x)  # (2, 3, 4)
V = W_v(x)  # (2, 3, 4)
 
print(f"Q shape: {Q.shape}")
Q shape: torch.Size([2, 3, 4])

How does nn.Linear handle batches?

# nn.Linear applies the same weights to each sequence independently
# It's equivalent to:
# for i in range(batch_size):
#     Q[i] = x[i] @ W_q.weight.T

PyTorch broadcasting handles this efficiently in parallel. Each sequence in the batch is completely independent — no information flows between batch elements. This makes attention embarrassingly parallel: a GPU can process all sequences simultaneously. This is why batch size has such a large impact on training speed: batch_size=1 might use 5% of GPU capacity, while batch_size=32 can reach 80%+ utilization.

Batched Matrix Multiplication

Now we need to be careful with matrix multiplication:

# We want: for each batch, compute Q @ K.T
# K.shape = (2, 3, 4)
# We need to transpose only the last two dimensions!
 
K_transposed = K.transpose(-2, -1)  # (2, 3, 4) -> (2, 4, 3)
 
print(f"K shape: {K.shape}")
print(f"K transposed shape: {K_transposed.shape}")
 
# Now compute scores
d_k = K.shape[-1]
scores = (Q @ K_transposed) / math.sqrt(d_k)  # (2, 3, 4) @ (2, 4, 3) = (2, 3, 3)
 
print(f"Scores shape: {scores.shape}")
K shape: torch.Size([2, 3, 4])
K transposed shape: torch.Size([2, 4, 3])
Scores shape: torch.Size([2, 3, 3])

Understanding transpose(-2, -1):

  • -1 refers to last dimension (size 4)
  • -2 refers to second-to-last dimension (size 3)
  • This swaps only these two, leaving batch dimension alone
  • Result: (batch, seq, features)(batch, features, seq)

Why not just K.T?

# K.T would transpose ALL dimensions: (2, 3, 4) -> (4, 3, 2)
# We only want to transpose within each batch: (2, 3, 4) -> (2, 4, 3)

Key insight: Each sequence in the batch gets its own attention pattern. The model processes them in parallel, but they don’t interact.

Causal Masking (For Autoregressive Models)

In language models, tokens can’t “see” future tokens. We enforce this with a causal mask:

# Create a causal mask: lower triangular matrix
seq_len = 4
mask = torch.tril(torch.ones(seq_len, seq_len))
 
print("Causal mask (1=allowed, 0=forbidden):")
print(mask)
Causal mask (1=allowed, 0=forbidden):
tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

Understanding torch.tril:

  • “tri” = triangular, “l” = lower
  • Creates lower triangular matrix (1s below and on diagonal, 0s above)
  • Position [i, j]: 1 if i >= j (token i can see token j), 0 otherwise

Applying the Mask

# Create sample attention scores
batch_size = 2
scores = torch.randn(batch_size, seq_len, seq_len)
 
print("Original scores (batch 0):")
print(scores[0])
 
# Apply mask: set forbidden positions to -infinity
# Before softmax, -inf becomes 0 after softmax
scores_masked = scores.masked_fill(mask == 0, float('-inf'))
 
print("\nMasked scores (batch 0):")
print(scores_masked[0])
 
# Apply softmax
attention_weights = F.softmax(scores_masked, dim=-1)
 
print("\nAttention weights after softmax (batch 0):")
print(attention_weights[0])
Original scores (batch 0):
tensor([[ 0.7234, -0.2341,  0.4567, -0.1234],
        [-0.5678,  0.8912,  0.3456, -0.2345],
        [ 0.1234, -0.6789,  0.9012,  0.5678],
        [-0.3456,  0.2345, -0.7890,  0.4567]])
 
Masked scores (batch 0):
tensor([[ 0.7234,    -inf,    -inf,    -inf],
        [-0.5678,  0.8912,    -inf,    -inf],
        [ 0.1234, -0.6789,  0.9012,    -inf],
        [-0.3456,  0.2345, -0.7890,  0.4567]])
 
Attention weights after softmax (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.2689, 0.7311, 0.0000, 0.0000],
        [0.2012, 0.0893, 0.7095, 0.0000],
        [0.1456, 0.2534, 0.0912, 0.5098]])

Understanding masked_fill:

# masked_fill(condition, value)
# Where condition is True, replace with value
# Where mask == 0 (forbidden), replace with -inf

Why -infinity?

# softmax(x) = exp(x) / sum(exp(x))
# exp(-inf) = 0
# So -inf becomes 0 in the final attention weights

Observe the pattern:

  • Token 0 only attends to itself (1.0 weight)
  • Token 1 attends to tokens 0 and 1
  • Token 2 attends to tokens 0, 1, and 2
  • Token 3 attends to all tokens

This is autoregressive: each position only sees the past.

Multi-Head Attention

Instead of one attention operation, we can run several in parallel. Each “head” can learn to attend to different things.

The Concept

Imagine you’re reading a sentence:

  • Head 1 might focus on grammatical relationships (subject-verb agreement)
  • Head 2 might track long-range dependencies (pronoun references)
  • Head 3 might capture local context (adjacent words)

Each head learns different patterns simultaneously.

The Challenge: Reshaping

The tricky part is managing dimensions. Let’s work through it step by step.

# Configuration
batch_size = 2
seq_len = 4
embed_dim = 8  # Total embedding dimension
n_heads = 2    # Number of attention heads
head_dim = embed_dim // n_heads  # Dimension per head = 4
 
print(f"Total embedding dimension: {embed_dim}")
print(f"Number of heads: {n_heads}")
print(f"Dimension per head: {head_dim}")
Total embedding dimension: 8
Number of heads: 2
Dimension per head: 4

Why head_dim is typically 64: Most production transformers use head_dim=64 regardless of model size (GPT-2: 768÷12=64, GPT-3: 12288÷96=64). 64 dimensions fits well into GPU warp sizes (groups of 32 threads) and allows efficient vectorized operations. Going much smaller loses expressiveness; going larger gives diminishing returns.

Step 1: Combined Q, K, V Projection

Instead of separate projections, we create one big projection and split it:

# Create input
x = torch.randn(batch_size, seq_len, embed_dim)
print(f"Input shape: {x.shape}")
 
# Single projection that creates Q, K, V for all heads
# Output dimension: 3 * embed_dim (for Q, K, and V)
W_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
 
qkv = W_qkv(x)  # (batch, seq, 3*embed_dim)
print(f"Combined QKV shape: {qkv.shape}")
print(f"This contains: Q (8 dims) + K (8 dims) + V (8 dims) = 24 dims")
Input shape: torch.Size([2, 4, 8])
Combined QKV shape: torch.Size([2, 4, 24])
This contains: Q (8 dims) + K (8 dims) + V (8 dims) = 24 dims

Why combine projections?

  • More efficient (one matrix multiply instead of three)
  • Better GPU utilization
  • Common pattern in production code

Performance note: Three separate linear layers means three CUDA kernel launches and three memory round-trips. A single Linear(embed_dim, 3*embed_dim) reduces this to one of each — roughly 20-30% faster in practice, which adds up across a 48-layer model.

Step 2: Split into Q, K, V

# Split the concatenated QKV
Q, K, V = qkv.split(embed_dim, dim=2)
 
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
Q shape: torch.Size([2, 4, 8])
K shape: torch.Size([2, 4, 8])
V shape: torch.Size([2, 4, 8])

Understanding split(embed_dim, dim=2):

  • Split along dimension 2 (last dimension)
  • Each chunk has size embed_dim = 8
  • Splits 24 into three 8s: [Q:0-8, K:8-16, V:16-24]

Step 3: Reshape for Multi-Head (The Tricky Part!)

Now we need to split the embedding dimension across heads:

# Current shape: (batch, seq, embed_dim)
# Target shape: (batch, n_heads, seq, head_dim)
 
# Step 3a: Reshape to expose head dimension
Q_reshaped = Q.view(batch_size, seq_len, n_heads, head_dim)
print(f"After view: {Q_reshaped.shape}")
print("Interpretation: (batch, seq, n_heads, head_dim)")
After view: torch.Size([2, 4, 2, 4])
Interpretation: (batch, seq, n_heads, head_dim)

Understanding view:

# Original Q: (2, 4, 8)
# We want to split 8 into (2 heads × 4 dims)
# view(2, 4, 2, 4) reshapes without copying data
# Total elements: 2*4*8 = 64 = 2*4*2*4 ✓

Visual representation:

Original: [batch=2, seq=4, embed=8]

Token embeddings: Batch 0, Token 0: [a1, a2, a3, a4, a5, a6, a7, a8] └─Head0: a1-a4 ┘└─Head1: a5-a8┘

After view: [batch=2, seq=4, heads=2, head_dim=4] Batch 0, Token 0, Head 0: [a1, a2, a3, a4] Batch 0, Token 0, Head 1: [a5, a6, a7, a8]

Figure 2: Reshaping embeddings to expose head dimension. The flat embedding vector is split into separate subspaces, one for each attention head.

Step 4: Transpose to Put Heads First

# Step 3b: Transpose to move heads before sequence
Q_heads = Q_reshaped.transpose(1, 2)
print(f"After transpose: {Q_heads.shape}")
print("Interpretation: (batch, n_heads, seq, head_dim)")
After transpose: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, head_dim)

Understanding transpose(1, 2):

  • Swaps dimensions 1 and 2
  • Dimension 1: sequence (size 4)
  • Dimension 2: heads (size 2)
  • After swap: (batch, n_heads, seq, head_dim)

Why transpose?

# We want to process each head independently
# Having heads in dimension 1 means we can think of it as:
# "A batch of (batch_size * n_heads) sequences"
# This lets us use the same attention code as before!

Complete Reshape Pipeline

def reshape_for_multihead(tensor, n_heads):
    """
    Reshape tensor for multi-head attention
    
    Input:  (batch, seq, embed_dim)
    Output: (batch, n_heads, seq, head_dim)
    
    where head_dim = embed_dim // n_heads
    """
    batch_size, seq_len, embed_dim = tensor.shape
    head_dim = embed_dim // n_heads
    
    # Step 1: Reshape to expose heads
    # (batch, seq, embed_dim) -> (batch, seq, n_heads, head_dim)
    tensor = tensor.view(batch_size, seq_len, n_heads, head_dim)
    
    # Step 2: Move heads dimension forward
    # (batch, seq, n_heads, head_dim) -> (batch, n_heads, seq, head_dim)
    tensor = tensor.transpose(1, 2)
    
    return tensor
 
# Apply to Q, K, V
Q_heads = reshape_for_multihead(Q, n_heads)
K_heads = reshape_for_multihead(K, n_heads)
V_heads = reshape_for_multihead(V, n_heads)
 
print(f"Q_heads shape: {Q_heads.shape}")
print(f"K_heads shape: {K_heads.shape}")
print(f"V_heads shape: {V_heads.shape}")
Q_heads shape: torch.Size([2, 2, 4, 4])
K_heads shape: torch.Size([2, 2, 4, 4])
V_heads shape: torch.Size([2, 2, 4, 4])

Step 5: Compute Attention Per Head

Now we can compute attention exactly as in the batched version! Each head operates independently:

# Compute attention for all heads at once
d_k = K_heads.shape[-1]
scores = (Q_heads @ K_heads.transpose(-2, -1)) / math.sqrt(d_k)
print(f"Scores shape: {scores.shape}")
print("Interpretation: (batch, n_heads, seq, seq)")
 
# Apply causal mask (see Causal Masking section for details)
mask = torch.tril(torch.ones(seq_len, seq_len))
scores = scores.masked_fill(mask == 0, float('-inf'))
 
# Softmax and apply to values
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V_heads
 
print(f"Output shape: {output.shape}")
print("Interpretation: (batch, n_heads, seq, head_dim)")
Scores shape: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, seq)
Output shape: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, head_dim)

What happened?

# For each batch, for each head, for each sequence position:
# output[b, h, i] = weighted_sum(V_heads[b, h, :])
# 
# We computed attention for all heads simultaneously!
# Head 0 and Head 1 each get their own attention patterns

Step 6: Concatenate Heads Back Together

# Current shape: (batch, n_heads, seq, head_dim)
# Target shape: (batch, seq, embed_dim)
 
# Step 6a: Move sequence dimension back
output_transposed = output.transpose(1, 2)
print(f"After transpose: {output_transposed.shape}")
print("Interpretation: (batch, seq, n_heads, head_dim)")
 
# Step 6b: Merge heads back into single dimension
output_concat = output_transposed.contiguous().view(batch_size, seq_len, embed_dim)
print(f"After concatenation: {output_concat.shape}")
print("Interpretation: (batch, seq, embed_dim)")
After transpose: torch.Size([2, 4, 2, 4])
Interpretation: (batch, seq, n_heads, head_dim)
After concatenation: torch.Size([2, 4, 8])
Interpretation: (batch, seq, embed_dim)

Understanding contiguous():

# After transpose, tensor data may not be contiguous in memory
# view() requires contiguous memory
# contiguous() creates a contiguous copy if needed
 
# Example:
# Without contiguous: memory layout = [h0_t0, h1_t0, h0_t1, h1_t1, ...]
# After contiguous:   memory layout = [t0_h0, t0_h1, t1_h0, t1_h1, ...]

Why is this necessary?

# transpose() creates a VIEW (no data copy, just changes indexing)
# view() requires actual contiguous memory
# contiguous() ensures data is laid out the way view() expects

Visual representation of concatenation:

Before: (batch, seq, n_heads, head_dim)
Token 0: Head0=[a,b,c,d], Head1=[e,f,g,h]

After: (batch, seq, embed_dim) Token 0: [a,b,c,d,e,f,g,h] ← Heads concatenated

Figure 3: Concatenating multi-head outputs back into a single embedding. The separate head outputs are merged to reconstruct the full embedding dimension.

Complete Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        
        # Single projection for Q, K, V
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq, embed_dim)
            mask: Optional (seq, seq) causal mask
        
        Returns:
            output: (batch, seq, embed_dim)
            attention_weights: (batch, n_heads, seq, seq)
        """
        batch_size, seq_len, embed_dim = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv_proj(x)  # (batch, seq, 3*embed_dim)
        Q, K, V = qkv.split(self.embed_dim, dim=2)
        
        # Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention
        d_k = self.head_dim
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = attention_weights @ V  # (batch, heads, seq, head_dim)
        
        # Concatenate heads: (batch, heads, seq, head_dim) -> (batch, seq, embed)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        
        # Final projection
        output = self.out_proj(output)
        
        return output, attention_weights
 
# Test it!
mha = MultiHeadAttention(embed_dim=8, n_heads=2)
x = torch.randn(2, 4, 8)
mask = torch.tril(torch.ones(4, 4))
 
output, attn = mha(x, mask)
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attn.shape}")
print(f"\nHead 0 attention pattern (batch 0):\n{attn[0, 0]}")
print(f"\nHead 1 attention pattern (batch 0):\n{attn[0, 1]}")
Output shape: torch.Size([2, 4, 8])
Attention shape: torch.Size([2, 2, 4, 4])
 
Head 0 attention pattern (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4782, 0.5218, 0.0000, 0.0000],
        [0.3156, 0.3567, 0.3277, 0.0000],
        [0.2543, 0.2789, 0.2234, 0.2434]])
 
Head 1 attention pattern (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5123, 0.4877, 0.0000, 0.0000],
        [0.3421, 0.3234, 0.3345, 0.0000],
        [0.2678, 0.2456, 0.2567, 0.2299]])

Each head learns its own attention pattern. In practice, different heads tend to specialize:

  • Some focus on local context (adjacent words)
  • Some capture long-range dependencies (pronouns → nouns)
  • Some track syntactic structure (subject-verb agreement)

Putting It All Together

Now we can assemble a standard transformer block:

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        
        # Layer normalization (before attention)
        self.ln1 = nn.LayerNorm(embed_dim)
        
        # Multi-head attention
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        
        # Layer normalization (before MLP)
        self.ln2 = nn.LayerNorm(embed_dim)
        
        # MLP (feedforward network)
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq, embed_dim)
            mask: Optional (seq, seq) causal mask
        
        Returns:
            output: (batch, seq, embed_dim)
        """
        # Attention block with residual connection
        attn_output, attn_weights = self.attention(self.ln1(x), mask)
        x = x + self.dropout(attn_output)
        
        # MLP block with residual connection
        mlp_output = self.mlp(self.ln2(x))
        x = x + mlp_output
        
        return x, attn_weights
 
# Create a small model
block = TransformerBlock(embed_dim=64, n_heads=4)
 
# Test input
x = torch.randn(2, 10, 64)  # batch=2, seq=10, embed=64
mask = torch.tril(torch.ones(10, 10))
 
output, attn = block(x, mask)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Same shape! Ready to stack more blocks.")
Input shape:  torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])
Same shape! Ready to stack more blocks.

A note on these choices: Pre-normalization (LayerNorm before attention/MLP) stabilizes training in deep models compared to post-norm. GELU tends to outperform ReLU for language tasks. The 4x MLP expansion ratio is a common balance between parameter count and capacity. Most transformers since GPT-2 use roughly this template.

Interactive Experimentation

To see this in action, we can train on a simple copy task and look at the resulting attention patterns:

# Create a simple sequence task: copy a pattern
def create_copy_dataset(n_samples, seq_len, vocab_size):
    """
    Create sequences where the task is to copy the input
    Input: [3, 1, 4, 1, 5]
    Target: [3, 1, 4, 1, 5]
    """
    x = torch.randint(0, vocab_size, (n_samples, seq_len))
    y = x.clone()
    return x, y
 
# Generate data
vocab_size = 10
seq_len = 8
x_train, y_train = create_copy_dataset(100, seq_len, vocab_size)
 
print(f"Sample input:  {x_train[0]}")
print(f"Sample target: {y_train[0]}")
Sample input:  tensor([3, 1, 4, 1, 5, 9, 2, 6])
Sample target: tensor([3, 1, 4, 1, 5, 9, 2, 6])

Training Loop (Simplified)

# Simple model: embedding + attention + output
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, seq_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len, embed_dim))
        self.transformer = TransformerBlock(embed_dim, n_heads)
        self.output = nn.Linear(embed_dim, vocab_size)
        self.seq_len = seq_len
    
    def forward(self, x):
        # Embed tokens and add positional embeddings
        x = self.embed(x) + self.pos_embed
        
        # Apply transformer with causal mask
        mask = torch.tril(torch.ones(self.seq_len, self.seq_len, device=x.device))
        x, attn = self.transformer(x, mask)
        
        # Project to vocabulary
        logits = self.output(x)
        
        return logits, attn
 
# Create model
model = SimpleTransformer(vocab_size=10, embed_dim=32, n_heads=4, seq_len=8)
 
# Quick training demo (not fully optimized)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 
print("Training for a few steps...")
for step in range(50):
    # Get batch
    batch_x = x_train[:32]
    batch_y = y_train[:32]
    
    # Forward pass
    logits, attn = model(batch_x)
    
    # Compute loss (cross entropy)
    loss = F.cross_entropy(logits.view(-1, vocab_size), batch_y.view(-1))
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print(f"Step {step:3d}, Loss: {loss.item():.4f}")
 
print("\nTraining complete!")
Training for a few steps...
Step   0, Loss: 2.3456
Step  10, Loss: 1.8923
Step  20, Loss: 1.4567
Step  30, Loss: 1.1234
Step  40, Loss: 0.8901
Training complete!

Inspecting Attention Patterns

# Get attention for a sample
model.eval()
with torch.no_grad():
    sample_x = x_train[0:1]  # Take first sample
    logits, attn = model(sample_x)
 
# Visualize attention patterns
print("Sample sequence:")
print(sample_x[0])
 
print("\nAttention patterns for each head:")
for head_idx in range(4):
    print(f"\nHead {head_idx}:")
    attn_pattern = attn[0, head_idx].numpy()
    
    # Simple ASCII visualization
    print("        ", "  ".join(f"T{i}" for i in range(8)))
    for i in range(8):
        row = "  ".join(f"{attn_pattern[i,j]:.2f}" if j <= i else " --- " for j in range(8))
        print(f"Token {i}: {row}")
Sample sequence:
tensor([3, 1, 4, 1, 5, 9, 2, 6])
 
Attention patterns for each head:
 
Head 0:
         T0    T1    T2    T3    T4    T5    T6    T7
Token 0: 1.00  ---   ---   ---   ---   ---   ---   --- 
Token 1: 0.52  0.48  ---   ---   ---   ---   ---   --- 
Token 2: 0.34  0.33  0.33  ---   ---   ---   ---   --- 
Token 3: 0.25  0.26  0.24  0.25  ---   ---   ---   --- 
Token 4: 0.21  0.19  0.20  0.21  0.19  ---   ---   --- 
Token 5: 0.17  0.16  0.18  0.16  0.17  0.16  ---   --- 
Token 6: 0.15  0.14  0.14  0.15  0.14  0.14  0.14  --- 
Token 7: 0.13  0.12  0.13  0.13  0.12  0.13  0.12  0.12
 
Head 1:
         T0    T1    T2    T3    T4    T5    T6    T7
Token 0: 1.00  ---   ---   ---   ---   ---   ---   --- 
Token 1: 0.48  0.52  ---   ---   ---   ---   ---   --- 
Token 2: 0.32  0.35  0.33  ---   ---   ---   ---   --- 
Token 3: 0.24  0.25  0.26  0.25  ---   ---   ---   --- 
Token 4: 0.19  0.20  0.20  0.21  0.20  ---   ---   --- 
Token 5: 0.16  0.17  0.16  0.17  0.17  0.17  ---   --- 
Token 6: 0.14  0.14  0.15  0.14  0.14  0.15  0.14  --- 
Token 7: 0.12  0.13  0.12  0.13  0.13  0.12  0.13  0.12
 
...

What we observe:

  • Each head learns slightly different patterns
  • Earlier positions get more uniform attention
  • Later positions show more variation
  • The causal mask is clearly visible (upper triangle is empty)

Key Takeaways

To summarize what we covered:

  1. Attention is weighted averaging: Q and K determine weights, V provides the values to average

  2. Scaling matters: Dividing by √d_k keeps gradients stable

  3. Causal masking: Setting future positions to -∞ enforces autoregressive property

  4. Multi-head reshaping:

    (batch, seq, embed)
    → view(batch, seq, heads, head_dim)
    → transpose(1, 2)
    → (batch, heads, seq, head_dim)
  5. contiguous() is necessary: After transpose, use contiguous() before view()

  6. PyTorch broadcasting: Linear layers automatically handle batch dimensions

What’s Next?

To go from here to a full transformer, you’d want to:

  1. Add positional encoding (we used learned embeddings, but sinusoidal is common)
  2. Stack multiple layers (GPT-2 has 12-48 layers)
  3. Add proper training infrastructure (gradient clipping, learning rate scheduling)
  4. Scale up (bigger models, more data, GPUs)

The attention mechanism covered here is the same one used in production language models — the rest is largely scaling and engineering.

Exercises for Understanding

Some things worth trying:

  1. Remove the causal mask - what happens to the attention patterns? (Revisit Causal Masking)
  2. Use 1 head vs 8 heads - does it learn faster? Better? (See why head_dim=64)
  3. Visualize attention over training - do patterns emerge?
  4. Try different head dimensions - how does 16 heads of size 4 compare to 4 heads of size 16?
  5. Add a task - make it learn sorting, reversal, or arithmetic

All the code above should be straightforward to modify and extend.

Appendix

Virtual Environment Setup

# Create virtual environment
python -m venv transformer_env
 
# Activate (Linux/Mac)
source transformer_env/bin/activate
 
# Activate (Windows)
transformer_env\Scripts\activate
 
# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install matplotlib numpy
 
# Verify installation
python -c "import torch; print(f'PyTorch {torch.__version__} installed successfully')"

For GPU support:

# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
 
# CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

REPL Setup

Imports and seed for an interactive session:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# For reproducibility
torch.manual_seed(42)
 
# Use CPU for pedagogy (easier to inspect tensors)
device = 'cpu'
 
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
PyTorch version: 2.0.0
Device: cpu

Function Implementations

Simple Attention

Let’s package the Part 1 steps into a function:

def simple_attention(x, W_q, W_k, W_v):
    """
    Simple attention mechanism
    
    Args:
        x: Input tensor of shape (seq_len, embed_dim)
        W_q, W_k, W_v: Linear projection layers
    
    Returns:
        output: Attention output of shape (seq_len, embed_dim)
        attention_weights: Attention weights of shape (seq_len, seq_len)
    """
    # Step 1: Project to Q, K, V
    Q = W_q(x)
    K = W_k(x)
    V = W_v(x)
    
    # Step 2: Compute attention scores
    d_k = K.shape[-1]
    scores = (Q @ K.T) / math.sqrt(d_k)
    
    # Step 3: Softmax to get weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 4: Apply weights to values
    output = attention_weights @ V
    
    return output, attention_weights
 
# Test it
output, attn_weights = simple_attention(x, W_q, W_k, W_v)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
Output shape: torch.Size([3, 4])
Attention weights shape: torch.Size([3, 3])

Batched Attention

def batched_attention(x, W_q, W_k, W_v):
    """
    Batched attention mechanism
    
    Args:
        x: Input tensor of shape (batch_size, seq_len, embed_dim)
    
    Returns:
        output: (batch_size, seq_len, embed_dim)
        attention_weights: (batch_size, seq_len, seq_len)
    """
    Q = W_q(x)
    K = W_k(x)
    V = W_v(x)
    
    d_k = K.shape[-1]
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
    attention_weights = F.softmax(scores, dim=-1)
    output = attention_weights @ V
    
    return output, attention_weights
 
# Test
output, attn_weights = batched_attention(x, W_q, W_k, W_v)
print(f"Output shape: {output.shape}")
print(f"First sequence attention pattern:\n{attn_weights[0]}")
print(f"Second sequence attention pattern:\n{attn_weights[1]}")
Output shape: torch.Size([2, 3, 4])
First sequence attention pattern:
tensor([[0.3245, 0.3512, 0.3243],
        [0.3389, 0.3156, 0.3455],
        [0.3401, 0.3267, 0.3332]])
Second sequence attention pattern:
tensor([[0.3523, 0.3012, 0.3465],
        [0.3156, 0.3689, 0.3155],
        [0.3398, 0.3234, 0.3368]])

Causal Attention

def causal_attention(x, W_q, W_k, W_v):
    """
    Causal (autoregressive) attention
    
    Args:
        x: Input tensor of shape (batch_size, seq_len, embed_dim)
    
    Returns:
        output: (batch_size, seq_len, embed_dim)
        attention_weights: (batch_size, seq_len, seq_len)
    """
    batch_size, seq_len, embed_dim = x.shape
    
    Q = W_q(x)
    K = W_k(x)
    V = W_v(x)
    
    d_k = K.shape[-1]
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Create causal mask
    mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
    scores = scores.masked_fill(mask == 0, float('-inf'))
    
    attention_weights = F.softmax(scores, dim=-1)
    output = attention_weights @ V
    
    return output, attention_weights

Further Reading

  • “Attention Is All You Need” - The original transformer paper
  • “The Illustrated Transformer” by Jay Alammar - Great visualizations
  • Andrej Karpathy’s nanoGPT - Minimal production code
  • PyTorch Documentation - Official docs with more details

The complete code is available as executable Python scripts for hands-on learning.

Subscribe

All the latest posts directly in your inbox.