16. Low-Latency Optimization

Chapter 16 of 22 · 25 min

Voice AI用户体验依赖于感知延迟。关键路径优化可将端到端延迟从秒级降至百毫秒级别。

Latency Budget Analysis

class LatencyBudget:
    def __init__(self):
        self.stages = {}
        self.total_budget_ms = 200  # Target for real-time feel
    
    def add_stage(self, name: str, duration_ms: float):
        self.stages[name] = duration_ms
    
    def report(self) -> dict:
        total = sum(self.stages.values())
        remaining = self.total_budget_ms - total
        
        return {
            "stages": self.stages,
            "total_ms": total,
            "remaining_ms": remaining,
            "status": "ok" if remaining >= 0 else "exceeded"
        }
    
    def optimize_targets(self) -> dict[str, float]:
        total = sum(self.stages.values())
        return {
            name: (duration / total) * self.total_budget_ms
            for name, duration in self.stages.items()
        }

# Example budget
budget = LatencyBudget()
budget.add_stage("vad", 15)
budget.add_stage("asr", 80)
budget.add_stage("llm", 150)
budget.add_stage("tts", 100)
print(budget.report())
# {'stages': {'vad': 15, 'asr': 80, 'llm': 150, 'tts': 100},
#  'total_ms': 345, 'remaining_ms': -145, 'status': 'exceeded'}

Pipelined Processing

Overlap computation across stages to reduce wall-clock time.

import asyncio
from typing import AsyncIterator

class PipelinedVoiceProcessor:
    def __init__(self, asr, llm, tts):
        self.asr = asr
        self.llm = llm
        self.tts = tts
        self.audio_queue = asyncio.Queue(maxsize=5)
        self.transcript_queue = asyncio.Queue(maxsize=3)
        self.response_queue = asyncio.Queue(maxsize=3)
    
    async def run(self):
        await asyncio.gather(
            self._capture_loop(),
            self._asr_loop(),
            self._llm_loop(),
            self._tts_loop()
        )
    
    async def _capture_loop(self):
        async for chunk in self.audio_stream():
            await self.audio_queue.put(chunk)
    
    async def _asr_loop(self):
        buffer = b""
        while True:
            chunk = await self.audio_queue.get()
            buffer += chunk
            
            if len(buffer) > self.asr.min_samples:
                transcript = await self.asr.transcribe(buffer)
                await self.transcript_queue.put(transcript)
                buffer = b""
    
    async def _llm_loop(self):
        partial_text = ""
        while True:
            transcript = await self.transcript_queue.get()
            partial_text += " " + transcript
            
            if self._is_end_of_turn(transcript):
                response = await self.llm.generate(partial_text)
                await self.response_queue.put(response)
                partial_text = ""
            else:
                # Stream partial response
                partial_response = await self.llm.generate_streaming(partial_text)
                await self.response_queue.put(partial_response)
    
    async def _tts_loop(self):
        while True:
            response = await self.response_queue.get()
            audio = await self.tts.synthesize(response)
            await self._play_audio(audio)
    
    def _is_end_of_turn(self, transcript: str) -> bool:
        return transcript.strip().endswith(("?", "."))

KV Cache Optimization

# PyTorch model's key-value cache management
class OptimizedLLM:
    def __init__(self, model, use_cache: bool = True):
        self.model = model
        self.use_cache = use_cache
        self.kv_cache = None
    
    @torch.no_grad()
    def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor = None):
        if self.use_cache and self.kv_cache is not None:
            # Only compute for new tokens
            outputs = self.model(
                input_ids[:, -1:],
                past_key_values=self.kv_cache,
                use_cache=True
            )
            self.kv_cache = outputs.past_key_values
            return outputs.logits
        else:
            outputs = self.model(input_ids, use_cache=self.use_cache)
            if self.use_cache:
                self.kv_cache = outputs.past_key_values
            return outputs.logits
    
    def reset_cache(self):
        self.kv_cache = None

Continuous Batching

class ContinuousBatchingScheduler:
    def __init__(self, model, max_batch_size: int = 16):
        self.model = model
        self.max_batch_size = max_batch_size
        self.pending_requests = []
        self.running_batches = []
    
    async def add_request(self, request_id: str, input_ids: list[int]):
        self.pending_requests.append({
            "id": request_id,
            "input_ids": input_ids,
            "finished": False,
            "output_ids": []
        })
    
    async def step(self):
        # Add pending requests to running batch if space available
        while (len(self.running_batches) < self.max_batch_size 
               and self.pending_requests):
            req = self.pending_requests.pop(0)
            self.running_batches.append(req)
        
        if not self.running_batches:
            return
        
        # Forward pass for all running requests
        batch_input = [r["input_ids"] + r["output_ids"] for r in self.running_batches]
        logits = await self.model.forward_batch(batch_input)
        
        # Decode one token per request
        for req, logit in zip(self.running_batches, logits):
            next_token = logit.argmax()
            req["output_ids"].append(next_token.item())
            
            if next_token == self.model.eos_token_id:
                req["finished"] = True
        
        # Remove finished requests
        completed = [r for r in self.running_batches if r["finished"]]
        self.running_batches = [r for r in self.running_batches if not r["finished"]]
        
        return completed

Target Latencies by Component

Component Target Maximum
Audio capture to VAD 20ms 50ms
VAD to ASR 30ms 80ms
ASR to first LLM token 50ms 150ms
LLM inter-token latency 20ms 50ms
LLM to TTS first audio 40ms 100ms
TTS streaming buffer 100ms 150ms
EXERCISE

Measure end-to-end latency for a voice pipeline by logging timestamps at each stage. Identify the bottleneck and implement one optimization (pipelining or KV caching) to reduce total latency by at least 20%. Time: 15 minutes.