16. Low-Latency Optimization
Chapter 16 of 22 · 25 min
Voice AI用户体验依赖于感知延迟。关键路径优化可将端到端延迟从秒级降至百毫秒级别。
Latency Budget Analysis
class LatencyBudget:
def __init__(self):
self.stages = {}
self.total_budget_ms = 200 # Target for real-time feel
def add_stage(self, name: str, duration_ms: float):
self.stages[name] = duration_ms
def report(self) -> dict:
total = sum(self.stages.values())
remaining = self.total_budget_ms - total
return {
"stages": self.stages,
"total_ms": total,
"remaining_ms": remaining,
"status": "ok" if remaining >= 0 else "exceeded"
}
def optimize_targets(self) -> dict[str, float]:
total = sum(self.stages.values())
return {
name: (duration / total) * self.total_budget_ms
for name, duration in self.stages.items()
}
# Example budget
budget = LatencyBudget()
budget.add_stage("vad", 15)
budget.add_stage("asr", 80)
budget.add_stage("llm", 150)
budget.add_stage("tts", 100)
print(budget.report())
# {'stages': {'vad': 15, 'asr': 80, 'llm': 150, 'tts': 100},
# 'total_ms': 345, 'remaining_ms': -145, 'status': 'exceeded'}
Pipelined Processing
Overlap computation across stages to reduce wall-clock time.
import asyncio
from typing import AsyncIterator
class PipelinedVoiceProcessor:
def __init__(self, asr, llm, tts):
self.asr = asr
self.llm = llm
self.tts = tts
self.audio_queue = asyncio.Queue(maxsize=5)
self.transcript_queue = asyncio.Queue(maxsize=3)
self.response_queue = asyncio.Queue(maxsize=3)
async def run(self):
await asyncio.gather(
self._capture_loop(),
self._asr_loop(),
self._llm_loop(),
self._tts_loop()
)
async def _capture_loop(self):
async for chunk in self.audio_stream():
await self.audio_queue.put(chunk)
async def _asr_loop(self):
buffer = b""
while True:
chunk = await self.audio_queue.get()
buffer += chunk
if len(buffer) > self.asr.min_samples:
transcript = await self.asr.transcribe(buffer)
await self.transcript_queue.put(transcript)
buffer = b""
async def _llm_loop(self):
partial_text = ""
while True:
transcript = await self.transcript_queue.get()
partial_text += " " + transcript
if self._is_end_of_turn(transcript):
response = await self.llm.generate(partial_text)
await self.response_queue.put(response)
partial_text = ""
else:
# Stream partial response
partial_response = await self.llm.generate_streaming(partial_text)
await self.response_queue.put(partial_response)
async def _tts_loop(self):
while True:
response = await self.response_queue.get()
audio = await self.tts.synthesize(response)
await self._play_audio(audio)
def _is_end_of_turn(self, transcript: str) -> bool:
return transcript.strip().endswith(("?", "."))
KV Cache Optimization
# PyTorch model's key-value cache management
class OptimizedLLM:
def __init__(self, model, use_cache: bool = True):
self.model = model
self.use_cache = use_cache
self.kv_cache = None
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor = None):
if self.use_cache and self.kv_cache is not None:
# Only compute for new tokens
outputs = self.model(
input_ids[:, -1:],
past_key_values=self.kv_cache,
use_cache=True
)
self.kv_cache = outputs.past_key_values
return outputs.logits
else:
outputs = self.model(input_ids, use_cache=self.use_cache)
if self.use_cache:
self.kv_cache = outputs.past_key_values
return outputs.logits
def reset_cache(self):
self.kv_cache = None
Continuous Batching
class ContinuousBatchingScheduler:
def __init__(self, model, max_batch_size: int = 16):
self.model = model
self.max_batch_size = max_batch_size
self.pending_requests = []
self.running_batches = []
async def add_request(self, request_id: str, input_ids: list[int]):
self.pending_requests.append({
"id": request_id,
"input_ids": input_ids,
"finished": False,
"output_ids": []
})
async def step(self):
# Add pending requests to running batch if space available
while (len(self.running_batches) < self.max_batch_size
and self.pending_requests):
req = self.pending_requests.pop(0)
self.running_batches.append(req)
if not self.running_batches:
return
# Forward pass for all running requests
batch_input = [r["input_ids"] + r["output_ids"] for r in self.running_batches]
logits = await self.model.forward_batch(batch_input)
# Decode one token per request
for req, logit in zip(self.running_batches, logits):
next_token = logit.argmax()
req["output_ids"].append(next_token.item())
if next_token == self.model.eos_token_id:
req["finished"] = True
# Remove finished requests
completed = [r for r in self.running_batches if r["finished"]]
self.running_batches = [r for r in self.running_batches if not r["finished"]]
return completed
Target Latencies by Component
| Component | Target | Maximum |
|---|---|---|
| Audio capture to VAD | 20ms | 50ms |
| VAD to ASR | 30ms | 80ms |
| ASR to first LLM token | 50ms | 150ms |
| LLM inter-token latency | 20ms | 50ms |
| LLM to TTS first audio | 40ms | 100ms |
| TTS streaming buffer | 100ms | 150ms |
EXERCISE
Measure end-to-end latency for a voice pipeline by logging timestamps at each stage. Identify the bottleneck and implement one optimization (pipelining or KV caching) to reduce total latency by at least 20%. Time: 15 minutes.