03. LoRA Theory
LoRA (Low-Rank Adaptation) exploits the observation that fine-tuning large models involves relatively low-rank changes to the original weight matrices. Instead of updating the full weight matrix W (dimensions d×k), LoRA introduces trainable low-rank decomposition matrices A and B while freezing W.
During forward pass, the original computation uses frozen weights W, while the LoRA adaptation adds the product of trainable matrices B and A (each initialized differently) scaled by a rank parameter r. The adapted forward pass becomes: h = Wx + (BA/α)·x, where α is a scaling factor typically set to the rank value.
Matrix A initializes with random Gaussian values while B initializes with zeros. This initialization ensures that at the start of training, the adaptation contributes nothing, and the model behaves identically to the pre-trained version. Training then gradually introduces changes through the low-rank product.
The mathematical intuition comes from the observation that the parameter update ΔW during fine-tuning often has low intrinsic rank. Rather than updating all d×k parameters independently, LoRA constrains ΔW = BA where B is d×r and A is r×k, reducing trainable parameters from d×k to r(d+k).
For a weight matrix with dimensions 4096×4096 and rank r=8, full fine-tuning requires updating over 16 million parameters. LoRA reduces this to approximately 65,000 trainable parameters (roughly 0.4% of the original), with memory savings coming from both reduced parameter count and avoiding optimizer states for frozen parameters.
Calculate the parameter reduction for a 7B model with 40 attention and 128 FFN layers, comparing rank-8 LoRA against full fine-tuning. Compute the percentage reduction and resulting trainable parameter count.
# lora_parameter_calculator.py
import torch
def count_lora_parameters(
hidden_dim: int,
intermediate_size: int,
num_layers: int,
rank: int,
num_attention_heads: int,
num_key_value_heads: int
) -> dict:
"""Calculate LoRA parameters for a transformer model."""
# Attention: Q, K, V projections (Q and V typically targeted)
q_params = num_layers * num_attention_heads * hidden_dim * (hidden_dim // num_attention_heads)
k_params = num_layers * num_key_value_heads * hidden_dim * (hidden_dim // num_attention_heads)
v_params = num_layers * num_key_value_heads * hidden_dim * (hidden_dim // num_attention_heads)
# LoRA parameters for Q and V (target modules)
lora_q = 2 * num_layers * rank * hidden_dim * (hidden_dim // num_attention_heads)
lora_v = 2 * num_layers * rank * hidden_dim * (hidden_dim // num_key_value_heads)
# FFN layers (if targeting gate/up/down projections)
ffn_params = num_layers * (3 * hidden_dim * intermediate_size)
lora_ffn = 2 * num_layers * rank * (hidden_dim + intermediate_size)
total_lora = lora_q + lora_v + lora_ffn
total_full = q_params + k_params + v_params + ffn_params
return {
"lora_trainable": total_lora,
"full_finetune": total_full,
"reduction_percentage": 100 * (1 - total_lora / total_full),
"compression_ratio": total_full / total_lora
}
# Example for a ~7B model
result = count_lora_parameters(
hidden_dim=4096,
intermediate_size=11008,
num_layers=32,
rank=8,
num_attention_heads=32,
num_key_value_heads=32
)
print(f"Trainable: {result['lora_trainable']:,}")
print(f"Reduction: {result['reduction_percentage']:.2f}%")