05. Dataset Streaming

Chapter 5 of 18 · 20 min

Large datasets don't fit in RAM. Streaming loads data on-demand, but naive streaming kills performance with excessive I/O latency.

Memory-Mapped Datasets

Memory mapping lets you access disk data as if it were in memory without loading it all:

import numpy as np
import mmap
from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
    """Memory-efficient streaming for large numpy arrays."""
    
    def __init__(self, data_path: str, shard_pattern: str = "shard_{:04d}.npy"):
        self.data_path = Path(data_path)
        self.shard_pattern = shard_pattern
        self.shards = sorted(self.data_path.glob("shard_*.npy"))
        
    def __len__(self):
        return sum(self._count_samples(shard) for shard in self.shards)
    
    def _count_samples(self, shard_path):
        # Fast count without loading data
        with open(shard_path, 'rb') as f:
            return np.lib.format.read_array_header(f)[0][0]
    
    def __iter__(self):
        for shard in self.shards:
            data = np.load(shard, mmap_mode='r')  # Memory map, don't load
            indices = np.random.permutation(len(data))
            for idx in indices:
                yield data[idx]

WebDataset for Distributed Streaming

WebDataset shards data into tar files, which are ideal for distributed training where different workers need different shards:

import webdataset as wds

dataset = (
    wds.WebDataset("s3://my-bucket/data/{000000..000999}.tar")
    .shuffle(1000)
    .decode("pil")
    .to_tuple("input.jpg", "target.pt")
    .map_tuple(train_transform, lambda x: x)  # Apply transforms
    .batched(32, partial=False)
)

loader = wds.WebLoader(dataset, num_workers=4, batch_size=None)

Avoiding Epoch Boundaries

Streaming datasets require explicit handling for reproducible shuffling across epochs:

def epoch_iterator(loader, seed=42):
    """Reproducible epoch with different shuffle each time."""
    g = torch.Generator()
    g.manual_seed(seed)
    
    for batch in loader:
        yield batch

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Measure throughput of your current dataset with time.time() around the loader loop. If below 500 samples/second on a fast SSD, implement mmap-based loading and re-benchmark.