动手学RAG

0 阅读24分钟

简单流程:

详细流程:

我们为什么要使用RAG:

  • 消除幻觉,提供事实依据: RAG 强制大模型在回答问题前,先去一个可靠的知识库里“查阅资料”。大模型的工作原理是“概率预测”而非“数据库检索”。所以一些幻觉是不可避免的。
  • 知识实时更新,成本极低: 当世界发生变化或公司政策更新时,你不需要重新训练大模型(这需要几百万美元和数月时间),你只需要把新的文档扔进 RAG 的向量数据库里即可。
  • 数据隐私与安全: 很多企业不敢把核心数据喂给公共大模型训练。通过 RAG,私有数据只存在于你本地的数据库中,大模型只是在运行时作为“阅读理解工具”来处理你临时喂给它的文本片段。
  • 可溯源性: RAG 系统可以告诉你,它的回答是基于知识库里的哪一篇文章、哪一页生成的,方便人类核实。

RAG的结构:

一个标准的 RAG 系统分为离线数据处理在线检索生成两条链路:

链路一:数据处理与入库(离线准备阶段)

  1. 文档解析 (Document Parsing): 将各种格式的数据(PDF、Word、HTML、内部 Wiki)提取为纯文本。
  2. 文本分块 (Chunking): 大模型的上下文窗口有限,不能把整本书塞进去。我们需要把长文本切分成几百个字的小段落(Chunks)。
  3. 向量化 (Embedding): 使用 Embedding 模型,将这些文字块转化成计算机能理解的“多维浮点数向量”。语义越相近的段落,在向量空间里的距离越近。
  4. 向量数据库 (Vector Database): 将转化好的向量以及对应的原始文本存储起来(如 Pinecone, Milvus, Chroma)。

链路二:检索与生成(在线问答阶段)

  1. 查询转换与重写 (Query Formulation): 用户提问往往很简短或指代不清(例如:“这会导致什么后果?”)。系统需要结合对话历史,把问题重写为一个清晰的独立问题。
  2. 检索 (Retrieval): 将用户的提问也通过 Embedding 模型转化为向量,然后在向量数据库中进行“相似度搜索”,找出最相关的 Top-K 个文档块。
  3. 重排序 (Reranking)(进阶模块): 向量检索有时不够精准,可以使用一个专门的 Reranker 模型对找出来的 Top-K 文档进行二次打分和精确排序,剔除无关干扰信息。
  4. 提示词组装 (Prompt Construction): 将用户的原始问题,加上检索出来的高质量文档块,组装成一个巨大的 Prompt。比如: “请根据以下参考资料回答用户问题。参考资料:[段落1]、[段落2]... 用户问题:[问题]”
  5. 大模型生成 (Generation): LLM 接收到这个“开卷考题”,阅读参考资料,总结并生成最终答案。

我们可以通过一个简单的项目来学习整个RAG的流程:

构建一个关于大模型技术及其发展趋势的知识图谱系统。

可以把相关的经典论文或技术博客作为数据源喂给它。

完整项目代码包括数据下载脚本 :github.com/ahuang0324/…

vanilla_rag_project/
├── pipeline_stage1.py   # 离线入库:PDF → ChromaDB + Neo4j
├── pipeline_stage2.py   # 在线检索:LangGraph 工作流
├── run_chat.py          # 交互式问答入口
├── run_test.py          # 批量测试脚本
├── data/                # 源 PDF(不纳入版本管理)
├── parsed_md/           # PDF 解析生成的中间 Markdown(不纳入版本管理)
├── chroma_db/           # ChromaDB 持久化文件(不纳入版本管理)
├── neo4j_data/          # Neo4j 持久化文件(不纳入版本管理)
├── output.txt           # 问答测试输出文件 
└── requirement.txt

GraphRAG(基于图谱的检索增强生成)系统

具体来说,这个项目的核心脉络可以概括为以下三个阶段:

  1. 高质量的“原材料”准备(即链路一): 把海量、排版复杂的大模型前沿学术论文、技术博客等非结构化数据,通过 MinerU 等工具清洗成干净的标准化文本,并进行合理的切块。
  2. 核心枢纽——知识抽取与图谱构建: 这是整个系统最见功底的地方。我们需要设计一套智能工作流(Agentic Workflow),让大模型阅读这些文本块,精准抽取出核心“实体”(比如某种具体的微调算法、某个具体的注意力机制架构)以及它们之间的“关系”(比如 A 模型基于 B 架构改进,或者 C 算法有效解决了 D 缺陷),并将其存入图数据库(如 Neo4j)中。
  3. 精准推理与问答: 最终,当用户询问某个复杂的大模型技术演进路径时,系统不仅能在 Chroma 等向量数据库中进行语义搜索,更能借助图数据库进行多跳(Multi-hop)的逻辑推理,给出极具深度和脉络的回答。

