18. Scaling Laws
Chapter 18 of 24 · 30 min
Scaling laws predict how model performance changes with compute, data, and parameters. Designing custom architectures requires understanding these relationships to avoid over- or under-investing in model size.
Chinchilla Scaling Law Implementation
import math
def chinchilla_optimal(d_model, n_layers, d_ff=None, compute_budget=1e23):
"""
Chinchilla scaling law: optimal tokens = 20 * parameters
Given compute budget, find optimal model size and training tokens.
FLOPs ≈ 6 * N * D where N=parameters, D=tokens
"""
# If d_ff not specified, assume standard ratio
if d_ff is None:
d_ff = int(d_model * 4)
# Estimate parameters for decoder-only transformer
# This is approximate; actual counts vary by architecture
params = (
12 * d_model * d_model * n_layers + # QKV projections
12 * d_model * d_ff * n_layers + # FFN
2 * d_model * n_layers + # LayerNorms
d_model * 50000 # embeddings (assume 50k vocab)
)
# Chinchilla: optimal D = 20 * N
optimal_tokens = 20 * params
# Check if we have enough compute for this model
required_flops = 6 * params * optimal_tokens
if required_flops > compute_budget:
# Reduce model size proportionally
scale_factor = (compute_budget / required_flops) ** 0.5
params_scaled = params * scale_factor
tokens_scaled = 20 * params_scaled
return {
'params': params_scaled,
'tokens': tokens_scaled,
'scaled_from_optimal': True
}
return {
'params': params,
'tokens': optimal_tokens,
'scaled_from_optimal': False
}
# Example: 1e23 FLOP budget (roughly 1 week on 1024 A100s)
result = chinchilla_optimal(
d_model=4096,
n_layers=32,
compute_budget=1e23
)
print(f"Recommended: {result['params']/1e9:.1f}B params, {result['tokens']/1e9:.1f}B tokens")
Compute-Optimal Frontier
import numpy as np
def compute_frontier(n_params_list, compute_budget):
"""
Compute-Optimal frontier: for each model size,
what's the optimal training duration given compute budget?
Frontier: C = 6 * N * D
For fixed C, D = C / (6 * N)
"""
results = []
for N in n_params_list:
D_optimal = compute_budget / (6 * N)
# Chinchilla suggests D = 20 * N as optimal
D_chinchilla = 20 * N
if D_optimal < D_chinchilla:
# Compute-limited regime: train less than Chinchilla suggests
efficiency = D_optimal / D_chinchilla
else:
# Data-limited regime: train longer but be data-constrained
efficiency = D_chinchilla / D_optimal
results.append({
'params': N,
'compute_optimal_tokens': D_optimal,
'chinchilla_tokens': D_chinchilla,
'efficiency': efficiency
})
return results
# Plot frontier
n_params = [7e9, 13e9, 34e9, 70e9, 140e9] # Various sizes
compute_budget = 1e23 # Fixed compute
frontier = compute_frontier(n_params, compute_budget)
for r in frontier:
print(f"{r['params']/1e9:.1f}B params: "
f"compute-optimal={r['compute_optimal_tokens']/1e9:.1f}B tokens, "
f"chinchilla={r['chinchilla_tokens']/1e9:.1f}B tokens")
Scaling Prediction for Custom Architectures
def predict_performance(current_params, current_tokens, current_loss,
target_params, target_tokens):
"""
Predict loss for a larger model using scaling laws.
Loss ~ a * N^α + b * D^β
Where α ≈ -0.07, β ≈ -0.095 (empirically determined)
These exponents vary by task and architecture.
"""
a = current_loss / (current_params ** -0.07)
b = current_loss / (current_tokens ** -0.095)
pred_loss = a * (target_params ** -0.07) + b * (target_tokens ** -0.095)
return pred_loss
# Example: predict improvement from 7B to 70B
current = {'params': 7e9, 'tokens': 100e9, 'loss': 1.8}
target = {'params': 70e9, 'tokens': 200e9} # Scale up model and data
pred_loss = predict_performance(**current, **target)
print(f"Predicted loss: {pred_loss:.3f}")
# Inverse: given target loss, how much compute needed?
def compute_needed(target_loss, n_params):
"""
Given target loss and model size, estimate required tokens.
"""
a = current_loss / (current_params ** -0.07)
b = target_loss - a * (n_params ** -0.07)
tokens = (b / b) ** (1 / -0.095) # Simplified
return tokens
Neural Scaling Laws
class ScalingLawAnalyzer:
"""
Analyze and predict scaling behavior for custom architectures.
"""
def __init__(self):
self.training_history = []
def log_metrics(self, step, params, tokens, loss, throughput):
self.training_history.append({
'step': step,
'params': params,
'tokens': tokens,
'loss': loss,
'throughput': throughput
})
def fit_power_law(self):
"""
Fit power law: L(N, D) ≈ α * N^β + γ * D^δ
"""
import scipy.optimize as opt
losses = [h['loss'] for h in self.training_history]
params = [h['params'] for h in self.training_history]
tokens = [h['tokens'] for h in self.training_history]
def loss_fn(coeffs):
alpha, beta, gamma, delta = coeffs
preds = [
alpha * (p ** beta) + gamma * (t ** delta)
for p, t in zip(params, tokens)
]
return sum((l - p) ** 2 for l, p in zip(losses, preds))
# Initial guess from literature
result = opt.minimize(loss_fn, [1e-3, -0.07, 1e-3, -0.095])
return {
'alpha': result.x[0],
'beta': result.x[1],
'gamma': result.x[2],
'delta': result.x[3]
}
def predict(self, n_params, n_tokens, coeffs):
alpha, beta, gamma, delta = coeffs
return alpha * (n_params ** beta) + gamma * (n_tokens ** delta)
Failure Mode: Incorrect Parameter Counting
# BUG: Undercounting parameters leads to wrong scaling predictions
def incorrect_param_count(d_model, n_layers, vocab_size=50000):
"""Missing many parameter sources"""
return (
12 * d_model * d_model * n_layers + # QKV (only attention)
d_model * vocab_size # embeddings
# Missing: FFN (4x d_model), layer norms, output head
)
def correct_param_count(d_model, n_layers, vocab_size=50000):
"""Complete parameter count for decoder transformer"""
# Attention parameters
attn_params = 12 * d_model * d_model * n_layers
# FFN parameters (typically 4x d_model)
d_ff = d_model * 4
ffn_params = 2 * d_model * d_ff * n_layers
# Layer norms (2 per layer)
ln_params = 2 * 2 * d_model * n_layers
# Embedding and output head (tied)
embed_params = d_model * vocab_size
# Final layer norm
final_ln = 2 * d_model
return {
'attention': attn_params,
'ffn': ffn_params,
'layer_norms': ln_params,
'embeddings': embed_params,
'final_ln': final_ln,
'total': attn_params + ffn_params + ln_params + embed_params + final_ln
}
# Mistral 7B has ~7.24B params, not 7B exactly
counts = correct_param_count(4096, 32, 32000)
print(f"Total: {counts['total']/1e9:.2f}B params")
EXERCISE
Using the compute_frontier function, generate a table showing optimal training tokens for model sizes from 1B to 180B parameters, assuming a compute budget of 10²³ FLOPs. Plot the compute-optimal frontier.