12. Selective State Space
Mamba's selective state space model generalizes SSM to content-aware sequence modeling. The key innovation: parameters A, B, C, and Δ become input-dependent, enabling selective information propagation.
The selective scan algorithm implements the SSM recurrence:
def selective_scan(u, delta, A, B, C):
"""
Selective SSM scan (simplified reference implementation).
Args:
u: (batch, seq, d_inner) - input
delta: (batch, seq, d_inner) - time step (input-dependent)
A: (d_inner, d_state) - state matrix
B: (batch, seq, d_state) - input matrix (input-dependent)
C: (batch, seq, d_state) - output matrix (input-dependent)
Returns:
y: (batch, seq, d_inner) - output
"""
batch, seq, d_inner = u.shape
d_state = A.shape[1]
# Discretize: ΔA = exp(Δt_A) and ΔB = Δt_B
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (batch, seq, d_inner, d_state)
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)
# Parallel prefix scan (in practice: hardware-aware parallel scan)
# Simplified sequential implementation for clarity:
y = torch.zeros(batch, seq, d_inner, device=u.device, dtype=u.dtype)
h = torch.zeros(batch, d_inner, d_state, device=u.device, dtype=u.dtype)
for t in range(seq):
h = deltaA[:, t] * h + deltaB_u[:, t]
y[:, t] = (h @ C[:, t].unsqueeze(-1)).squeeze(-1)
return y
def mamba_eq discretize(A, B, C, delta):
"""
Discretize continuous SSM parameters using zero-order hold.
Standard SSM discretization:
c = exp(ΔA) c_prev + (exp(ΔA) - I) A^{-1} B u
Simplified (linear):
c = (I + ΔA) c_prev + ΔB u
"""
# Standard discretization
deltaA = torch.exp(delta * A)
deltaB = delta.unsqueeze(-1) * B
return deltaA, deltaB
Failure mode: Sequential scan is O(N). Naive implementation defeats the purpose of a linear-time model. Production implementations use parallel prefix sum (parallel scan) to achieve O(log N) depth with O(N) work:
def parallel_selective_scan(u, delta, A, B, C, chunk_size=32):
"""
Hardware-aware parallel selective scan.
Processing in chunks enables parallelism.
"""
batch, seq, d_inner = u.shape
d_state = A.shape[1]
# Pad to chunk size
pad_len = (chunk_size - seq % chunk_size) % chunk_size
if pad_len > 0:
u = F.pad(u, (0, 0, 0, pad_len))
delta = F.pad(delta, (0, 0, 0, pad_len))
# Reshape for chunk processing
u = u.view(batch, -1, chunk_size, d_inner)
delta = delta.view(batch, -1, chunk_size, d_inner)
# Process chunks in parallel (simplified: sequential in practice)
y_chunks = []
for chunk_idx in range(u.shape[1]):
y = selective_scan(u[:, chunk_idx], delta[:, chunk_idx], A, B, C)
y_chunks.append(y)
return torch.cat(y_chunks, dim=1)[:, :seq]
Failure mode: Numerical precision with large sequences. Repeated matrix multiplications accumulate floating-point errors. With seq > 8192, errors accumulate noticeably. Mixed precision (float16) helps, but large models should use float32 accumulation for the state.
Failure mode: d_state must balance memory and expressiveness. Current models use d_state=16 (reccomended). Increasing to 64 or 128 provides little benefit and wastes memory. The SSM state acts as a compressed memory of sequence history.
Compare Mamba with standard attention on a retrieval task: given a key at position 0, retrieve the corresponding value at position 1000. Vary the key-value distance and measure retrieval accuracy for both models. Explain why long-distance retrieval differs. For course A002 complete: runlocalai.co/learn/A002