12. Selective State Space

Chapter 12 of 24 · 20 min

Mamba's selective state space model generalizes SSM to content-aware sequence modeling. The key innovation: parameters A, B, C, and Δ become input-dependent, enabling selective information propagation.

The selective scan algorithm implements the SSM recurrence:

def selective_scan(u, delta, A, B, C):
    """
    Selective SSM scan (simplified reference implementation).
    
    Args:
        u: (batch, seq, d_inner) - input
        delta: (batch, seq, d_inner) - time step (input-dependent)
        A: (d_inner, d_state) - state matrix
        B: (batch, seq, d_state) - input matrix (input-dependent)
        C: (batch, seq, d_state) - output matrix (input-dependent)
    
    Returns:
        y: (batch, seq, d_inner) - output
    """
    batch, seq, d_inner = u.shape
    d_state = A.shape[1]
    
    # Discretize: ΔA = exp(Δt_A) and ΔB = Δt_B
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (batch, seq, d_inner, d_state)
    deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)
    
    # Parallel prefix scan (in practice: hardware-aware parallel scan)
    # Simplified sequential implementation for clarity:
    y = torch.zeros(batch, seq, d_inner, device=u.device, dtype=u.dtype)
    h = torch.zeros(batch, d_inner, d_state, device=u.device, dtype=u.dtype)
    
    for t in range(seq):
        h = deltaA[:, t] * h + deltaB_u[:, t]
        y[:, t] = (h @ C[:, t].unsqueeze(-1)).squeeze(-1)
    
    return y

def mamba_eq discretize(A, B, C, delta):
    """
    Discretize continuous SSM parameters using zero-order hold.
    Standard SSM discretization:
        c = exp(ΔA) c_prev + (exp(ΔA) - I) A^{-1} B u
    Simplified (linear):
        c = (I + ΔA) c_prev + ΔB u
    """
    # Standard discretization
    deltaA = torch.exp(delta * A)
    deltaB = delta.unsqueeze(-1) * B
    
    return deltaA, deltaB

Failure mode: Sequential scan is O(N). Naive implementation defeats the purpose of a linear-time model. Production implementations use parallel prefix sum (parallel scan) to achieve O(log N) depth with O(N) work:

def parallel_selective_scan(u, delta, A, B, C, chunk_size=32):
    """
    Hardware-aware parallel selective scan.
    Processing in chunks enables parallelism.
    """
    batch, seq, d_inner = u.shape
    d_state = A.shape[1]
    
    # Pad to chunk size
    pad_len = (chunk_size - seq % chunk_size) % chunk_size
    if pad_len > 0:
        u = F.pad(u, (0, 0, 0, pad_len))
        delta = F.pad(delta, (0, 0, 0, pad_len))
    
    # Reshape for chunk processing
    u = u.view(batch, -1, chunk_size, d_inner)
    delta = delta.view(batch, -1, chunk_size, d_inner)
    
    # Process chunks in parallel (simplified: sequential in practice)
    y_chunks = []
    for chunk_idx in range(u.shape[1]):
        y = selective_scan(u[:, chunk_idx], delta[:, chunk_idx], A, B, C)
        y_chunks.append(y)
    
    return torch.cat(y_chunks, dim=1)[:, :seq]

Failure mode: Numerical precision with large sequences. Repeated matrix multiplications accumulate floating-point errors. With seq > 8192, errors accumulate noticeably. Mixed precision (float16) helps, but large models should use float32 accumulation for the state.

Failure mode: d_state must balance memory and expressiveness. Current models use d_state=16 (reccomended). Increasing to 64 or 128 provides little benefit and wastes memory. The SSM state acts as a compressed memory of sequence history.

EXERCISE

Compare Mamba with standard attention on a retrieval task: given a key at position 0, retrieve the corresponding value at position 1000. Vary the key-value distance and measure retrieval accuracy for both models. Explain why long-distance retrieval differs. For course A002 complete: runlocalai.co/learn/A002