RAG技术综述:从基础到高级

检索增强生成从Naive RAG到Modular RAG的演进:分块策略、重排序与生产级模式

引言

检索增强生成(Retrieval-Augmented Generation, RAG)已成为将大语言模型与外部知识结合的标准范式。其核心思想简洁有力:先检索相关信息,再让LLM基于检索结果生成答案。然而,从"能跑通Demo"到"能上生产"之间,存在大量工程细节需要打磨。本文将系统梳理RAG技术从基础到高级的完整演进路径。

RAG发展阶段

三代RAG架构

RAG架构演进

Generation 1: Naive RAG(2023初期)
  Query → Embed → Vector Search → Top-K → LLM → Answer
  特点: 简单直接,但质量不稳定
  问题: 检索噪声大、分块粗糙、无质量控制

Generation 2: Advanced RAG(2024)
  Query → [Rewrite/Expand] → Embed → [Hybrid Search] →
  [Rerank] → [Compress] → LLM → [Verify] → Answer
  特点: 在每个环节加入优化
  改进: 查询改写、混合检索、重排序、答案验证

Generation 3: Modular RAG(2025-2026)
  ┌──────────────────────────────────────────┐
  │  Orchestrator (Router / Agent)            │
  │  ├── Query Understanding Module           │
  │  ├── Retrieval Strategy Selector          │
  │  │   ├── Dense Retrieval                  │
  │  │   ├── Sparse Retrieval (BM25)          │
  │  │   ├── Knowledge Graph Retrieval        │
  │  │   └── SQL / Structured Query           │
  │  ├── Post-Retrieval Processing            │
  │  │   ├── Reranking                        │
  │  │   ├── Compression                      │
  │  │   └── Fusion                           │
  │  ├── Generation Module                    │
  │  └── Evaluation / Feedback Loop           │
  └──────────────────────────────────────────┘
  特点: 模块化、可组合、自适应

分块策略(Chunking)

分块方法对比

方法 实现复杂度 语义完整性 适用场景
固定长度分块 快速原型
按段落/章节 结构化文档
递归字符分割 通用文本
语义分块 最高 质量优先
句子窗口 问答场景
父子分块 层级文档

分块实现

from dataclasses import dataclass

@dataclass
class Chunk:
    text: str
    metadata: dict
    embedding: list[float] = None

class SemanticChunker:
    """Split text into semantically coherent chunks."""

    def __init__(self, embed_fn, threshold: float = 0.5,
                 min_chunk_size: int = 100, max_chunk_size: int = 1500):
        self.embed_fn = embed_fn
        self.threshold = threshold
        self.min_chunk_size = min_chunk_size
        self.max_chunk_size = max_chunk_size

    def chunk(self, text: str, source: str = "") -> list[Chunk]:
        """Split text at semantic breakpoints."""
        sentences = self._split_sentences(text)
        if len(sentences) <= 1:
            return [Chunk(text=text, metadata={"source": source})]

        # Compute embeddings for each sentence
        embeddings = self.embed_fn([s for s in sentences])

        # Find semantic breakpoints using cosine similarity
        breakpoints = []
        for i in range(1, len(embeddings)):
            sim = self._cosine_sim(embeddings[i-1], embeddings[i])
            if sim < self.threshold:
                breakpoints.append(i)

        # Build chunks from breakpoints
        chunks = []
        start = 0
        for bp in breakpoints:
            chunk_text = " ".join(sentences[start:bp]).strip()
            if len(chunk_text) >= self.min_chunk_size:
                chunks.append(Chunk(
                    text=chunk_text,
                    metadata={"source": source, "start_sentence": start}
                ))
            start = bp

        # Last chunk
        last = " ".join(sentences[start:]).strip()
        if last:
            chunks.append(Chunk(text=last, metadata={"source": source}))

        # Merge small chunks, split large ones
        return self._normalize_sizes(chunks)

    def _split_sentences(self, text: str) -> list[str]:
        import re
        return [s.strip() for s in re.split(r'[.!?\n]+', text) if s.strip()]

    def _cosine_sim(self, a, b) -> float:
        import math
        dot = sum(x*y for x, y in zip(a, b))
        norm_a = math.sqrt(sum(x*x for x in a))
        norm_b = math.sqrt(sum(x*x for x in b))
        return dot / (norm_a * norm_b + 1e-8)

    def _normalize_sizes(self, chunks: list[Chunk]) -> list[Chunk]:
        normalized = []
        for chunk in chunks:
            if len(chunk.text) > self.max_chunk_size:
                # Split oversized chunks at sentence boundaries
                parts = self._split_at_size(chunk.text, self.max_chunk_size)
                for p in parts:
                    normalized.append(Chunk(text=p, metadata=chunk.metadata))
            else:
                normalized.append(chunk)
        return normalized

    def _split_at_size(self, text: str, max_size: int) -> list[str]:
        words = text.split()
        parts, current = [], []
        current_len = 0
        for word in words:
            if current_len + len(word) + 1 > max_size and current:
                parts.append(" ".join(current))
                current, current_len = [], 0
            current.append(word)
            current_len += len(word) + 1
        if current:
            parts.append(" ".join(current))
        return parts

