RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 17
Custom LLM Architecture Design

17. Training Stability

Chapter 17 of 24 · 25 min
KEY INSIGHT

Training stability is not a one-time configuration but a continuous monitoring process. Implement gradient monitoring, checkpoint saving on anomalies, and conservative initialization. A 7B model trained for 100K steps represents thousands of dollars—invest in stability infrastructure.

Training instability is the primary failure mode for custom architectures. Gradient explosions, loss spikes, and NaN values can destroy weeks of compute investment.

Gradient Clipping Implementation

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_

class GradientController:
    """
    Manages gradient dynamics for stable training.
    Custom architectures often need stricter clipping than standard.
    """
    def __init__(self, max_norm=1.0, clip_cooldown=100):
        self.max_norm = max_norm
        self.clip_count = 0
        self.clip_cooldown = clip_cooldown
        self.last_clip_time = 0
    
    def clip_if_needed(self, model, step):
        """Clip gradients and track clipping frequency"""
        total_norm = clip_grad_norm_(model.parameters(), self.max_norm)
        
        if total_norm > self.max_norm:
            self.clip_count += 1
        
        # Alert if clipping frequency is too high
        if step - self.last_clip_time > self.clip_cooldown:
            if self.clip_count > 10:
                print(f"WARNING: {self.clip_count} clips in last {self.clip_cooldown} steps")
            self.clip_count = 0
            self.last_clip_time = step
        
        return total_norm

# Usage in training loop
controller = GradientController(max_norm=0.5)  # Stricter for custom arch

for step, batch in enumerate(dataloader):
    outputs = model(batch)
    loss = compute_loss(outputs, batch)
    
    optimizer.zero_grad()
    loss.backward()
    
    # Check before clipping
    total_norm = sum(p.grad.norm() for p in model.parameters() if p.grad is not None)
    if total_norm > 100:
        print(f"Gradient explosion at step {step}: norm={total_norm:.2f}")
    
    # Clip and step
    controller.clip_if_needed(model, step)
    optimizer.step()

Learning Rate Warmup for Custom Architectures

def get_lr_with_warmup(step, warmup_steps, base_lr, min_lr=1e-6, schedule='linear'):
    """
    Learning rate schedule with warmup.
    Custom architectures often need longer warmup.
    """
    if step < warmup_steps:
        # Linear warmup
        return base_lr * step / warmup_steps
    
    # Post-warmup schedules
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    
    if schedule == 'linear':
        return base_lr - (base_lr - min_lr) * progress
    elif schedule == 'cosine':
        return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * progress))
    elif schedule == 'constant':
        return base_lr
    else:
        raise ValueError(f"Unknown schedule: {schedule}")

class CustomArchitectureTrainer:
    def __init__(self, model, optimizer, warmup_steps=2000, base_lr=1e-4):
        self.model = model
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.base_lr = base_lr
        self.step = 0
    
    def step(self, batch):
        self.step += 1
        
        lr = get_lr_with_warmup(self.step, self.warmup_steps, self.base_lr)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        outputs = self.model(batch)
        loss = outputs['loss']
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.optimizer.zero_grad()
        
        return {'loss': loss.item(), 'lr': lr}

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

class MixedPrecisionTrainer:
    """
    Training with bf16 mixed precision for custom architectures.
    bf16 has same range as fp32 but less precision.
    Large custom architectures especially benefit from memory savings.
    """
    def __init__(self, model, optimizer, base_lr=1e-4):
        self.model = model
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.scaler = GradScaler()
        self.step = 0
    
    def step(self, batch):
        self.step += 1
        
        lr = get_lr_with_warmup(self.step, 2000, self.base_lr)
        for pg in self.optimizer.param_groups:
            pg['lr'] = lr
        
        # Forward with autocast
        with autocast(dtype=torch.bfloat16):
            outputs = self.model(batch)
            loss = outputs['loss']
        
        # Check for NaN
        if torch.isnan(loss):
            # Save checkpoint for debugging
            torch.save({
                'step': self.step,
                'model_state': self.model.state_dict(),
                'batch': batch
            }, f'nan_checkpoint_step{self.step}.pt')
            raise ValueError(f"NaN loss at step {self.step}")
        
        # Backward with scaled loss
        self.scaler.scale(loss).backward()
        
        # Unscale before clipping
        self.scaler.unscale_(self.optimizer)
        
        clip_grad_norm_(self.model.parameters(), 1.0)
        
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        
        return loss.item()

Initialization Strategies

class StableMLP(nn.Module):
    """
    MLP block with careful initialization for training stability.
    """
    def __init__(self, d_model, d_ff, init_method='kaiming'):
        super().__init__()
        self.up = nn.Linear(d_model, d_ff)
        self.down = nn.Linear(d_ff, d_model)
        
        # Careful initialization prevents gradient issues
        if init_method == 'kaiming':
            nn.init.kaiming_normal_(self.up.weight, nonlinearity='relu')
            nn.init.zeros_(self.down.weight)
            nn.init.zeros_(self.up.bias)
            nn.init.zeros_(self.down.bias)
        elif init_method == 'xavier':
            nn.init.xavier_normal_(self.up.weight)
            nn.init.xavier_normal_(self.down.weight)
        elif init_method == 'small':
            nn.init.normal_(self.up.weight, std=0.02)
            nn.init.normal_(self.down.weight, std=0.02)
    
    def forward(self, x):
        return self.down(F.silu(self.up(x)))

# Custom attention initialization
class StableAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        
        # Scale initial weights to prevent attention saturation
        nn.init.normal_(self.qkv.weight, std=0.02)
        nn.init.zeros_(self.proj.weight)
    
    def forward(self, x):
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        # ... attention computation ...
        return self.proj(out)

Failure Mode: Embedding Norm Explosion

# BUG: Unbounded embedding values lead to NaN in attention
class UnboundedEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        # No initialization - defaults to N(0, 1)
        # With vocab_size=100k, rare tokens can have large norms
    
    def forward(self, tokens):
        return self.embedding(tokens)
        # tokens with rare IDs produce vectors with norm ~sqrt(d_model)
        # With d_model=4096, norms can be ~64
        # Stacks of layers amplify this quickly

# FIX: Bound the embedding norm
class BoundedEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_norm=1.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, max_norm=max_norm)
        nn.init.normal_(self.embedding.weight, std=0.02)
    
    def forward(self, tokens):
        # F.normalize ensures norm <= max_norm
        return F.normalize(self.embedding(tokens), p=2, dim=-1) * max_norm
EXERCISE

Implement a training loop that simulates gradient explosion by artificially multiplying gradients by 10 at step 500. Verify that your stability mechanisms catch and recover from this without losing earlier progress.

← Chapter 16
Sliding Window Attention
Chapter 18 →
Scaling Laws