链路一:

原始文档 → 可读文本 → 小块知识 → 向量表示 → 入库保存

1. 文档解析 (Document Parsing)

这是将人类可读的复杂版面转化为机器可处理的纯文本和结构化数据的第一步。 这一步的关键是 要保留完整且正确的结构。所以核心可能是格式的提取。

  • 技术原理:( HTML / Wiki PDF Word / Docx )
    • 格式提取: 解析特定文件格式(如 PDF 的底层操作流、Word 的 XML 结构)来提取文本。
    • 版面分析 (Layout Analysis): 识别文档中的标题、段落、页眉、页脚、多列文本以及表格和图片的边界。
    • 光学字符识别 (OCR): 针对扫描版 PDF 或图片,将图像像素转化为文本。
  • 发展演进:
    • 早期: 基于规则和简单正则提取。遇到复杂的双栏 PDF 或带有大量表格的文档会直接崩溃,乱码或顺序错乱。
    • 中期: 引入传统的机器学习模型进行版面识别(如识别哪部分是图,哪部分是表)。
    • 现代: 多模态与视觉大模型 (VLM) 的介入。现在的解析工具开始能够“看懂”复杂的财报表格,甚至能提取图片中的信息转化为图文描述,实现真正的“所见即所得”解析。
  • 常用工具/框架:
    • Unstructured.io: RAG 领域最著名的开源解析库之一,支持几乎所有文件格式。
    • LlamaParse (by LlamaIndex): 专门针对复杂 PDF(特别是带复杂表格的文档)优化的解析服务。
    • PyMuPDF / PDFPlumber: 传统的 Python 库,处理纯文本 PDF 速度极快。
    • Marker / Surya: 近期非常火的开源高精度 PDF 转 Markdown 工具,版面还原度极高。
2. 文本分块 (Chunking)

由于大模型有 Context Window(上下文窗口)限制,且越长的文本由于“注意力稀释”越容易丢失细节,因此需要对解析后的长文本进行切片。

  • 技术原理:
    • 滑动窗口策略: 为了防止切分时把一句话或一个完整的意思从中间硬生生截断,相邻的 Chunk 之间通常会保留一定的重叠部分(Overlap)。
    • 粒度平衡: Chunk 太小,会丢失上下文语境(例如只是一句“他同意了”,检索出来毫无意义);Chunk 太大,会引入太多无关噪声,降低检索精度。
  • 发展演进:
    • 1.0 固定长度切分 (Fixed-size Chunking): 简单粗暴地按字符数(如每 500 字一块)切分,容易割裂语义。
    • 2.0 基于规则的切分 (Rule-based Chunking): 根据自然段、标点符号(句号、换行符)、Markdown 标题(H1, H2, H3)进行切分,保证了句子的完整性。
    • 3.0 语义切分 (Semantic Chunking): 计算相邻句子的向量相似度,如果相似度骤降,说明话题变了,就在这里切一刀。
    • 4.0 代理切分 (Agentic / LLM-based Chunking): 也是目前的前沿方向,让小参数量的 LLM 先通读文档,由 AI 自己判断哪里该切分,甚至为每个 Chunk 生成独立的摘要。
  • 常用工具/框架:
    • LangChain / LlamaIndex 自带的 Splitters:RecursiveCharacterTextSplitter(递归字符切分器,最常用、性价比最高)。
    • Semantic Chunker: 一些高级 RAG 框架中内置的基于相似度的切分器。
3. 向量化 (Embedding)

这是 RAG 系统的“翻译官”,它将人类语言翻译成计算机能进行数学计算的高维坐标(向量)。

它学到的是一种 语义映射能力: 主题相近的句子,向量更近 。 意思相反或无关的句子,向量更远

  • 技术原理:
    • 使用深度学习模型,将一段 Chunk 映射为一个高维的浮点数数组(例如 768 维或 1536 维)。
    • 核心逻辑是“语义相近,空间相邻”: 比如“苹果手机”和“iPhone”在字面上完全不同,但在高维空间里它们的向量距离会非常近;而“苹果手机”和“吃苹果”距离就会相对较远。
  • 发展演进:
    • 第一代: 基于词频的稀疏表示(TF-IDF, BM25),只能做字面匹配,不懂语义。
    • 第二代: 静态词向量(Word2Vec, GloVe),一个词只有一个固定向量,无法解决多义词问题(如“苹果”)。
    • 第三代: 动态上下文向量(BERT 系列),开始能够理解语境。
    • 现代: 专为检索优化的对比学习模型 (Contrastive Learning) 。不仅支持超长上下文(支持将整个长段落打成一个向量),还出现了多语言、多模态 Embedding(能把图片和文本打入同一个向量空间)。
  • 常用工具/模型:
    • 闭源 API: OpenAI 的 text-embedding-3-small/large(行业标杆,性价比高)、Cohere Embeddings。
    • 开源模型 (HuggingFace): BAAI(智源研究院)的 bge-m3 系列(中文及多语言效果极佳)、阿里巴巴的 GTE 系列、Jina AI 的向量模型。
