07. Transformer Block Design

Chapter 7 of 24 · 15 min

A transformer block combines multi-head self-attention with a feed-forward network, using residual connections and layer normalization. Modern implementations differ from the original transformer in subtle but important ways.

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (used in Llama)."""
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        # x: (batch, seq, d_model)
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return (x / rms) * self.weight

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ffn, dropout=0.1, bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = SwiGLUTransformerFFN(d_model, d_ffn, dropout, bias=bias)
        
        # Pre-norm architecture (used in modern transformers)
        self.attention_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)
    
    def forward(self, x, mask=None):
        # Pre-norm: normalize before attention (better training stability)
        x = x + self.attention(self.attention_norm(x), mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x

Critical design decision: Pre-norm vs. post-norm. Original transformers used post-normalization (norm after attention/FFN). Modern models use pre-norm because it provides more stable gradients during deep stacks (>12 layers). Post-norm with deep networks causes gradient vanishing/exploding.

Failure mode: Attention-norm applied before attention but input is already residual-connected. Common mistake: x = x + attention(norm(x)) correctly implements pre-norm. An incorrect implementation x = norm(x + attention(x)) changes the network behavior and degrades performance.

Failure mode: d_ffn dimension. Common ratio to d_model is 3-4x (Llama uses 3.5x). A ratio too large wastes compute; too small limits model capacity. For d_model=4096, d_ffn=11008 is standard but d_ffn=16384 is often used in larger models.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Implement both pre-norm and post-norm transformer blocks. Train identical small models with each for 5 epochs on a simple task (like adding random numbers). Compare gradient magnitudes at layer 8 of each stack.