06. SwiGLU Activation

Chapter 6 of 24 · 15 min

SwiGLU (Swish-Gated Linear Unit) combines the Swish activation with a gating mechanism, consistently improving performance in language models. It's used in Llama, GLM, and many modern architectures.

Standard SwiGLU formula:

SwiGLU(x) = Swish(W₁x) ⊗ (W₃x)
Output = W₂(SwiGLU(x))

Where Swish(x) = x · sigmoid(x), and ⊗ denotes element-wise multiplication.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    """SwiGLU activation function for FFN layers."""
    def __init__(self, d_model, d_ffn, bias=True):
        super().__init__()
        # Three linear projections (no bias in W1/W3 for some implementations)
        self.w1 = nn.Linear(d_model, d_ffn, bias=bias)
        self.w2 = nn.Linear(d_ffn, d_model, bias=bias)
        self.w3 = nn.Linear(d_model, d_ffn, bias=bias)
    
    def swish(self, x):
        return x * torch.sigmoid(x)
    
    def forward(self, x):
        # SwiGLU = Swish(W1x) ⊗ (W3x)
        return self.w2(self.swish(self.w1(x)) * self.w3(x))

class SwiGLUTransformerFFN(nn.Module):
    """Feed-forward network using SwiGLU in transformer blocks."""
    def __init__(self, d_model, d_ffn, dropout=0.1, bias=True):
        super().__init__()
        self.swiglu = SwiGLU(d_model, d_ffn, bias=bias)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(self.swiglu(x))

Failure mode: Dimension explosion. SwiGLU uses three linear projections instead of two, increasing parameter count. With d_model=4096 and d_ffn=11008, you have roughly 40% more parameters than ReLU-based FFN.

Failure mode: NaN training if residual connections are missing or improperly scaled. SwiGLU can produce large activations; skip connections stabilize training. Verify gradients flow properly through the residual path.

Failure mode: Bias terms inconsistently handled. Some implementations (Llama) omit bias from w1 and w3 but include it in w2. Others include bias everywhere. Mixing implementations causes subtle numerical differences.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Compare SwiGLU with ReLU and GeGLU in a small transformer on a text classification task. Measure training loss convergence over 1000 steps. Verify that SwiGLU requires adaptive optimizers (like Adam) more than ReLU due to vanishing gradients near zero.