4. 向量数据库 (Vector Database)

当你有几百万甚至上亿个 Chunk 的向量时,把它们全放在内存里进行两两遍历对比(暴力搜索)是不现实的。这就需要专门的向量数据库。 “语义搜索引擎”。 用户问题来了之后:

1. 把问题转成向量 2. 去数据库里搜索最接近的若干个 chunk 向量 3. 返回 top-k 结果,例如返回最相关的 5 段文本。

  • 技术原理:
    • 核心算法 ANN (Approximate Nearest Neighbor): 向量数据库不追求 100% 找到绝对最近的那个点,而是通过构建特殊的索引(如图索引 HNSW、聚类倒排索引 IVFFlat),在极短的时间内(毫秒级)找到“足够近”的一批点,实现了精度与速度的权衡。
    • 标量混合检索: 现代向量数据库不仅能存向量,还能存元数据(Metadata,如文档作者、时间、类别)。在检索时,可以先通过时间线过滤(比如只找 2023 年之后的),再进行向量相似度搜索。
  • 发展演进:
    • 算法库时代: 以 Facebook 开源的 Faiss 为代表,只是一个算法库,没有数据库的增删改查、持久化和高可用特性。
    • 传统数据库插件时代: 像 PostgreSQL 的 pgvector 插件、Elasticsearch 的密向量检索支持。适合已经有成熟数据库基础设施的团队。
    • 原生向量数据库时代 (Native Vector DB): 专门为亿级甚至百亿级向量检索设计的云原生架构,支持分布式、动态扩缩容。
  • 常用工具/框架:
    • Milvus: 老牌开源向量数据库,架构复杂但极度成熟,适合企业级海量数据。
    • Pinecone: 闭源的 Serverless 云服务,开发者体验极好,免运维(AI 初创公司最爱)。
    • Chroma / Qdrant: 轻量级、易于上手,非常适合本地开发和中小型项目。
    • Weaviate: 同样优秀的开源向量数据库,对各种 RAG 框架兼容性极好。

第一个阶段我们采用的技术栈:MinerU (解析)+ BGE-M3 (向量化) + Neo4j (图存储) + Chroma (向量存储)

PDF 文件
   │
   ▼ Step 1: parse_pdf_with_mineru()
Markdown 文本
   │
   ▼ Step 2: chunk_markdown()
文本块列表 (Chunks)
   │
   ▼ Step 3&4: ingest_to_databases()
   ├──► ChromaDB(向量检索库)
   └──► Neo4j(图数据库)
import os
import uuid
from typing import List, Dict
from dotenv import load_dotenv

load_dotenv()

# LangChain 切分工具
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter

# 向量化模型
from FlagEmbedding import BGEM3FlagModel

# 数据库
import chromadb
from neo4j import GraphDatabase

