AI 缓存策略:Semantic Cache 与 Prompt Cache

语义缓存(Embedding 相似度匹配)、精确缓存、Prompt Caching(Anthropic/OpenAI)与缓存失效策略

引言

LLM 推理成本高、延迟大,而实际业务中大量请求是重复或高度相似的。据统计,典型的客服聊天机器人中 30-50% 的问题本质上是相同的。如果能将这些重复请求的结果缓存起来,既能显著降低成本,又能将延迟从秒级降到毫秒级。

但 LLM 缓存与传统 API 缓存有本质区别:用户很少用完全相同的文字提出相同的问题。"怎么退货"和"我要退商品"是不同的字符串,但语义完全一致。这就需要语义缓存(Semantic Cache)。

缓存策略全景

三层缓存体系

┌──────────────────────────────────────────────────┐
│                LLM 缓存金字塔                      │
│                                                  │
│              ┌───────────┐                       │
│              │ Prompt    │  Provider 端           │
│              │ Cache     │  prefix 复用            │
│              │ (API 级)   │  延迟: -50% 首 token    │
│              └─────┬─────┘                       │
│                    │                             │
│           ┌────────▼─────────┐                   │
│           │  Semantic Cache  │  应用端            │
│           │  (语义相似度)     │  embedding 匹配     │
│           │                  │  延迟: <50ms        │
│           └────────┬─────────┘                   │
│                    │                             │
│       ┌────────────▼──────────────┐              │
│       │     Exact Match Cache     │  应用端       │
│       │     (精确字符串匹配)        │  hash 查找    │
│       │                           │  延迟: <5ms   │
│       └───────────────────────────┘              │
└──────────────────────────────────────────────────┘

三种缓存对比

特性 精确匹配 语义缓存 Prompt Cache
匹配方式 字符串 hash Embedding 相似度 前缀字符匹配
命中率 低 (5-15%) 中 (20-40%) 高 (60-80%)
额外成本 Embedding 计算 无(Provider 内置)
延迟减少 100% 100% 50-80%
成本减少 100% 100% 50-90%
实现复杂度 低(API 参数)
适用场景 固定模板查询 自然语言查询 长 System Prompt

精确匹配缓存

实现

# src/cache/exact_cache.py
import hashlib
import json
import time
from typing import Optional
import redis

class ExactMatchCache:
    """Cache LLM responses by exact input hash."""

    def __init__(
        self,
        redis_client: redis.Redis,
        default_ttl: int = 3600,
        prefix: str = "llm:exact:",
    ):
        self.redis = redis_client
        self.default_ttl = default_ttl
        self.prefix = prefix

    def _compute_key(self, messages: list[dict], model: str, temperature: float) -> str:
        """Deterministic hash of the full request."""
        payload = json.dumps({
            "messages": messages,
            "model": model,
            "temperature": temperature,
        }, sort_keys=True, ensure_ascii=False)

        return self.prefix + hashlib.sha256(payload.encode()).hexdigest()

    def get(self, messages: list[dict], model: str, temperature: float) -> Optional[dict]:
        key = self._compute_key(messages, model, temperature)
        cached = self.redis.get(key)
        if cached:
            data = json.loads(cached)
            # Record cache hit for metrics
            self.redis.incr(f"{self.prefix}hits")
            return data
        self.redis.incr(f"{self.prefix}misses")
        return None

    def set(
        self,
        messages: list[dict],
        model: str,
        temperature: float,
        response: dict,
        ttl: Optional[int] = None,
    ) -> None:
        key = self._compute_key(messages, model, temperature)
        self.redis.setex(
            key,
            ttl or self.default_ttl,
            json.dumps(response, ensure_ascii=False),
        )

    def get_hit_rate(self) -> float:
        hits = int(self.redis.get(f"{self.prefix}hits") or 0)
        misses = int(self.redis.get(f"{self.prefix}misses") or 0)
        total = hits + misses
        return hits / total if total > 0 else 0.0

语义缓存

核心原理

查询流程:

1. 用户输入 "怎么申请退货"
2. 计算 embedding: [0.12, -0.34, 0.56, ...]
3. 在向量库中搜索最相似的缓存条目
4. 找到 "如何退货退款" (相似度 0.95)
5. 相似度 > 阈值 (0.90) → 缓存命中
6. 直接返回缓存的 LLM 回答

未命中流程:
1. 用户输入 "你们支持什么支付方式"
2. 计算 embedding
3. 向量库中最相似的是 "怎么退货" (相似度 0.35)
4. 相似度 < 阈值 → 缓存未命中
5. 调用 LLM 获取回答
6. 将 (embedding, 回答) 存入缓存

完整实现

