本文来源于数据从业者全栈知识库,更多体系化内容请访问知识库。
理论讲完了,代码才是考场。 本文从零实现一个生产可用的 Corrective RAG 系统——能自动评估检索质量,质量不够就调 Web Search,生成后还要查幻觉。 用 LangGraph 把这些逻辑串成图,用 FastAPI 包成 API,拿来就能用。
目录
- #为什么选 Corrective RAG 作为实战目标
- #完整系统架构
- #第一步:环境搭建
- #第二步:State 定义
- #第三步:各节点实现
- #第四步:条件边(决策逻辑)
- #第五步:图的编译与执行
- #第六步:FastAPI 封装
- #生产级优化
- #效果评估:与 Naive RAG 的对比
- #常见问题与调试技巧
为什么选 Corrective RAG 作为实战目标
Agentic RAG进阶架构 介绍了四种 Agentic RAG 架构。选 Corrective RAG 作为工程实战,原因很直接:
实用性最高:大多数企业知识库的核心痛点就是”覆盖不全”——内部文档有,但员工还会问外部信息。Corrective RAG 正好解决这个问题。
改造成本合理:不需要微调模型(Self-RAG 原版需要),不需要建知识图谱(Graph RAG 成本高),只需要在现有 RAG 管道上加评估层和 Web Search fallback。
LangGraph 天生适配:Corrective RAG 的条件分支逻辑(CORRECT / INCORRECT / AMBIGUOUS)用 LangGraph 的条件边来表达非常自然。
完整系统架构
flowchart TD
START(["用户问题"]) --> REWRITE["查询改写节点<br>rewrite_query"]
REWRITE --> RETRIEVE["向量检索节点<br>retrieve"]
RETRIEVE --> GRADE["文档评分节点<br>grade_documents"]
GRADE --> DECISION{"检索质量决策<br>decide_to_generate"}
DECISION -->|"质量良好<br>score ≥ 0.6"| GENERATE["答案生成节点<br>generate_answer"]
DECISION -->|"质量不足<br>score < 0.6"| WEBSEARCH["Web搜索节点<br>web_search"]
WEBSEARCH --> MERGE["合并文档节点<br>merge_docs"]
MERGE --> GENERATE
GENERATE --> HALLCHECK{"幻觉检测<br>check_hallucination"}
HALLCHECK -->|"发现幻觉<br>且迭代次数 < 3"| GENERATE
HALLCHECK -->|"无幻觉 or<br>迭代超限"| ANSWERCHECK{"答案质量检测<br>check_answer_quality"}
ANSWERCHECK -->|"回答了问题"| OUTPUT(["输出最终答案"])
ANSWERCHECK -->|"没有回答问题<br>且迭代次数 < 3"| REWRITE
style START fill:#4CAF50,color:#fff
style OUTPUT fill:#4CAF50,color:#fff
style WEBSEARCH fill:#FF9800,color:#fff
style HALLCHECK fill:#F44336,color:#fff
这个架构有三个关键循环保护点:
- 幻觉检测 → 重新生成(最多 3 次)
- 答案质量不足 → 重新改写查询(最多 3 次)
- 全局
iterations计数器,超限强制输出,防止死循环
第一步:环境搭建
# 核心依赖pip install langgraph==0.2.28 \ langchain-openai==0.2.5 \ langchain-community==0.3.5 \ langchain-core==0.3.15 \ weaviate-client==4.8.1 \ tavily-python==0.5.0 \ fastapi==0.115.4 \ uvicorn==0.32.0 \ python-dotenv==1.0.1
# 可选:RAGAS 评估框架pip install ragas==0.2.6# config.py - 统一配置管理import osfrom dotenv import load_dotenv
load_dotenv()
# LLM 配置OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")LLM_TEMPERATURE = 0 # 评估和生成都用 0,保持确定性
# Weaviate 向量库配置WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://localhost:8080")WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY")WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION", "KnowledgeBase")
# Tavily Web Search 配置TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# 系统控制参数MAX_ITERATIONS = 3 # 最大循环次数,防止死循环RELEVANCE_THRESHOLD = 0.6 # 低于此分数触发 Web SearchCACHE_TTL = 3600 # 相同问题缓存 1 小时第二步:State 定义
LangGraph 的核心是 State,它在图的所有节点之间流转,记录整个执行过程的状态。
from typing import TypedDict, List, Optional, Annotatedfrom langchain_core.documents import Documentimport operator
class GraphState(TypedDict): """ Corrective RAG 的全局状态
所有节点通过读写这个 State 来通信, 不要用全局变量或者把数据存到节点外面。 """ # 输入 question: str # 用户原始问题(不变)
# 查询改写 query: str # 改写后的查询(用于检索)
# 检索结果 documents: List[Document] # 向量库检索到的文档 web_results: List[Document] # Web Search 的结果 final_documents: List[Document] # 合并后的最终文档集
# 评分 relevance_score: float # 检索质量综合得分(0-1) graded_documents: List[dict] # 每个文档的详细评分
# 生成结果 generation: str # 当前生成的答案
# 质量检测 hallucination_check: str # "supported" 或 "hallucination" answer_quality: str # "useful" 或 "not_useful"
# 控制变量 iterations: int # 当前循环次数 web_search_triggered: bool # 是否触发了 Web Search(用于日志)为什么不用普通的 dict?TypedDict 让 IDE 有类型提示,减少拼写错误。在 LangGraph 里,每个节点返回的是 State 的部分更新,框架会自动合并到全局 State。
第三步:各节点实现
节点 1:查询改写(rewrite_query)
用户的原始提问往往表达不够精确。改写的目标是让查询更适合向量检索。
from langchain_core.prompts import ChatPromptTemplatefrom langchain_openai import ChatOpenAIfrom langchain_core.output_parsers import StrOutputParserfrom config import LLM_MODEL, LLM_TEMPERATURE
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
REWRITE_PROMPT = ChatPromptTemplate.from_messages([ ("system", """你是一个查询优化专家,负责将用户的自然语言问题改写为更适合向量检索的查询。
改写规则: 1. 保留问题的核心意图,不要改变语义 2. 展开缩写和口语表达("怎么搞" → "如何实现") 3. 添加相关的专业术语(如果能判断领域的话) 4. 如果问题包含多个子问题,只保留最核心的那个 5. 输出单一的改写后查询,不要解释
直接输出改写后的查询,不要任何前缀。"""), ("human", "原始问题:{question}")])
rewrite_chain = REWRITE_PROMPT | llm | StrOutputParser()
def rewrite_query(state: GraphState) -> dict: """查询改写节点""" question = state["question"]
# 首次进入或迭代重试时都改写 if state.get("iterations", 0) == 0: # 首次:基于原始问题改写 rewritten = rewrite_chain.invoke({"question": question}) else: # 重试:基于原始问题 + 上次答案失败的信息改写 rewritten = rewrite_chain.invoke({ "question": f"""原始问题:{question} 上次生成的答案没有很好地回答这个问题,请换一个角度改写查询。""" })
return { "query": rewritten.strip(), "iterations": state.get("iterations", 0) + 1 }节点 2:向量检索(retrieve)
import weaviatefrom langchain_weaviate import WeaviateVectorStorefrom langchain_openai import OpenAIEmbeddingsfrom config import WEAVIATE_URL, WEAVIATE_API_KEY, WEAVIATE_COLLECTION
# 初始化 Weaviate 客户端(复用连接,不要在每次请求时重新连接)weaviate_client = weaviate.connect_to_local( host="localhost", port=8080, grpc_port=50051,)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = WeaviateVectorStore( client=weaviate_client, index_name=WEAVIATE_COLLECTION, text_key="content", embedding=embeddings,)
# 使用混合检索(向量 + BM25),比纯向量检索效果更好retriever = vectorstore.as_retriever( search_type="hybrid", # Weaviate 支持混合搜索 search_kwargs={ "k": 5, # 返回 top-5 文档 "alpha": 0.7, # 0=纯BM25, 1=纯向量, 0.7偏向语义 })
def retrieve(state: GraphState) -> dict: """向量检索节点""" query = state["query"] documents = retriever.invoke(query)
return {"documents": documents}节点 3:文档评分(grade_documents)
这是 Corrective RAG 的核心节点——用 LLM 评估每个检索到的文档的质量。
from langchain_core.prompts import ChatPromptTemplatefrom langchain_core.output_parsers import JsonOutputParserfrom langchain_openai import ChatOpenAIfrom langchain_core.documents import Documentfrom typing import Listfrom config import LLM_MODEL, RELEVANCE_THRESHOLD
llm = ChatOpenAI(model=LLM_MODEL, temperature=0)
GRADER_PROMPT = ChatPromptTemplate.from_messages([ ("system", """你是一个专业的检索质量评估员。
评估检索到的文档片段是否能帮助回答用户问题。
返回 JSON 格式,不要包含其他内容: {{ "score": 0.0到1.0之间的数值, "reason": "评分理由,一句话" }}
评分标准: - 1.0:文档直接包含问题的精确答案 - 0.8:文档高度相关,包含回答问题所需的核心信息 - 0.6:文档相关,包含部分有用信息 - 0.4:文档略相关,有少量参考价值 - 0.2以下:文档基本不相关,对回答没有帮助
严格评分,不要过于宽松。"""), ("human", "用户问题:{question}\n\n文档内容:\n{document}")])
grader_chain = GRADER_PROMPT | llm | JsonOutputParser()
def grade_documents(state: GraphState) -> dict: """文档评分节点""" question = state["question"] documents = state["documents"]
graded_docs = [] scores = []
for doc in documents: try: result = grader_chain.invoke({ "question": question, "document": doc.page_content[:2000] # 限制长度,避免超出上下文 }) score = float(result.get("score", 0)) graded_docs.append({ "document": doc, "score": score, "reason": result.get("reason", "") }) scores.append(score) except Exception: # 评分失败的文档给低分,不要崩溃整个流程 graded_docs.append({ "document": doc, "score": 0.3, "reason": "评分失败,给予默认低分" }) scores.append(0.3)
# 综合分:取评分最高的几个文档的平均值 top_scores = sorted(scores, reverse=True)[:3] avg_score = sum(top_scores) / len(top_scores) if top_scores else 0
# 过滤:只保留分数高于 0.5 的文档 filtered_docs = [ item["document"] for item in graded_docs if item["score"] >= 0.5 ]
return { "graded_documents": graded_docs, "relevance_score": avg_score, "final_documents": filtered_docs # 初始设为过滤后的文档 }节点 4:Web Search(web_search)
from langchain_community.tools.tavily_search import TavilySearchResultsfrom langchain_core.documents import Documentfrom config import TAVILY_API_KEYimport os
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
web_search_tool = TavilySearchResults( max_results=3, search_depth="advanced", include_answer=True, include_raw_content=False, include_images=False,)
def web_search(state: GraphState) -> dict: """Web Search 节点:当本地检索质量不足时触发""" query = state["query"]
try: results = web_search_tool.invoke({"query": query})
web_docs = [] for r in results: if isinstance(r, dict) and "content" in r: web_docs.append(Document( page_content=r["content"], metadata={ "source": r.get("url", "web_search"), "type": "web_search", "title": r.get("title", ""), } ))
return { "web_results": web_docs, "web_search_triggered": True, } except Exception as e: # Web Search 失败不应该让整个流程崩溃 return { "web_results": [], "web_search_triggered": True, }
def merge_docs(state: GraphState) -> dict: 本文作者:Elazer (石头)
原文链接:https://ss-data.cc/posts/kb-agentic-rag
版权声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。