class OfflineDataPipeline:
    def __init__(self, neo4j_uri, neo4j_user, neo4j_pwd, chroma_path=None, embedding_model=None):
        if chroma_path is None:
            chroma_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
        print("初始化流水线...")
        # 1. 初始化 BGE-M3 向量模型 (支持多语言和长文本)
        if embedding_model is not None:
            self.embedding_model = embedding_model
        else:
            self.embedding_model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
        
        # 2. 初始化 ChromaDB 向量数据库
        self.chroma_client = chromadb.PersistentClient(path=chroma_path)
        self.vector_collection = self.chroma_client.get_or_create_collection(name="llm_knowledge_chunks")
        
        # 3. 初始化 Neo4j 图数据库连接 (可选,不可用时跳过图谱写入)
        try:
            self.neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_pwd))
            self.neo4j_driver.verify_connectivity()
            self.neo4j_available = True
        except Exception as e:
            print(f"⚠️  Neo4j 不可用,跳过图谱写入: {e}")
            self.neo4j_driver = None
            self.neo4j_available = False
        
        print("所有组件初始化完成!")

    def parse_pdf_with_mineru(self, pdf_path: str, output_dir: str) -> str:
        """
        步骤 1: 使用 pymupdf (fitz) 将 PDF 解析为 Markdown 文本。
        """
        import fitz
        print(f"正在解析 PDF: {pdf_path}")
        base_name = os.path.basename(pdf_path).replace(".pdf", "")
        md_file_path = os.path.join(output_dir, "magic-pdf", base_name, "auto", f"{base_name}.md")
        
        if not os.path.exists(md_file_path):
            os.makedirs(os.path.dirname(md_file_path), exist_ok=True)
            doc = fitz.open(pdf_path)
            md_lines = []
            for page in doc:
                blocks = page.get_text("blocks")
                for b in sorted(blocks, key=lambda x: (x[1], x[0])):
                    text = b[4].strip()
                    if text:
                        md_lines.append(text)
                md_lines.append("")
            doc.close()
            md_text = "\n".join(md_lines)
            with open(md_file_path, "w", encoding="utf-8") as f:
                f.write(md_text)
        
        with open(md_file_path, 'r', encoding='utf-8') as f:
            return f.read()

    def chunk_markdown(self, markdown_text: str) -> List[Dict]:
        """
        步骤 2: 将 Markdown 文本按标题和长度进行切块 (Chunking)
        """
        print("正在进行文本切块...")
        # 首先按 Markdown 标题切分,保留逻辑结构
        headers_to_split_on = [
            ("#", "Header 1"),
            ("##", "Header 2"),
            ("###", "Header 3"),
        ]
        markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
        md_header_splits = markdown_splitter.split_text(markdown_text)
        
        # 如果某个标题下的内容依然过长,再用递归字符切分器进行细切
        chunk_size = 500
        chunk_overlap = 50
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, chunk_overlap=chunk_overlap
        )
        final_splits = text_splitter.split_documents(md_header_splits)
        
        chunks_data = []
        for i, split in enumerate(final_splits):
            chunk_id = str(uuid.uuid4())
            chunks_data.append({
                "chunk_id": chunk_id,
                "text": split.page_content,
                "metadata": split.metadata # 包含 Header 1/2/3 信息
            })
        return chunks_data

    def ingest_to_databases(self, doc_name: str, chunks_data: List[Dict]):
        """
        步骤 3 & 4: 向量化并进行双库写入 (Chroma + Neo4j)
        """
        print(f"正在向量化并入库,共 {len(chunks_data)} 个 Chunks...")
        texts = [c["text"] for c in chunks_data]
        ids = [c["chunk_id"] for c in chunks_data]
        metadatas = [{"doc_name": doc_name, **c["metadata"]} for c in chunks_data]
        
        # --- A. 向量化 (BGE-M3) ---
        # BGE-M3 返回一个字典,dense_vecs 是稠密向量
        embeddings = self.embedding_model.encode(texts, batch_size=12, max_length=1024)['dense_vecs'].astype('float32')
        
        # --- B. 写入 ChromaDB ---
        self.vector_collection.add(
            documents=texts,
            embeddings=embeddings.tolist(),
            metadatas=metadatas,
            ids=ids
        )
        
        # --- C. 写入 Neo4j (创建文档与块的层级关系) ---
        if not self.neo4j_available:
            print("⚠️  跳过 Neo4j 写入(服务不可用)")
            print(f"文档 {doc_name} 入库完成!")
            return
        doc_id = str(uuid.uuid4())
        with self.neo4j_driver.session() as session:
            # 1. 创建源文档节点
            session.run(
                "MERGE (d:Document {name: $doc_name}) SET d.id = $doc_id",
                doc_name=doc_name, doc_id=doc_id
            )
            # 2. 批量创建 Chunk 节点并连接到源文档
            session.run(
                """
                MATCH (d:Document {name: $doc_name})
                UNWIND $chunks AS chunk
                CREATE (c:Chunk {id: chunk.chunk_id, text: chunk.text})
                CREATE (c)-[:PART_OF]->(d)
                """,
                doc_name=doc_name,
                chunks=[{"chunk_id": c["chunk_id"], "text": c["text"]} for c in chunks_data]
            )
        print(f"文档 {doc_name} 入库完成!")

    def run(self, pdf_path: str):
        """执行完整流水线"""
        doc_name = os.path.basename(pdf_path)
        output_dir = "./parsed_md"
        
        # 1. 解析
        md_text = self.parse_pdf_with_mineru(pdf_path, output_dir)
        # 2. 切块
        chunks = self.chunk_markdown(md_text)
        # 3. 入库
        self.ingest_to_databases(doc_name, chunks)
        print("✅ 链路一处理完毕。")

    def close(self):
        if self.neo4j_driver is not None:
            self.neo4j_driver.close()

