RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 10
Custom LLM Architecture Design

10. Load Balancing

Chapter 10 of 24 · 15 min
KEY INSIGHT

Load balancing is a multi-objective optimization problem. The router must simultaneously (1) select good experts for each token and (2) distribute tokens evenly across experts. These objectives occasionally conflict; balancing how much each matters is crucial.

Load balancing ensures tokens distribute evenly across experts. Without balancing, training suffers from expert collapse (few experts handle most tokens) and gradient staleness. Auxiliary load balancing losses address this.

class LoadBalancingLoss(nn.Module):
    """Auxiliary loss for load balancing in MoE."""
    def __init__(self, n_experts, importance_factor=0.01):
        super().__init__()
        self.n_experts = n_experts
        self.importance_factor = importance_factor
    
    def forward(self, expert_counts, router_probs):
        """
        Args:
            expert_counts: (n_experts,) number of tokens routed to each expert
            router_probs: (batch*seq, n_experts) softmax probabilities from router
        
        Returns:
            load_balancing_loss: scalar auxiliary loss
        """
        n_tokens = expert_counts.sum()
        
        # Fraction of tokens routed to each expert
        router fractions = expert_counts / n_tokens
        
        # Mean probability of routing to each expert
        buffer fractions = router_probs.mean(dim=0)
        
        # Auxiliary loss: penalize if fractions != buffer fractions
        # Minimize when both are uniform (fair load balancing)
        loss = self.importance_factor * self.n_experts * (fractions * buffer fractions).sum()
        
        return loss

class BalancedMoELayer(nn.Module):
    """MoE layer with load balancing auxiliary loss."""
    def __init__(self, d_model, d_ffn, n_experts, top_k=2, 
                 importance_factor=0.01, bias=False):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.importance_factor = importance_factor
        
        self.router = ExpertRouter(d_model, n_experts, bias)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ffn, bias=bias),
                nn.GELU(),
                nn.Linear(d_ffn, d_model, bias=bias)
            )
            for _ in range(n_experts)
        ])
        self.loss = LoadBalancingLoss(n_experts, importance_factor)
    
    def forward(self, x, return_aux_loss=True):
        weights, indices, expert_counts = self.router(x, self.top_k)
        
        # Process tokens through selected experts
        output = torch.zeros_like(x.view(-1, x.shape[-1]))
        
        for i in range(self.top_k):
            expert_idx = indices[:, i]
            expert_weight = weights[:, i]
            
            for e_idx in range(self.n_experts):
                mask = (expert_idx == e_idx)
                if mask.any():
                    expert_input = x.view(-1, x.shape[-1])[mask]
                    expert_output = self.experts[e_idx](expert_input)
                    output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output
        
        output = output.view(*x.shape)
        
        if return_aux_loss:
            router_probs = F.softmax(self.router.router(x.view(-1, x.shape[-1])), dim=-1)
            aux_loss = self.loss(expert_counts.float(), router_probs)
            return output, aux_loss
        
        return output

Failure mode: Hyperparameter sensitivity. importance_factor controls load balancing strength—too high and the model optimizes for balance over primary task; too low and experts still collapse. Typical values: 0.01-0.05.

Failure mode: Expert capacity violation. Even with load balancing loss, individual expert forward passes may exceed capacity, especially with gradient updates during training. Token dropping during forward pass loses information with no backward gradient.

Failure mode: Gradient conflicts. The load balancing loss opposes the primary language modeling loss in the router gradient. This can cause router training instability. Gradient clipping on router parameters mitigates this.

EXERCISE

Implement expert capacity as a hard constraint during training: if too many tokens route to an expert, randomly drop excess tokens before the forward pass. Compute the fraction of dropped tokens and verify that load balancing loss decreases.

← Chapter 9
Expert Routing
Chapter 11 →
Mamba State Space Model