检索策略

混合检索

class HybridRetriever:
    """Combine dense (semantic) and sparse (BM25) retrieval."""

    def __init__(self, vector_store, bm25_index,
                 dense_weight: float = 0.7, sparse_weight: float = 0.3):
        self.vector_store = vector_store
        self.bm25_index = bm25_index
        self.dense_weight = dense_weight
        self.sparse_weight = sparse_weight

    def retrieve(self, query: str, top_k: int = 10) -> list[dict]:
        """Retrieve using reciprocal rank fusion of dense + sparse."""
        # Dense retrieval (semantic)
        dense_results = self.vector_store.similarity_search(query, k=top_k * 2)

        # Sparse retrieval (BM25)
        sparse_results = self.bm25_index.search(query, k=top_k * 2)

        # Reciprocal Rank Fusion (RRF)
        k = 60  # RRF constant
        scores = {}

        for rank, doc in enumerate(dense_results):
            doc_id = doc["id"]
            scores[doc_id] = scores.get(doc_id, 0) + self.dense_weight / (k + rank + 1)

        for rank, doc in enumerate(sparse_results):
            doc_id = doc["id"]
            scores[doc_id] = scores.get(doc_id, 0) + self.sparse_weight / (k + rank + 1)

        # Sort by fused score
        ranked = sorted(scores.items(), key=lambda x: -x[1])
        return [{"id": doc_id, "score": score} for doc_id, score in ranked[:top_k]]

查询改写

查询改写策略

原始查询: "公司去年的营收怎么样?"

Strategy 1: HyDE(假设文档嵌入)
  LLM生成假设答案: "XX公司2025年实现营业收入约50亿元,同比增长15%..."
  用假设答案做嵌入检索(而非原始查询)

Strategy 2: Multi-Query(多查询展开)
  子查询1: "公司2025年营业收入总额"
  子查询2: "公司2025年财务报告营收数据"
  子查询3: "公司去年全年收入对比上一年"
  分别检索后合并去重

Strategy 3: Step-back Prompting(回退提问)
  原始: "公司去年的营收怎么样?"
  回退: "公司近三年的财务表现和增长趋势"
  用更宽泛的查询检索更多上下文

重排序(Reranking)

重排序模型选型

模型 类型 性能 延迟 适用
Cohere Rerank API 50-100ms 生产级
bge-reranker-v2 本地 20-50ms 隐私敏感
cross-encoder 本地 中高 30-80ms 通用
ColBERT 本地 中高 10-30ms 低延迟
LLM-as-reranker API/本地 最高 200-500ms 质量优先
class RerankerPipeline:
    """Two-stage retrieval with reranking."""

    def __init__(self, retriever, reranker, first_stage_k: int = 50,
                 final_k: int = 5):
        self.retriever = retriever
        self.reranker = reranker
        self.first_stage_k = first_stage_k
        self.final_k = final_k

    def search(self, query: str) -> list[dict]:
        # Stage 1: Fast retrieval (recall-optimized)
        candidates = self.retriever.retrieve(query, top_k=self.first_stage_k)

        # Stage 2: Cross-encoder reranking (precision-optimized)
        pairs = [(query, doc["text"]) for doc in candidates]
        rerank_scores = self.reranker.score(pairs)

        for doc, score in zip(candidates, rerank_scores):
            doc["rerank_score"] = score

        # Sort by rerank score and return top-K
        candidates.sort(key=lambda x: -x["rerank_score"])
        return candidates[:self.final_k]

生产级RAG模式

模式一:自适应RAG

自适应RAG决策流

用户查询
    │
    ▼
┌─────────────┐    "简单事实查询"    ┌──────────────┐
│ 查询分类器   │──────────────────→  │ 直接检索+生成  │
│ (LLM/规则)  │                     └──────────────┘
│             │    "复杂推理查询"    ┌──────────────┐
│             │──────────────────→  │ 多步检索+COT  │
│             │                     └──────────────┘
│             │    "不需要检索"      ┌──────────────┐
│             │──────────────────→  │ 直接LLM回答   │
└─────────────┘                     └──────────────┘

模式二:Corrective RAG