# --- 运行示例 ---
if __name__ == "__main__":
    # 配置你的 Neo4j 账号密码
    pipeline = OfflineDataPipeline(
        neo4j_uri=os.getenv("NEO4J_URI",  "bolt://localhost:7687"),
        neo4j_user=os.getenv("NEO4J_USER", "neo4j"),
        neo4j_pwd=os.getenv("NEO4J_PWD",  "password"),
    )
    
    # 假设你有一篇大模型论文的 PDF
    # pipeline.run("./data/attention_is_all_you_need.pdf")
    
    pipeline.close()

链路二:

1. 查询转换与重写 (Query Formulation)

用户提问往往极其口语化、上下文缺失,或者意图模糊。直接拿原问题去检索,效果通常很差。

  • 技术原理: 利用 LLM 作为“翻译官”,在检索前对用户的 Query 进行扩写、拆解或多义词补充。
  • 发展历史:
    • 早期: 关键词提取(去除停用词如“的”、“是”)。
    • 中期: 同义词扩展(Query Expansion)。
    • 现代(大模型时代): * HyDE (Hypothetical Document Embeddings): 让大模型先“盲答”生成一篇假想文档,然后用这篇假想文档的向量去数据库里搜索,这能极大提高语义命中率。
      • Multi-Query: 让 LLM 把一个问题改写成 3-5 个不同表述的问题,分别检索后再合并结果。
  • 常用工具: LangChain/LlamaIndex 中的 Query Rewrite 模块;自定义的 LLM Prompt。
  • ****核心补充(针对图谱项目):意图路由 (Query Routing) 在你的系统中,不能只有重写,还需要路由。当用户问“FedSDG 算法的提出机构是哪家?”时,系统需要判断:这个问题是应该去 Chroma 查向量(语义检索),还是去 Neo4j 查图谱关系(Cypher 查询),亦或是双路召回。这是一个典型的 Agent 决策节点。
2. 检索 (Retrieval)

将重写后的查询转化为机器语言,去数据库里“捞”数据。

  • 技术原理: 将 Query 向量化(使用与链路一相同的 Embedding 模型),计算 Query 向量与数据库中 Chunk 向量的余弦相似度(Cosine Similarity),取分数最高的前 K 个(Top-K)。
  • 发展历史:
    • 稀疏检索 (BM25): 传统的全文搜索,基于词频,极其精准但不懂语义。
    • 稠密检索 (Dense Retrieval): 纯向量搜索,懂语义但可能忽略核心专有名词。
    • 混合检索 (Hybrid Search): 目前工业界标配。BM25 保下限(保证核心词命中),向量检索提上限(保证语义泛化),最后用 RRF(倒数秩融合)算法将两路结果合并。
  • 常用工具: ChromaDB 的 .query() 方法;Elasticsearch (做混合检索)。
  • ****核心补充(图谱结合):GraphRAG 检索 在你的项目中,检索阶段不仅要拿 Top-K 的文本块,还要利用识别出的实体(如“Transformer”)作为起点,在 Neo4j 中向外拓展 1-2 跳(比如查出它相关的改进模型),将图谱关系转化为文本一并作为检索结果。
3. 重排序 (Reranking)(极其重要的进阶模块)

向量检索速度极快,但它是一个“粗排”过程(双塔模型),很容易把表面相似但逻辑无关的片段召回。

  • 技术原理: 引入一个专门的交叉编码器(Cross-Encoder)模型。它不仅看向量距离,而是把 Query 和召回的每一个 Document 放在一起深度阅读,输出一个 0-1 之间的精准相关性得分,然后重新排序,剔除低分项。
  • 发展历史: 从早期的业务规则打分,演进到基于 BERT 架构的专门排序模型,现在也有直接让大模型(LLM as a Judge)来进行打分的策略。
  • 常用工具 / 模型:
    • BGE-Reranker (智源研究院): 既然你在链路一用了 BGE-M3,强烈建议搭配 BGE-Reranker。它在学术文献、长文本重排上效果绝佳。
    • Cohere Rerank API: 闭源界最好用的重排 API。
4. 提示词组装 (Prompt Construction)

把提纯后的“参考资料”和用户的“原始问题”拼接成一个格式化模板,送给 LLM。

  • 技术原理: 运用 In-Context Learning(上下文学习),让大模型在有限的窗口内阅读提供的背景知识。
  • 常用工具: Jinja2(在 Python 后端极为常见)、LangChain 的 PromptTemplate
  • ****核心补充:防幻觉约束 Prompt 中必须加入强硬的护栏指令,例如: “请严格根据提供的参考资料回答。如果参考资料中没有相关信息,请直接回答‘知识库中暂无相关信息’,严禁编造。” 此外,可以让 LLM 在回答时带上引用来源(例如:“根据 [Chunk-2],该算法优化了...”),这对于技术前沿分析尤为重要。
