HOW-TO · RAG

How to Scale FAISS Across Multiple Machines

advanced40 minBy Fredoline Eruo
Target environment
Ubuntu 24.04 · Ollama 0.4.x
PREREQUISITES

Multiple machines with network connectivity, FAISS installed

What this does

A single-machine FAISS index hits RAM capacity and I/O bandwidth ceilings. When the dataset grows beyond one server, shard vectors across multiple machines and fan queries out in parallel. This guide covers the sharding architecture, load balancing, and distributed query aggregation pattern.

Steps

  1. Design the sharding key.

    import hashlib
    
    def shard_key(document_id, num_shards=4):
        hash_bytes = hashlib.sha256(document_id.encode()).digest()
        return int.from_bytes(hash_bytes[:4], byteorder="big") % num_shards
    
    for doc_id in ["doc-001", "doc-002"]:
        print(f"{doc_id} -> shard {shard_key(doc_id)}")
    
  2. Build a separate FAISS index on each machine.

    import faiss, numpy as np
    
    def build_local_index(partition_vectors, dim=768):
        index = faiss.IndexHNSWFlat(dim, M=32)
        index.hnsw.efConstruction = 200
        index.add(partition_vectors.astype("float32"))
        return index
    
    my_vectors = np.random.rand(250000, 768).astype("float32")
    local_idx = build_local_index(my_vectors)
    print(f"Shard built: {local_idx.ntotal} vectors")
    
  3. Expose search via REST API on each shard.

    from fastapi import FastAPI
    import faiss, numpy as np, ollama, uvicorn
    
    app = FastAPI()
    local_index = None
    
    @app.post("/search")
    async def search_endpoint(query_text: str, k: int = 5):
        resp = ollama.embeddings(model="nomic-embed-text", prompt=query_text)
        vec = np.array([resp["embedding"]], dtype="float32")
        distances, labels = local_index.search(vec, k)
        return {"labels": labels[0].tolist(), "distances": distances[0].tolist()}
    
  4. Implement fan-out query aggregation.

    import asyncio, httpx, heapq
    
    SHARD_HOSTS = ["http://192.168.1.10:8000", "http://192.168.1.11:8000"]
    
    async def distributed_search(query_text, k=5):
        async with httpx.AsyncClient(timeout=10.0) as client:
            tasks = [client.post(f"{h}/search", json={"query_text": query_text, "k": k}) for h in SHARD_HOSTS]
            responses = await asyncio.gather(*tasks)
        global_results = []
        for resp in responses:
            data = resp.json()
            for label, dist in zip(data["labels"], data["distances"]):
                global_results.append((dist, label))
        return heapq.nsmallest(k, global_results)
    

Verification

curl -s http://localhost:8000/health 2>/dev/null || echo "Host unreachable"
# Expected: {"status":"ok"}

Common failures

  • Network latency destroys throughput. Co-locate shards in the same availability zone.
  • Inconsistent embedding models. Ensure the query embedding model matches the one used during index construction.
  • Overlapping document IDs. Assert global uniqueness or use prefixed namespaces.
  • Hot shard under load. Monitor query distribution and implement adaptive routing around slow shards.
  • Index file corruption. Persist to local SSD and use checksums to detect corruption.
  • Version mismatch - The installed package or runtime differs from the command shown; check the version first and rerun the smallest verification command.
  • Local environment drift - Another service, virtual environment, model, or path is being used; print the active binary path and configuration before changing the guide steps.

Related guides

RELATED GUIDES