11. Mamba State Space Model
Mamba replaces attention with a selective state space model that processes sequences linearly in context length. Unlike transformers' O(N²) attention, Mamba achieves O(N) complexity while maintaining long-range dependencies.
The Mamba SSM (State Space Model) discretizes continuous parameters using zero-order hold:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MambaBlock(nn.Module):
"""
Mamba Selective State Space Model block.
Based on 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces'
"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# Convolutional extension
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner,
)
# SSM parameters (selective, thus input-dependent)
# D: skip connection (input-dependent)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# dt_proj: input-dependent step size
self.dt_proj = nn.Linear(d_state * 2 + 1, self.d_inner, bias=True)
# A: state matrix (input-dependent during forward)
self.A_log = nn.Parameter(torch.randn(self.d_inner, d_state))
self.A = nn.Parameter(torch.exp(torch.randn(self.d_inner, d_state)))
# B, C: input-dependent during forward (from x_proj)
# Output projection
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.selective_scan = selective_scan
def forward(self, x):
"""
x: (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
# Split input for A/B/C and D branches
xz = self.in_proj(x)
x_inner, z = xz.chunk(2, dim=-1)
# Convolution (over seq dimension)
x_conv = x_inner.transpose(1, 2) # (batch, d_inner, seq)
x_conv = self.conv1d(x_conv)[:, :, :seq_len]
x_conv = x_conv.transpose(1, 2) # (batch, seq, d_inner)
x_conv = F.silu(x_conv)
# Compute SSM parameters (selective: depend on current input)
x_dbl = self.x_proj(x_conv) # (batch, seq, d_state*2 + 1)
dt, B, C = x_dbl.chunk(3, dim=-1)
dt = self.dt_proj(dt)
D = torch.sigmoid(z)
# Selective scan
y = self.selective_scan(
x_conv, dt, self.A, B, C
)
# Output with gating
return self.out_proj(y * D)
Failure mode: State dimension too large. d_state=16 is typical but varies by model size. Larger d_state improves expressiveness but increases memory (O(d_state × d_inner) parameters). Setting d_state=4 severely limits SSM expressiveness.
Failure mode: Discretization instability. The A matrix (state matrix) exponential creates large magnitudes for negative A_log values. Initialize A_log with small absolute values (standard normal is appropriate) or use careful initialization from the reference implementation.
Implement the basic (non-selective) SSM with fixed A, B, C. Test on a copy task (output the input shifted by one position). Verify that the SSM struggles with content-based selection, then explain why selectivity in Mamba addresses this.