5. 大模型生成 (Generation)

LLM 接收拼装好的开卷考题,输出最终结果。

  • 技术原理: 基于 Transformer 的自回归生成(预测下一个 Token)。
  • 常用工具: * API 形式: OpenAI、Claude、DeepSeek 等。
    • 本地部署: 使用 Ollama 或 vLLM 部署 Llama-3、Qwen 等开源模型。
  • 核心补充:流式输出 (Streaming) 与 Agent 整合 对于用户体验来说,生成必须是流式的(逐字吐出)。同时,在这个阶段,你完全可以利用 LangGraph 构建一个包含“反思机制(Self-RAG)”的工作流:如果 LLM 生成的答案被自身的评分节点判定为“不准确”或“偏离主题”,可以让流程回退,重新调整 Query 进行二次检索。
import os
import operator
from typing import TypedDict, Annotated, List
from dotenv import load_dotenv

load_dotenv()

from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from FlagEmbedding import BGEM3FlagModel, FlagReranker
import chromadb
from neo4j import GraphDatabase

# ==========================================
# 1. 定义整个工作流的“状态” (State)
# ==========================================
class GraphRAGState(TypedDict):
    original_query: str                                        # 用户原始问题
    rewritten_query: str                                       # LLM重写后的问题
    route_type: str                                            # 路由决策
    retrieved_contexts: Annotated[List[str], operator.add]     # 多节点结果自动合并
    retrieved_metadatas: Annotated[List[dict], operator.add]   # 对应的元数据(文档名等)
    final_answer: str                                          # 最终生成的答案

# ==========================================
# 2. 初始化核心组件 (大模型、数据库、Reranker)
# ==========================================
# 建议在实际项目中把这些封装成类,这里为了展示逻辑直观呈现
llm = ChatOpenAI(
    temperature=0,
    api_key=os.getenv("API_KEY"),
    base_url=os.getenv("BASE_URL"),
    model=os.getenv("MODEL_NAME"),
)
embedding_model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True, devices="cuda:0")  # 与链路一保持一致
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True, devices="cuda:0")

# 连接链路一建好的 Chroma 向量库
_RAG_DIR = os.path.dirname(os.path.abspath(__file__))
chroma_client = chromadb.PersistentClient(path=os.path.join(_RAG_DIR, "chroma_db"))
_vector_collection = None

def _get_vector_collection():
    global _vector_collection
    if _vector_collection is None:
        _vector_collection = chroma_client.get_collection(name="llm_knowledge_chunks")
    return _vector_collection

# 连接 Neo4j 图数据库 (可选)
try:
    neo4j_driver = GraphDatabase.driver(
        os.getenv("NEO4J_URI",  "bolt://localhost:7687"),
        auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PWD", "password")),
    )
    neo4j_driver.verify_connectivity()
    _neo4j_available = True
except Exception as _neo4j_err:
    print(f"⚠️  Neo4j 不可用,图谱检索将返回空结果: {_neo4j_err}")
    neo4j_driver = None
    _neo4j_available = False

# ==========================================
# 3. 定义 LangGraph 的各个节点 (Nodes)
# ==========================================

def query_analysis_node(state: GraphRAGState) -> GraphRAGState:
    """节点 1:查询分析与路由决策"""
    print("--- [节点] 查询分析与路由 ---")
    query = state["original_query"]
    
    # 简单的 Prompt,让 LLM 决定怎么查
    prompt = ChatPromptTemplate.from_template(
        "你是一个大模型领域的知识路由专家。请分析以下用户问题:\n"
        "问题:{query}\n"
        "如果问题是关于概念原理、技术细节、长文本描述,请回复 'vector_only'。\n"
        "如果问题是关于实体关系(如谁提出了什么、某模型基于什么架构、属于哪个机构),请回复 'graph_only'。\n"
        "如果是复杂的综合问题,回复 'hybrid'。\n"
        "仅输出这三个词中的一个。"
    )
    chain = prompt | llm
    route_result = chain.invoke({"query": query}).content.strip().lower()
    if route_result not in ("vector_only", "graph_only", "hybrid"):
        route_result = "vector_only"
    
    print(f"  路由决策: {route_result}")
    # 这里也可以顺便做 Query Rewrite,为了代码简洁暂略
    return {"route_type": route_result, "rewritten_query": query}

