09. Expert Routing
Expert routing determines which experts process each token. The router is a learnable linear layer that outputs logits, from which we select top-k experts via softmax-weighted selection.
class ExpertRouter(nn.Module):
"""Expert router with capacity and importance penalties."""
def __init__(self, d_model, n_experts, bias=True):
super().__init__()
self.n_experts = n_experts
self.router = nn.Linear(d_model, n_experts, bias=bias)
def forward(self, x, top_k=2, capacity_factor=1.0):
"""
Route tokens to top-k experts.
Args:
x: (batch, seq, d_model)
top_k: number of experts per token
capacity_factor: max tokens per expert relative to mean
Returns:
routing_weights: (batch*seq, top_k) normalized weights
expert_indices: (batch*seq, top_k) expert indices
load_factor: auxiliary load info for balancing
"""
x_flat = x.view(-1, x.shape[-1])
# Router logits (importance scores)
logits = self.router(x_flat)
# Selection plus capacity enforcement
weights, indices = self._select_topk_with_capacity(
logits, top_k, capacity_factor
)
return weights, indices
def _select_topk_with_capacity(self, logits, top_k, capacity_factor):
"""Select top-k experts with capacity constraints."""
n_tokens = logits.shape[0]
capacity = int(n_tokens * top_k / self.n_experts * capacity_factor)
# Select top-k
weights, indices = torch.topk(logits, k=top_k, dim=-1)
# Count tokens per expert (for load balancing loss)
expert_counts = torch.zeros(self.n_experts, device=logits.device)
for i in range(self.n_experts):
expert_counts[i] = (indices == i).sum()
# Normalize weights
weights = F.softmax(weights, dim=-1)
return weights, indices, expert_counts
Failure mode: Capacity exceeded. If more tokens than capacity are routed to an expert, some must be dropped. The standard approach (used in Switch Transformer) routes to top_k and drops the rest, but this loses information. Implementations often use auxiliary load balancing to prevent capacity violations.
Failure mode: Numerical overflow in softmax. Router logits can become very large or very negative. If logits are extreme, softmax produces NaN. Clip logits to a reasonable range (e.g., [-5, 5]) as a safeguard.
Failure mode: Greedy routing not globally optimal. Top-k selection is local per-token, not global. Two tokens might route to the same expert while another expert sits idle. The load balancer addresses this at training time, but inference can still exhibit this behavior.
Implement a router with an auxiliary routing entropy loss. The goal is to maximize entropy while maintaining routing quality. Train a small MoE with and without entropy regularization and compare expert utilization distribution.