08. Mixture of Experts
Mixture of Experts (MoE) scaling decouples parameter count from computation cost. Instead of activating all experts for every token, MoE routes tokens to a subset of experts, enabling massive models with constant per-token computation.
class MoELayer(nn.Module):
"""Mixture of Experts layer with top-k routing."""
def __init__(self, d_model, d_ffn, n_experts, top_k=2, bias=False):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# Each expert is an independent FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ffn, bias=bias),
nn.GELU(),
nn.Linear(d_ffn, d_model, bias=bias)
)
for _ in range(n_experts)
])
# Router network: maps input to expert probabilities
self.router = nn.Linear(d_model, n_experts, bias=False)
def forward(self, x):
"""
x: (batch, seq_len, d_model)
Returns: (batch, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.shape
# Flatten batch and sequence for routing
x_flat = x.view(-1, d_model)
# Compute routing logits
router_logits = self.router(x_flat) # (batch*seq, n_experts)
# Select top-k experts
weights, indices = torch.topk(router_logits, self.top_k, dim=-1)
# Softmax over selected experts only
weights = F.softmax(weights, dim=-1)
# Process with selected experts
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = indices[:, i]
expert_weight = weights[:, i]
# For each expert, accumulate weighted outputs
for e_idx in range(self.n_experts):
mask = (expert_idx == e_idx)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[e_idx](expert_input)
output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output
return output.view(batch_size, seq_len, d_model)
Failure mode: Load imbalance. Without constraints, a few experts receive most tokens while others remain underutilized. This wastes capacity and causes training instability. Addressing this requires auxiliary load balancing losses (covered in Chapter 10).
Failure mode: Memory inefficiency from routing complexity. Naive MoE implementations batch tokens per expert, leading to variable-length sequences and padding overhead. Production implementations use specialized kernels (like MoE from Megatron-LM).
Failure mode: Expert collapse early in training. Router learns to route all tokens to one or two experts before auxiliary losses take effect. Clipping router logits or adding entropy regularization mitigates this.
Create a small MoE with 4 experts and test routing on random input. Print the expert selection counts to verify routing works. Add entropy regularization to the router to encourage wider expert usage.