def vector_retrieval_node(state: GraphRAGState) -> GraphRAGState:
    """节点 2A:向量检索 (从 ChromaDB 捞取文本块)"""
    print("--- [节点] 向量数据库检索 ---")
    query = state["rewritten_query"]
    
    # 用 BGE-M3 编码 query,与链路一写入时保持一致
    query_vec = embedding_model.encode([query])['dense_vecs'].astype('float32').tolist()
    results = _get_vector_collection().query(query_embeddings=query_vec, n_results=5)
    contexts = results['documents'][0] if results['documents'] else []
    metadatas = results['metadatas'][0] if results['metadatas'] else [{} for _ in contexts]
    
    print(f"  向量检索命中 {len(contexts)} 条,来源文档:")
    for i, (ctx, meta) in enumerate(zip(contexts, metadatas)):
        doc = meta.get('doc_name', '未知')
        print(f"    [{i+1}] 📄 {doc} | {ctx[:80].replace(chr(10), ' ')}...")
    
    return {"retrieved_contexts": contexts, "retrieved_metadatas": metadatas}

def graph_retrieval_node(state: GraphRAGState) -> GraphRAGState:
    """节点 2B:图谱检索 (用 LLM 提取关键词,在 Neo4j 中做全文匹配)"""
    print("--- [节点] 图数据库检索 ---")
    query = state["rewritten_query"]
    contexts = []
    metadatas = []

    if not _neo4j_available:
        print("⚠️  Neo4j 不可用,图谱检索返回空结果")
        return {"retrieved_contexts": contexts, "retrieved_metadatas": metadatas}

    # Step 1: 用 LLM 从 query 中提取 2~4 个核心关键词(英文/中文均可)
    kw_prompt = ChatPromptTemplate.from_template(
        "请从以下问题中提取 2~4 个最核心的英文技术关键词,用于在论文数据库中做全文检索。"
        "只输出关键词列表,用英文逗号分隔,不要输出其他内容。\n问题:{query}"
    )
    kw_result = (kw_prompt | llm).invoke({"query": query}).content.strip()
    keywords = [k.strip() for k in kw_result.replace(",", ",").split(",") if k.strip()][:4]
    print(f"  提取关键词: {keywords}")

    # Step 2: 对每个关键词在 Neo4j 的 Chunk.text 中做 CONTAINS 匹配,取相关片段
    seen = set()
    with neo4j_driver.session() as session:
        for kw in keywords:
            cypher = (
                "MATCH (c:Chunk)-[:PART_OF]->(d:Document) "
                "WHERE toLower(c.text) CONTAINS toLower($kw) "
                "RETURN c.text, d.name LIMIT 3"
            )
            result = session.run(cypher, kw=kw)
            for record in result:
                text = record['c.text']
                doc_name = record['d.name']
                key = text[:80]
                if key in seen:
                    continue
                seen.add(key)
                contexts.append(text)
                metadatas.append({"doc_name": doc_name, "source": "graph", "keyword": kw})

    print(f"  图谱检索命中 {len(contexts)} 条,来源文档:")
    for i, (ctx, meta) in enumerate(zip(contexts, metadatas)):
        doc = meta.get('doc_name', '未知')
        kw = meta.get('keyword', '')
        print(f"    [{i+1}] 🔑 {kw} | 📄 {doc} | {ctx[:80].replace(chr(10), ' ')}...")

    return {"retrieved_contexts": contexts, "retrieved_metadatas": metadatas}

def rerank_node(state: GraphRAGState) -> GraphRAGState:
    """节点 3:重排序 (去除噪声,保留最相关的信息)"""
    print("--- [节点] BGE-Reranker 重排序 ---")
    query = state["rewritten_query"]
    contexts = state.get("retrieved_contexts", [])
    metadatas = state.get("retrieved_metadatas", [{} for _ in contexts])
    
    if not contexts:
        return state
        
    # 构建 pairs 给 BGE-Reranker 打分
    pairs = [[query, ctx] for ctx in contexts]
    scores = reranker.compute_score(pairs)
    if not isinstance(scores, list):
        scores = [scores]
    
    # 根据得分从高到低排序,过滤掉得分过低的噪声,这里保留 Top 3
    scored = sorted(zip(contexts, metadatas, scores), key=lambda x: x[2], reverse=True)
    
    print(f"  重排打分结果 (共 {len(scored)} 条):")
    for i, (ctx, meta, score) in enumerate(scored):
        doc = meta.get('doc_name', '未知')
        marker = "✅" if i < 3 else "❌"
        print(f"    {marker} [{i+1}] score={score:.4f} | 📄 {doc} | {ctx[:60].replace(chr(10), ' ')}...")
    
    top_contexts = [ctx for ctx, meta, score in scored[:3]]
    top_metadatas = [meta for ctx, meta, score in scored[:3]]
    
    print(f"  重排后保留了 {len(top_contexts)} 条高质量参考资料。")
    return {"retrieved_contexts": top_contexts, "retrieved_metadatas": top_metadatas}