# src/cache/semantic_cache.py
import hashlib
import json
import time
from typing import Optional
from dataclasses import dataclass
import numpy as np

@dataclass
class CacheEntry:
    query: str
    response: dict
    embedding: list[float]
    model: str
    created_at: float
    hit_count: int = 0

class SemanticCache:
    """Semantic similarity-based LLM response cache."""

    def __init__(
        self,
        embedding_model: str = "text-embedding-3-small",
        similarity_threshold: float = 0.92,
        max_entries: int = 100_000,
        ttl_seconds: int = 86400,
    ):
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.max_entries = max_entries
        self.ttl_seconds = ttl_seconds

        # Use Qdrant for vector storage
        from qdrant_client import QdrantClient, models
        self.qdrant = QdrantClient(url="http://localhost:6333")

        # Create collection if not exists
        try:
            self.qdrant.get_collection("semantic_cache")
        except Exception:
            self.qdrant.create_collection(
                collection_name="semantic_cache",
                vectors_config=models.VectorParams(
                    size=1536,  # text-embedding-3-small
                    distance=models.Distance.COSINE,
                ),
            )

    def _get_embedding(self, text: str) -> list[float]:
        import openai
        client = openai.OpenAI()
        response = client.embeddings.create(
            model=self.embedding_model,
            input=text,
        )
        return response.data[0].embedding

    def _extract_query(self, messages: list[dict]) -> str:
        """Extract the semantic query from messages."""
        # Use the last user message as the cache key
        user_messages = [m for m in messages if m["role"] == "user"]
        if not user_messages:
            return ""
        return user_messages[-1]["content"]

    def get(
        self,
        messages: list[dict],
        model: str,
    ) -> Optional[dict]:
        query = self._extract_query(messages)
        if not query:
            return None

        query_embedding = self._get_embedding(query)

        # Search for similar cached queries
        from qdrant_client import models
        results = self.qdrant.search(
            collection_name="semantic_cache",
            query_vector=query_embedding,
            limit=1,
            score_threshold=self.similarity_threshold,
            query_filter=models.Filter(
                must=[
                    models.FieldCondition(
                        key="model",
                        match=models.MatchValue(value=model),
                    ),
                ],
            ),
        )

        if results:
            hit = results[0]
            cached_response = json.loads(hit.payload["response"])

            # Update hit count
            self.qdrant.set_payload(
                collection_name="semantic_cache",
                payload={"hit_count": hit.payload.get("hit_count", 0) + 1},
                points=[hit.id],
            )

            return {
                **cached_response,
                "_cache": {
                    "hit": True,
                    "similarity": hit.score,
                    "original_query": hit.payload["query"],
                },
            }

        return None

    def set(
        self,
        messages: list[dict],
        model: str,
        response: dict,
    ) -> None:
        query = self._extract_query(messages)
        if not query:
            return

        query_embedding = self._get_embedding(query)

        import uuid
        from qdrant_client import models

        self.qdrant.upsert(
            collection_name="semantic_cache",
            points=[
                models.PointStruct(
                    id=str(uuid.uuid4()),
                    vector=query_embedding,
                    payload={
                        "query": query,
                        "response": json.dumps(response, ensure_ascii=False),
                        "model": model,
                        "created_at": time.time(),
                        "hit_count": 0,
                    },
                ),
            ],
        )

Provider Prompt Caching

Anthropic Prompt Caching

Anthropic 的 Prompt Caching 在 API 层面缓存长 System Prompt 的 KV-Cache,避免重复计算:

# Anthropic Prompt Caching
import anthropic

client = anthropic.Anthropic()

# Long system prompt (cached across requests)
system_prompt = """You are a customer service agent for AcmeCorp.
You have access to the following knowledge base:
[... 5000 words of product documentation ...]
[... 3000 words of return policy ...]
[... 2000 words of FAQ ...]
"""

response = client.messages.create(
    model="claude-sonnet-4-20250514",
    max_tokens=1024,
    system=[
        {
            "type": "text",
            "text": system_prompt,
            "cache_control": {"type": "ephemeral"},  # Enable caching
        },
    ],
    messages=[
        {"role": "user", "content": "How do I return a product?"},
    ],
)

# Check cache usage in response
print(f"Input tokens: {response.usage.input_tokens}")
print(f"Cache read tokens: {response.usage.cache_read_input_tokens}")
print(f"Cache creation tokens: {response.usage.cache_creation_input_tokens}")

# First request: cache_creation > 0, cache_read = 0
# Subsequent requests: cache_creation = 0, cache_read > 0 (90% cheaper)

OpenAI Prompt Caching

OpenAI 自动缓存长前缀,无需额外 API 参数:

import openai

client = openai.OpenAI()

