10. Load Balancing
Load balancing ensures tokens distribute evenly across experts. Without balancing, training suffers from expert collapse (few experts handle most tokens) and gradient staleness. Auxiliary load balancing losses address this.
class LoadBalancingLoss(nn.Module):
"""Auxiliary loss for load balancing in MoE."""
def __init__(self, n_experts, importance_factor=0.01):
super().__init__()
self.n_experts = n_experts
self.importance_factor = importance_factor
def forward(self, expert_counts, router_probs):
"""
Args:
expert_counts: (n_experts,) number of tokens routed to each expert
router_probs: (batch*seq, n_experts) softmax probabilities from router
Returns:
load_balancing_loss: scalar auxiliary loss
"""
n_tokens = expert_counts.sum()
# Fraction of tokens routed to each expert
router fractions = expert_counts / n_tokens
# Mean probability of routing to each expert
buffer fractions = router_probs.mean(dim=0)
# Auxiliary loss: penalize if fractions != buffer fractions
# Minimize when both are uniform (fair load balancing)
loss = self.importance_factor * self.n_experts * (fractions * buffer fractions).sum()
return loss
class BalancedMoELayer(nn.Module):
"""MoE layer with load balancing auxiliary loss."""
def __init__(self, d_model, d_ffn, n_experts, top_k=2,
importance_factor=0.01, bias=False):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
self.importance_factor = importance_factor
self.router = ExpertRouter(d_model, n_experts, bias)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ffn, bias=bias),
nn.GELU(),
nn.Linear(d_ffn, d_model, bias=bias)
)
for _ in range(n_experts)
])
self.loss = LoadBalancingLoss(n_experts, importance_factor)
def forward(self, x, return_aux_loss=True):
weights, indices, expert_counts = self.router(x, self.top_k)
# Process tokens through selected experts
output = torch.zeros_like(x.view(-1, x.shape[-1]))
for i in range(self.top_k):
expert_idx = indices[:, i]
expert_weight = weights[:, i]
for e_idx in range(self.n_experts):
mask = (expert_idx == e_idx)
if mask.any():
expert_input = x.view(-1, x.shape[-1])[mask]
expert_output = self.experts[e_idx](expert_input)
output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output
output = output.view(*x.shape)
if return_aux_loss:
router_probs = F.softmax(self.router.router(x.view(-1, x.shape[-1])), dim=-1)
aux_loss = self.loss(expert_counts.float(), router_probs)
return output, aux_loss
return output
Failure mode: Hyperparameter sensitivity. importance_factor controls load balancing strength—too high and the model optimizes for balance over primary task; too low and experts still collapse. Typical values: 0.01-0.05.
Failure mode: Expert capacity violation. Even with load balancing loss, individual expert forward passes may exceed capacity, especially with gradient updates during training. Token dropping during forward pass loses information with no backward gradient.
Failure mode: Gradient conflicts. The load balancing loss opposes the primary language modeling loss in the router gradient. This can cause router training instability. Gradient clipping on router parameters mitigates this.
Implement expert capacity as a hard constraint during training: if too many tokens route to an expert, randomly drop excess tokens before the forward pass. Compute the fraction of dropped tokens and verify that load balancing loss decreases.