09. Memory Persistence

Chapter 9 of 24 · 20 min

Persistence is where memory systems become production-ready. Working memory is RAM; episodic memory stores events; semantic memory holds knowledge. All three need durable storage for production deployments.

Three-tier storage strategy:

class PersistentMemorySystem:
    def __init__(
        self,
        working: WorkingMemory,
        episodic: EpisodicMemory,
        semantic: SemanticMemory,
        session_store: SessionStore
    ):
        self.working = working
        self.episodic = episodic
        self.semantic = semantic
        self.session_store = session_store
    
    def save_session(self, session_id: str) -> None:
        """Persist working memory to episodic storage."""
        events = self.working.messages
        for msg in events:
            self.episodic.record(
                event_type=f"session_{msg.role}",
                payload={"content": msg.content, "tool_call": msg.tool_call},
                session_id=session_id
            )
        self.session_store.save(session_id, self.working)
    
    def load_session(self, session_id: str) -> WorkingMemory:
        """Restore working memory from episodic storage."""
        events = self.episodic.get_session(session_id)
        restored = WorkingMemory()
        for event in sorted(events, key=lambda e: e.timestamp):
            if event.event_type.startswith("session_"):
                role = event.event_type.replace("session_", "")
                restored.add_message(
                    role=role,
                    content=event.payload.get("content"),
                    tool_call=event.payload.get("tool_call")
                )
        return restored

PostgreSQL storage for production:

import asyncpg
from typing import AsyncIterator

class PostgresEpisodeStorage(EpisodeStorage):
    def __init__(self, pool: asyncpg.Pool):
        self.pool = pool
    
    async def save(self, event: EpisodeEvent) -> None:
        async with self.pool.acquire() as conn:
            await conn.execute("""
                INSERT INTO episode_events 
                (timestamp, event_type, payload, session_id)
                VALUES ($1, $2, $3, $4)
            """, event.timestamp, event.event_type, json.dumps(event.payload), event.session_id)
    
    async def get_by_session(self, session_id: str) -> list[EpisodeEvent]:
        async with self.pool.acquire() as conn:
            rows = await conn.fetch("""
                SELECT * FROM episode_events 
                WHERE session_id = $1 
                ORDER BY timestamp ASC
            """, session_id)
            return [self._row_to_event(row) for row in rows]
    
    async def get_recent(self, limit: int) -> list[EpisodeEvent]:
        async with self.pool.acquire() as conn:
            rows = await conn.fetch("""
                SELECT * FROM episode_events 
                ORDER BY timestamp DESC 
                LIMIT $1
            """, limit)
            return [self._row_to_event(row) for row in rows]
    
    async def search(self, query: str, limit: int) -> list[EpisodeEvent]:
        # For full-text search, use PostgreSQL tsvector
        async with self.pool.acquire() as conn:
            rows = await conn.fetch("""
                SELECT * FROM episode_events 
                WHERE payload::text ILIKE $1
                ORDER BY timestamp DESC
                LIMIT $2
            """, f"%{query}%", limit)
            return [self._row_to_event(row) for row in rows]
    
    def _row_to_event(self, row: asyncpg.Record) -> EpisodeEvent:
        return EpisodeEvent(
            timestamp=row["timestamp"],
            event_type=row["event_type"],
            payload=json.loads(row["payload"]),
            session_id=row["session_id"]
        )

Failure mode: session corruption. If the agent crashes mid-session, working memory is lost. Implement a checkpoint that periodically saves working memory to episodic storage—every N tool calls or every N seconds.

class CheckpointManager:
    def __init__(self, memory_system: PersistentMemorySystem, interval: int = 5):
        self.memory_system = memory_system
        self.interval = interval
        self._call_count = 0
    
    def on_tool_executed(self, session_id: str) -> None:
        self._call_count += 1
        if self._call_count % self.interval == 0:
            self.memory_system.save_session(session_id)
EXERCISE

Implement a session recovery system. On startup, check for incomplete sessions in the database, restore them, and present the user with "Welcome back. You left off at: [last message]. Continue?"