# OpenAI automatically caches prefixes >= 1024 tokens
# that are identical across requests
long_system = "..." # 2000+ tokens of instructions

# Request 1: Full price
response1 = client.chat.completions.create(
    model="gpt-4o",
    messages=[
        {"role": "system", "content": long_system},
        {"role": "user", "content": "Question 1"},
    ],
)

# Request 2: Same prefix, 50% discount on cached tokens
response2 = client.chat.completions.create(
    model="gpt-4o",
    messages=[
        {"role": "system", "content": long_system},  # Same prefix
        {"role": "user", "content": "Question 2"},    # Different suffix
    ],
)

# Check cached tokens
print(response2.usage.prompt_tokens_details.cached_tokens)

Prompt Cache 优化技巧

最大化缓存命中的 Prompt 结构:

┌─────────────────────────────────────────┐
│ System Prompt (不变部分, 长且稳定)         │ ← 缓存命中
│   - 角色定义                              │
│   - 知识库内容                            │
│   - 输出格式规范                          │
│   - 工具定义                              │
├─────────────────────────────────────────┤
│ Few-shot Examples (不变部分)              │ ← 缓存命中
│   - 示例 1                               │
│   - 示例 2                               │
│   - 示例 3                               │
├─────────────────────────────────────────┤
│ Dynamic Context (会变部分)               │ ← 不缓存
│   - 检索到的文档                          │
│   - 当前会话历史                          │
│   - 用户输入                              │
└─────────────────────────────────────────┘

原则: 不变的放前面 (被缓存), 变化的放后面

缓存失效策略

多维度失效

# src/cache/invalidation.py
from enum import Enum
from typing import Optional
from datetime import datetime, timedelta

class InvalidationReason(Enum):
    TTL_EXPIRED = "ttl_expired"
    MODEL_UPDATED = "model_updated"
    KNOWLEDGE_UPDATED = "knowledge_updated"
    LOW_QUALITY = "low_quality"
    MANUAL = "manual"

class CacheInvalidator:
    def __init__(self, cache: SemanticCache):
        self.cache = cache

    def invalidate_by_ttl(self, max_age: timedelta) -> int:
        """Remove entries older than max_age."""
        cutoff = (datetime.now() - max_age).timestamp()
        # Delete old entries from Qdrant
        deleted = self.cache.qdrant.delete(
            collection_name="semantic_cache",
            points_selector=models.FilterSelector(
                filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="created_at",
                            range=models.Range(lt=cutoff),
                        ),
                    ],
                ),
            ),
        )
        return deleted.operation_id

    def invalidate_by_model(self, model: str) -> int:
        """Clear cache when model is updated."""
        deleted = self.cache.qdrant.delete(
            collection_name="semantic_cache",
            points_selector=models.FilterSelector(
                filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="model",
                            match=models.MatchValue(value=model),
                        ),
                    ],
                ),
            ),
        )
        return deleted.operation_id

    def invalidate_by_quality(self, min_hit_count: int = 0, max_age_days: int = 30) -> int:
        """Remove low-engagement entries (never hit, old)."""
        cutoff = (datetime.now() - timedelta(days=max_age_days)).timestamp()
        deleted = self.cache.qdrant.delete(
            collection_name="semantic_cache",
            points_selector=models.FilterSelector(
                filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="hit_count",
                            range=models.Range(lte=min_hit_count),
                        ),
                        models.FieldCondition(
                            key="created_at",
                            range=models.Range(lt=cutoff),
                        ),
                    ],
                ),
            ),
        )
        return deleted.operation_id

监控指标

指标 计算方式 目标值 告警阈值
精确缓存命中率 hits / total >10% <5%
语义缓存命中率 hits / total >25% <15%
Prompt Cache 命中率 cached_tokens / total_input >60% <30%
缓存延迟 (P99) 从查询到返回 <50ms >200ms
缓存节省成本 hit_cost_savings / total_cost >20% <10%
缓存错误率 stale_or_wrong / hits <1% >5%

总结

  1. 三层缓存互补:精确匹配处理确定性查询,语义缓存处理自然语言变体,Prompt Cache 处理公共前缀。
  2. 语义缓存的阈值需要调优:太低会返回不相关的结果,太高会降低命中率。0.90-0.95 是通常的安全区间。
  3. Prompt Caching 是最低成本优化:只需要调整 Prompt 结构(不变的放前面),就能获得 50-90% 的输入 Token 折扣。
  4. 缓存失效比缓存命中更重要:错误的缓存比没有缓存更危险,多维度失效策略是必需的。
  5. monitoring 驱动优化:持续监控命中率和节省成本,根据数据调整阈值和策略。

Maurice | maurice_wen@proton.me