class CorrectiveRAG:
    """Self-correcting RAG with relevance grading and web fallback."""

    def __init__(self, retriever, llm, web_search):
        self.retriever = retriever
        self.llm = llm
        self.web_search = web_search

    def answer(self, query: str) -> dict:
        # Step 1: Retrieve
        docs = self.retriever.retrieve(query, top_k=5)

        # Step 2: Grade relevance
        relevant_docs = []
        for doc in docs:
            grade = self._grade_relevance(query, doc["text"])
            if grade == "relevant":
                relevant_docs.append(doc)

        # Step 3: Corrective action
        if len(relevant_docs) == 0:
            # No relevant docs: fallback to web search
            web_results = self.web_search.search(query)
            context = "\n".join([r["snippet"] for r in web_results])
            source = "web_search"
        elif len(relevant_docs) < 2:
            # Ambiguous: supplement with web search
            web_results = self.web_search.search(query)
            context = "\n".join(
                [d["text"] for d in relevant_docs] +
                [r["snippet"] for r in web_results[:2]]
            )
            source = "hybrid"
        else:
            context = "\n".join([d["text"] for d in relevant_docs])
            source = "knowledge_base"

        # Step 4: Generate
        answer = self._generate(query, context)

        # Step 5: Hallucination check
        if not self._check_grounded(answer, context):
            answer = self._generate(query, context, strict=True)

        return {"answer": answer, "source": source, "docs_used": len(relevant_docs)}

    def _grade_relevance(self, query: str, doc: str) -> str:
        # Use LLM to grade relevance
        prompt = f"Is this document relevant to the query?\nQuery: {query}\nDocument: {doc}\nAnswer: relevant or irrelevant"
        return self.llm.generate(prompt).strip().lower()

    def _generate(self, query: str, context: str, strict: bool = False) -> str:
        system = "Answer based strictly on the provided context." if strict else "Answer the question using the context."
        return self.llm.generate(f"{system}\n\nContext:\n{context}\n\nQuestion: {query}")

    def _check_grounded(self, answer: str, context: str) -> bool:
        prompt = f"Is this answer fully supported by the context?\nContext: {context}\nAnswer: {answer}\nRespond: yes or no"
        return "yes" in self.llm.generate(prompt).lower()

模式三:Graph RAG

Graph RAG架构

文档 → [实体/关系抽取] → 知识图谱
                              │
                              ▼
查询 → [实体识别] → [图遍历/子图检索] → 结构化上下文
                                              │
                                              ▼
                                    LLM + 图上下文 → 答案

优势:
  多跳推理: A→B→C 的关系链
  全局视角: 社区检测 + 摘要
  可解释性: 推理路径可追溯

评估体系

RAGAS评测框架

指标 含义 计算方法 目标
Faithfulness 答案忠于检索内容 LLM判断每个声明是否有据 >0.85
Answer Relevancy 答案与问题相关 答案生成问题与原问题的相似度 >0.80
Context Precision 检索内容的精准度 相关文档在排序中的位置 >0.75
Context Recall 检索内容的召回率 参考答案的信息被检索覆盖比例 >0.80
class RAGEvaluator:
    """Evaluate RAG pipeline quality using RAGAS-style metrics."""

    def __init__(self, llm, embed_fn):
        self.llm = llm
        self.embed_fn = embed_fn

    def evaluate(self, query: str, answer: str, contexts: list[str],
                 ground_truth: str = None) -> dict:
        scores = {
            "faithfulness": self._faithfulness(answer, contexts),
            "answer_relevancy": self._answer_relevancy(query, answer),
            "context_precision": self._context_precision(query, contexts),
        }
        if ground_truth:
            scores["context_recall"] = self._context_recall(ground_truth, contexts)
        scores["overall"] = sum(scores.values()) / len(scores)
        return scores

    def _faithfulness(self, answer: str, contexts: list[str]) -> float:
        context_str = "\n".join(contexts)
        prompt = (
            f"Extract claims from the answer, then check each against the context.\n"
            f"Context: {context_str}\nAnswer: {answer}\n"
            f"Return the fraction of supported claims (0.0 to 1.0)."
        )
        result = self.llm.generate(prompt)
        try:
            return float(result.strip())
        except ValueError:
            return 0.5

    def _answer_relevancy(self, query: str, answer: str) -> float:
        q_emb = self.embed_fn([query])[0]
        a_emb = self.embed_fn([answer])[0]
        return self._cosine_sim(q_emb, a_emb)

    def _context_precision(self, query: str, contexts: list[str]) -> float:
        q_emb = self.embed_fn([query])[0]
        c_embs = self.embed_fn(contexts)
        sims = [self._cosine_sim(q_emb, c) for c in c_embs]
        # Average precision weighted by rank
        sorted_sims = sorted(enumerate(sims), key=lambda x: -x[1])
        precision_sum = sum(
            (i+1) / (rank+1) for rank, (i, sim) in enumerate(sorted_sims)
            if sim > 0.5
        )
        return precision_sum / max(len(contexts), 1)

    def _context_recall(self, ground_truth: str, contexts: list[str]) -> float:
        context_str = "\n".join(contexts)
        gt_emb = self.embed_fn([ground_truth])[0]
        ctx_emb = self.embed_fn([context_str])[0]
        return self._cosine_sim(gt_emb, ctx_emb)

    def _cosine_sim(self, a, b) -> float:
        import math
        dot = sum(x*y for x, y in zip(a, b))
        return dot / (math.sqrt(sum(x*x for x in a)) * math.sqrt(sum(x*x for x in b)) + 1e-8)

结论

RAG技术已从简单的"检索+生成"管道演进为高度模块化、自适应的知识增强系统。在生产部署中,分块质量、混合检索、重排序和答案验证是影响最终效果的四个关键环节。建议采用渐进式优化策略:先用Naive RAG验证可行性,再逐步引入查询改写、混合检索、重排序等Advanced RAG组件,最后根据业务需求选择性引入Graph RAG、Corrective RAG等高级模式。


Maurice | maurice_wen@proton.me