def generation_node(state: GraphRAGState) -> GraphRAGState:
    """节点 4:最终大模型生成"""
    print("--- [节点] 组装 Prompt 与生成答案 ---")
    query = state["original_query"]
    contexts_list = state["retrieved_contexts"]
    metadatas_list = state.get("retrieved_metadatas", [{} for _ in contexts_list])
    
    # 构建带编号和来源标注的参考资料块
    numbered_contexts = []
    for i, (ctx, meta) in enumerate(zip(contexts_list, metadatas_list), 1):
        doc = meta.get('doc_name', '未知来源')
        numbered_contexts.append(f"[{i}] 来源:{doc}\n{ctx}")
    contexts_str = "\n\n".join(numbered_contexts)
    
    print(f"  输入参考资料 {len(numbered_contexts)} 条:")
    for i, (ctx, meta) in enumerate(zip(contexts_list, metadatas_list), 1):
        doc = meta.get('doc_name', '未知来源')
        print(f"    [{i}] 📄 {doc}")
        print(f"        {ctx[:120].replace(chr(10), ' ')}...")
    
    prompt = ChatPromptTemplate.from_template(
        "你是一个顶尖的计算机科学与大模型技术研究助手。请严格根据以下带编号的参考资料回答用户问题。\n"
        "要求:\n"
        "1. 回答中每个关键论断必须以 [编号] 的形式标注来源,例如 [1][2]。\n"
        "2. 尽量引用原文中的关键表述,用引号括起来。\n"
        "3. 如果参考资料中无法得出答案,请诚实说明,不要编造。\n\n"
        "【参考资料】:\n{contexts}\n\n"
        "【用户问题】: {query}\n"
        "【你的回答】:"
    )
    chain = prompt | llm

    # 流式输出
    print("  ⏳ 生成中(流式):")
    full_answer = ""
    for chunk in chain.stream({"contexts": contexts_str, "query": query}):
        token = chunk.content
        print(token, end="", flush=True)
        full_answer += token
    print()  # 换行

    return {"final_answer": full_answer}

# ==========================================
# 4. 定义条件路由逻辑 (Conditional Edges)
# ==========================================
def route_query(state: GraphRAGState) -> str:
    """根据路由分析结果,决定走向哪个检索节点"""
    route = state["route_type"]
    if route == "graph_only":
        return "graph_retrieval_node"
    else:
        # vector_only 和 hybrid 都先走向量检索
        return "vector_retrieval_node"

# ==========================================
# 5. 构建与编译 LangGraph 工作流
# ==========================================
workflow = StateGraph(GraphRAGState)

# 添加节点
workflow.add_node("query_analysis", query_analysis_node)
workflow.add_node("vector_retrieval_node", vector_retrieval_node)
workflow.add_node("graph_retrieval_node", graph_retrieval_node)
workflow.add_node("rerank_node", rerank_node)
workflow.add_node("generation_node", generation_node)

# 定义边 (流程流转)
workflow.set_entry_point("query_analysis")

# 动态路由:分析完问题后,去查哪个库?
workflow.add_conditional_edges(
    "query_analysis",
    route_query,
    {
        "vector_retrieval_node": "vector_retrieval_node",
        "graph_retrieval_node": "graph_retrieval_node",
    }
)

# hybrid:向量检索后继续走图谱检索;vector_only/graph_only 直接汇聚到重排
workflow.add_conditional_edges(
    "vector_retrieval_node",
    lambda s: "graph_retrieval_node" if s["route_type"] == "hybrid" else "rerank_node",
    {"graph_retrieval_node": "graph_retrieval_node", "rerank_node": "rerank_node"}
)
workflow.add_edge("graph_retrieval_node", "rerank_node")

# 重排后生成最终答案,然后结束
workflow.add_edge("rerank_node", "generation_node")
workflow.add_edge("generation_node", END)

# 编译图
app = workflow.compile()

# ==========================================
# 运行示例
# ==========================================
if __name__ == "__main__":
    test_query = "目前主流的多模态对齐算法有哪些?它们分别优化了哪些缺陷?"
    
    print(f"用户提问: {test_query}\n")
    
    # 运行 LangGraph
    final_state = app.invoke({"original_query": test_query})
    
    print("\n================ 最终回答 ================\n")
    print(final_state["final_answer"])