03. LoRA Theory

Chapter 3 of 24 · 15 min

LoRA (Low-Rank Adaptation) exploits the observation that fine-tuning large models involves relatively low-rank changes to the original weight matrices. Instead of updating the full weight matrix W (dimensions d×k), LoRA introduces trainable low-rank decomposition matrices A and B while freezing W.

During forward pass, the original computation uses frozen weights W, while the LoRA adaptation adds the product of trainable matrices B and A (each initialized differently) scaled by a rank parameter r. The adapted forward pass becomes: h = Wx + (BA/α)·x, where α is a scaling factor typically set to the rank value.

Matrix A initializes with random Gaussian values while B initializes with zeros. This initialization ensures that at the start of training, the adaptation contributes nothing, and the model behaves identically to the pre-trained version. Training then gradually introduces changes through the low-rank product.

The mathematical intuition comes from the observation that the parameter update ΔW during fine-tuning often has low intrinsic rank. Rather than updating all d×k parameters independently, LoRA constrains ΔW = BA where B is d×r and A is r×k, reducing trainable parameters from d×k to r(d+k).

For a weight matrix with dimensions 4096×4096 and rank r=8, full fine-tuning requires updating over 16 million parameters. LoRA reduces this to approximately 65,000 trainable parameters (roughly 0.4% of the original), with memory savings coming from both reduced parameter count and avoiding optimizer states for frozen parameters.

EXERCISE

Calculate the parameter reduction for a 7B model with 40 attention and 128 FFN layers, comparing rank-8 LoRA against full fine-tuning. Compute the percentage reduction and resulting trainable parameter count.

# lora_parameter_calculator.py
import torch

def count_lora_parameters(
    hidden_dim: int,
    intermediate_size: int,
    num_layers: int,
    rank: int,
    num_attention_heads: int,
    num_key_value_heads: int
) -> dict:
    """Calculate LoRA parameters for a transformer model."""
    
    # Attention: Q, K, V projections (Q and V typically targeted)
    q_params = num_layers * num_attention_heads * hidden_dim * (hidden_dim // num_attention_heads)
    k_params = num_layers * num_key_value_heads * hidden_dim * (hidden_dim // num_attention_heads)
    v_params = num_layers * num_key_value_heads * hidden_dim * (hidden_dim // num_attention_heads)
    
    # LoRA parameters for Q and V (target modules)
    lora_q = 2 * num_layers * rank * hidden_dim * (hidden_dim // num_attention_heads)
    lora_v = 2 * num_layers * rank * hidden_dim * (hidden_dim // num_key_value_heads)
    
    # FFN layers (if targeting gate/up/down projections)
    ffn_params = num_layers * (3 * hidden_dim * intermediate_size)
    lora_ffn = 2 * num_layers * rank * (hidden_dim + intermediate_size)
    
    total_lora = lora_q + lora_v + lora_ffn
    total_full = q_params + k_params + v_params + ffn_params
    
    return {
        "lora_trainable": total_lora,
        "full_finetune": total_full,
        "reduction_percentage": 100 * (1 - total_lora / total_full),
        "compression_ratio": total_full / total_lora
    }

# Example for a ~7B model
result = count_lora_parameters(
    hidden_dim=4096,
    intermediate_size=11008,
    num_layers=32,
    rank=8,
    num_attention_heads=32,
    num_key_value_heads=32
)
print(f"Trainable: {result['lora_trainable']:,}")
print(f"Reduction: {result['reduction_percentage']:.2f}%")