rag+agent主程序

3 阅读15分钟

第三部分:状态定义(工作流数据结构)

python

运行

# 定义消息状态类,使用TypedDict进行类型注解
class MessagesState(TypedDict):
    # 定义messages字段,类型为消息序列,使用add_messages处理追加
    messages: Annotated[Sequence[BaseMessage], add_messages]
    # 定义relevance_score字段,用于存储文档相关性评分
    relevance_score: Annotated[Optional[str], "Relevance score of retrieved documents, 'yes' or 'no'"]
    # 定义rewrite_count字段,用于跟踪问题重写的次数,达到次数退出graph的递归循环
    rewrite_count: Annotated[int, "Number of times query has been rewritten"]

这是整个工作流的 “数据载体” ,包含 3 个关键字段:

  1. messages:对话历史(自动追加)
  2. relevance_score:文档是否相关(yes/no)
  3. rewrite_count:问题重写次数(最多 3 次,防止死循环)

第四部分:工具配置管理类

python

运行

# 定义工具配置管理类,用于管理工具及其路由配置
class ToolConfig:
    # 初始化方法,接收工具列表并设置相关属性
    def __init__(self, tools):
        self.tools = tools
        self.tool_names = {tool.name for tool in tools}
        self.tool_routing_config = self._build_routing_config(tools)
        logger.info(f"Initialized ToolConfig with tools: {self.tool_names}, routing: {self.tool_routing_config}")
  • 统一管理所有工具:工具列表、工具名、路由规则

python

运行

# 内部方法,用于根据工具定义动态构建路由配置
def _build_routing_config(self, tools):
    routing_config = {}
    for tool in tools:
        tool_name = tool.name.lower()
        # 检索类工具 → 先评分
        if "retrieve" in tool_name:
            routing_config[tool_name] = "grade_documents"
        # 非检索工具 → 直接生成答案
        else:
            routing_config[tool_name] = "generate"
    return routing_config

核心规则

  • 工具名带 retrieve(检索)→ 先去评分文档节点
  • 其他工具 → 直接去生成答案节点

后面还有 3 个简单方法:

python

运行

def get_tools(self): return self.tools
def get_tool_names(self): return self.tool_names
def get_tool_routing_config(self): return self.tool_routing_config
  • 提供外部访问接口

第五部分:结构化输出(文档评分)

python

运行

# 文档相关性评分
class DocumentRelevanceScore(BaseModel):
    # 定义binary_score字段,表示相关性评分,取值为"yes""no"
    binary_score: str = Field(description="Relevance score 'yes' or 'no'")
  • 强制大模型只返回 yes/no,用于判断文档是否有用

第六部分:自定义异常 + 并发工具节点

python

运行

# 自定义异常,表示数据库连接池初始化或状态异常
class ConnectionPoolError(Exception):
    """自定义异常,表示数据库连接池初始化或状态异常"""
    pass
  • 专门处理数据库连接池错误

python

运行

# 重定义ToolNode,支持并发处理工具调用
class ParallelToolNode(ToolNode):
    def __init__(self, tools, max_workers: int = 5):
        super().__init__(tools)
        self.max_workers = max_workers
  • 重写 LangGraph 自带工具节点,支持并发调用多个工具

python

运行

def _run_single_tool(self, tool_call: dict, tool_map: dict) -> ToolMessage:
    try:
        tool_name = tool_call["name"]
        tool = tool_map.get(tool_name)
        if not tool: raise ValueError(f"Tool {tool_name} not found")
        result = tool.invoke(tool_call["args"])
        return ToolMessage(content=str(result),tool_call_id=tool_call["id"],name=tool_name)
    except Exception as e:
        logger.error(f"Error executing tool {tool_call.get('name', 'unknown')}: {e}")
        return ToolMessage(content=f"Error: {str(e)}",tool_call_id=tool_call["id"],name=tool_call.get("name", "unknown"))
  • 执行单个工具,捕获异常,返回工具结果

python

运行

def __call__(self, state: dict) -> dict:
    logger.info("ParallelToolNode processing tool calls")
    last_message = state["messages"][-1]
    tool_calls = getattr(last_message, "tool_calls", [])
    if not tool_calls: return {"messages": []}

    tool_map = {tool.name: tool for tool in self.tools}
    results = []

    with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
        future_to_tool = {executor.submit(self._run_single_tool, tool_call, tool_map): tool_call for tool_call in tool_calls}
        for future in as_completed(future_to_tool):
            try:
                result = future.result()
                results.append(result)
            except Exception as e:
                logger.error(f"Tool execution failed: {e}")
                tool_call = future_to_tool[future]
                results.append(ToolMessage(content=f"Unexpected error: {str(e)}",tool_call_id=tool_call["id"],name=tool_call.get("name", "unknown")))

    logger.info(f"Completed {len(results)} tool calls")
    return {"messages": results}

核心

  • 线程池并发执行多个工具(最多 5 个)
  • 所有工具执行完,统一返回结果

第七部分:通用工具函数

1. 获取最新用户问题

python

运行

def get_latest_question(state: MessagesState) -> Optional[str]:
    try:
        if not state.get("messages"): return None
        for message in reversed(state["messages"]):
            if message.__class__.__name__ == "HumanMessage":
                return message.content
        return None
    except Exception as e:
        logger.error(f"Error getting latest question: {e}")
        return None
  • 倒序找最近一条用户消息

2. 过滤消息(只保留用户 / AI 消息)

python

运行

def filter_messages(messages: list) -> list:
    filtered = [msg for msg in messages if msg.__class__.__name__ in ['AIMessage', 'HumanMessage']]
    return filtered[-5:] if len(filtered) > 5 else filtered
  • 只保留最近 5 轮对话,减少 token 消耗

3. 存储用户记忆

python

运行

def store_memory(question: BaseMessage, config: RunnableConfig, store: BaseStore) -> str:
    namespace = ("memories", config["configurable"]["user_id"])
    try:
        memories = store.search(namespace, query=str(question.content))
        user_info = "\n".join([d.value["data"] for d in memories])
        if "记住" in question.content.lower():
            memory = escape(question.content)
            store.put(namespace, str(uuid.uuid4()), {"data": memory})
            logger.info(f"Stored memory: {memory}")
        return user_info
    except Exception as e:
        logger.error(f"Error in store_memory: {e}")
        return ""
  • 用户说记住XXX → 存到数据库
  • 下次提问自动带出相关记忆

4. 创建 LLM 调用链(带缓存)

python

运行

def create_chain(llm_chat, template_file: str, structured_output=None):
    if not hasattr(create_chain, "prompt_cache"):
        create_chain.prompt_cache = {}
        create_chain.lock = threading.Lock()

    try:
        if template_file in create_chain.prompt_cache:
            prompt_template = create_chain.prompt_cache[template_file]
        else:
            with create_chain.lock:
                if template_file not in create_chain.prompt_cache:
                    create_chain.prompt_cache[template_file] = PromptTemplate.from_file(template_file, encoding="utf-8")
                prompt_template = create_chain.prompt_cache[template_file]

        prompt = ChatPromptTemplate.from_messages([("human", prompt_template.template)])
        return prompt | (llm_chat.with_structured_output(structured_output) if structured_output else llm_chat)
    except FileNotFoundError:
        logger.error(f"Template file {template_file} not found")
        raise
  • 从文件加载提示词
  • 线程安全缓存:不用重复读文件
  • 支持结构化输出

第八部分:数据库可靠性机制

1. 数据库连接测试(带重试)

python

运行

@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=2, max=10),retry=retry_if_exception_type(OperationalError))
def test_connection(db_connection_pool: ConnectionPool) -> bool:
    with db_connection_pool.getconn() as conn:
        with conn.cursor() as cursor:
            cursor.execute("SELECT 1")
            result = cursor.fetchone()
            if result != (1,):
                raise ConnectionPoolError("连接池测试查询失败,返回结果异常")
    return True
  • 失败自动重试 3 次,指数退避等待
  • 执行SELECT 1测试连接是否可用

2. 数据库连接池监控

python

运行

def monitor_connection_pool(db_connection_pool: ConnectionPool, interval: int = 60):
    def _monitor():
        while not db_connection_pool.closed:
            try:
                stats = db_connection_pool.get_stats()
                active = stats.get("connections_in_use", 0)
                total = db_connection_pool.max_size
                logger.info(f"Connection db_connection_pool status: {active}/{total} connections in use")
                if active >= total * 0.8:
                    logger.warning(f"Connection db_connection_pool nearing capacity: {active}/{total}")
            except Exception as e:
                logger.error(f"Failed to monitor connection db_connection_pool: {e}")
            time.sleep(interval)

    monitor_thread = threading.Thread(target=_monitor, daemon=True)
    monitor_thread.start()
    return monitor_thread
  • 后台线程每 60 秒打印连接池状态
  • 连接数超过 80% 报警

第九部分:工作流节点(核心业务逻辑)

1. Agent 分诊节点

python

运行

def agent(state: MessagesState, config: RunnableConfig, *, store: BaseStore, llm_chat, tool_config: ToolConfig) -> dict:
    logger.info("Agent processing user query")
    namespace = ("memories", config["configurable"]["user_id"])
    try:
        question = state["messages"][-1]
        user_info = store_memory(question, config, store)
        messages = filter_messages(state["messages"])
        llm_chat_with_tool = llm_chat.bind_tools(tool_config.get_tools())
        agent_chain = create_chain(llm_chat_with_tool, Config.PROMPT_TEMPLATE_TXT_AGENT)
        response = agent_chain.invoke({"question": question,"messages": messages, "userInfo": user_info})
        return {"messages": [response]}
    except Exception as e:
        logger.error(f"Error in agent processing: {e}")
        return {"messages": [{"role": "system", "content": "处理请求时出错"}]}

作用

  • 大模型作为决策中枢
  • 判断:直接回答 / 调用工具

2. 文档评分节点

python

运行

def grade_documents(state: MessagesState, llm_chat) -> dict:
    logger.info("Grading documents for relevance")
    if not state.get("messages"): return {"messages": [{"role": "system", "content": "状态为空,无法评分"}], "relevance_score": None}
    try:
        question = get_latest_question(state)
        context = state["messages"][-1].content
        grade_chain = create_chain(llm_chat, Config.PROMPT_TEMPLATE_TXT_GRADE, DocumentRelevanceScore)
        scored_result = grade_chain.invoke({"question": question, "context": context})
        score = scored_result.binary_score
        logger.info(f"Document relevance score: {score}")
        return {"messages": state["messages"], "relevance_score": score}
    except Exception as e:
        logger.error(f"Unexpected error in grading: {e}")
        return {"messages": [{"role": "system", "content": "评分过程中出错"}], "relevance_score": None}

作用

  • 评估检索到的文档是否有用
  • 返回 yes /no

3. 查询重写节点

python

运行

def rewrite(state: MessagesState, llm_chat) -> dict:
    logger.info("Rewriting query")
    try:
        question = get_latest_question(state)
        rewrite_chain = create_chain(llm_chat, Config.PROMPT_TEMPLATE_TXT_REWRITE)
        response = rewrite_chain.invoke({"question": question})
        rewrite_count = state.get("rewrite_count", 0) + 1
        logger.info(f"Rewrite count: {rewrite_count}")
        return {"messages": [response], "rewrite_count": rewrite_count}
    except Exception as e:
        logger.error(f"Message access error in rewrite: {e}")
        return {"messages": [{"role": "system", "content": "无法重写查询"}]}

作用

  • 文档不相关 → 优化用户问题,重新检索
  • 最多重写 3 次

4. 生成答案节点

python

运行

def generate(state: MessagesState, llm_chat) -> dict:
    logger.info("Generating final response")
    try:
        question = get_latest_question(state)
        context = state["messages"][-1].content
        generate_chain = create_chain(llm_chat, Config.PROMPT_TEMPLATE_TXT_GENERATE)
        response = generate_chain.invoke({"context": context, "question": question})
        return {"messages": [response]}
    except Exception as e:
        logger.error(f"Message access error in generate: {e}")
        return {"messages": [{"role": "system", "content": "无法生成回复"}]}

作用

  • 用检索到的内容 + 用户问题 → 生成最终回答

第十部分:路由函数(节点跳转规则)

1. 工具调用后路由

python

运行

def route_after_tools(state: MessagesState, tool_config: ToolConfig) -> Literal["generate", "grade_documents"]:
    if not state.get("messages"): return "generate"
    try:
        last_message = state["messages"][-1]
        if not hasattr(last_message, "name"): return "generate"
        tool_name = last_message.name
        if tool_name not in tool_config.get_tool_names(): return "generate"
        target = tool_config.get_tool_routing_config().get(tool_name, "generate")
        return target
    except Exception as e:
        logger.error(f"Unexpected error in route_after_tools: {e}")
        return "generate"
  • 检索工具 → 评分
  • 其他工具 → 生成

2. 评分后路由

python

运行

def route_after_grade(state: MessagesState) -> Literal["generate", "rewrite"]:
    if not isinstance(state, dict): return "rewrite"
    relevance_score = state.get("relevance_score")
    rewrite_count = state.get("rewrite_count", 0)

    if rewrite_count >= 3: return "generate"
    if relevance_score and relevance_score.lower() == "yes": return "generate"
    return "rewrite"
  • 相关 → 生成
  • 不相关 / 重写 < 3 次 → 重写
  • 重写满 3 次 → 直接生成

第十一部分:流程图构建 + 主程序

1. 保存流程图

python

运行

def save_graph_visualization(graph: StateGraph, filename: str = "graph.png") -> None:
    try:
        with open(filename, "wb") as f:
            f.write(graph.get_graph().draw_mermaid_png())
        logger.info(f"Graph visualization saved as {filename}")
    except IOError as e:
        logger.warning(f"Failed to save graph visualization: {e}")
  • 把工作流导出成图片

2. 创建工作流图

python

运行

def create_graph(db_connection_pool: ConnectionPool, llm_chat, llm_embedding, tool_config: ToolConfig) -> StateGraph:
    # 检查连接池
    # 初始化检查点/存储
    workflow = StateGraph(MessagesState)
    # 添加节点
    workflow.add_node("agent", ...)
    workflow.add_node("call_tools", ParallelToolNode(...))
    workflow.add_node("rewrite", ...)
    workflow.add_node("generate", ...)
    workflow.add_node("grade_documents", ...)

    # 连接节点
    workflow.add_edge(START, "agent")
    workflow.add_conditional_edges("agent", tools_condition, {"tools": "call_tools", END: END})
    workflow.add_conditional_edges("call_tools", lambda state: route_after_tools(...))
    workflow.add_conditional_edges("grade_documents", route_after_grade)
    workflow.add_edge("generate", END)
    workflow.add_edge("rewrite", "agent")

    return workflow.compile(checkpointer=checkpointer, store=store)

这是整个程序的 “流程图绘制”

plaintext

开始 → Agent → 调用工具 → 评分/生成 → 结束
                ↓
             不相关 → 重写 → 回到Agent

3. 响应输出函数

python

运行

def graph_response(graph: StateGraph, user_input: str, config: dict, tool_config: ToolConfig) -> None:
    events = graph.stream(...)
    for event in events:
        for value in event.values():
            last_message = value["messages"][-1]
            # 打印工具输出 / AI回答
  • 流式输出结果
  • 区分工具调用和 AI 回答

4. 主函数

python

运行

def main():
    # 加载模型/工具
    llm_chat, llm_embedding = get_llm(Config.LLM_TYPE)
    tools = get_tools(llm_embedding)
    tool_config = ToolConfig(tools)

    # 初始化数据库连接池
    db_connection_pool = ConnectionPool(...)
    db_connection_pool.open()
    monitor_connection_pool(db_connection_pool)

    # 创建工作流
    graph = create_graph(...)

    # 循环对话
    while True:
        user_input = input("User: ")
        if user_input in ["quit","exit","q"]: break
        graph_response(graph, user_input, config, tool_config)
  • 程序入口
  • 初始化所有资源
  • 循环接收用户输入

整体总结

这是一个企业级、高可靠、带记忆的检索增强对话机器人,核心特点:

  1. LangGraph 工作流:Agent 分诊 → 工具调用 → 评分 → 重写 → 生成
  2. 并发工具调用:多工具同时执行,速度更快
  3. PostgreSQL 持久化:会话、记忆、日志全落库
  4. 高可用机制:数据库重试、连接池监控、自动切割日志
  5. 用户记忆:支持记住用户信息,长期对话

一、class ToolConfig 工具配置管理类(超级详细)

python

运行

class ToolConfig:

含义:定义一个工具配置类,专门统一管理所有 AI 工具,包括:

  • 工具列表
  • 工具名字
  • 工具执行完后跳去哪个节点(路由规则)作用:让代码不乱、不重复、可扩展。

python

运行

    def __init__(self, tools):

含义:构造函数,创建 ToolConfig 对象时必须传入工具列表

python

运行

        self.tools = tools

含义:把外部传进来的 tools(工具列表)存到当前对象里,以后随时能用。

python

运行

        self.tool_names = {tool.name for tool in tools}

含义

  • 遍历所有工具
  • 把每个工具的 .name 拿出来
  • 放进**集合(set)**作用:快速判断 “某个名字是不是工具”。

python

运行

        self.tool_routing_config = self._build_routing_config(tools)

含义:调用内部方法,自动生成路由规则

  • 检索工具 → 去评分
  • 其他工具 → 直接生成

python

运行

        logger.info(...)

含义:打印日志,方便调试。


python

运行

    def _build_routing_config(self, tools):

含义内部私有方法,生成 “工具名 → 下一步节点” 的映射字典。

python

运行

        routing_config = {}

含义:空字典,用来存:工具名: 要跳去的节点名

python

运行

        for tool in tools:

含义:遍历每一个工具。

python

运行

            tool_name = tool.name.lower()

含义:转小写,避免大小写敏感导致匹配失败。

python

运行

            if "retrieve" in tool_name:

含义:判断是不是检索类工具(知识库搜索)。

python

运行

                routing_config[tool_name] = "grade_documents"

含义:检索工具执行完 → 必须先评分,判断文档有没有用。

python

运行

            else:
                routing_config[tool_name] = "generate"

含义:不是检索工具 → 不用评分,直接生成答案

python

运行

        if not routing_config:
            logger.warning(...)

含义:如果没有任何工具,打印警告。

python

运行

        return routing_config

含义:返回最终路由规则字典。


python

运行

    def get_tools(self):
        return self.tools

    def get_tool_names(self):
        return self.tool_names

    def get_tool_routing_config(self):
        return self.tool_routing_config

含义:三个getter 方法,外部只能读,不能改内部数据,这叫封装


二、class DocumentRelevanceScore(BaseModel)

python

运行

class DocumentRelevanceScore(BaseModel):
    binary_score: str = Field(description="...")

超级详细解释

  • 这是 Pydantic 模型
  • 作用:强制大模型必须返回固定格式
  • 只能返回:"yes""no"
  • 代码拿到结果后可以直接用 scored_result.binary_score

为什么要写?因为大模型说话乱七八糟,必须结构化输出才能代码自动判断。


三、class ConnectionPoolError(Exception)

python

运行

class ConnectionPoolError(Exception):
    """自定义异常,表示数据库连接池初始化或状态异常"""
    pass

含义:自定义一个异常类型,专门标记数据库连接池出问题。好处:

  • 能精确捕获
  • 日志清晰
  • 代码可读性极高

四、class ParallelToolNode(ToolNode) 并发工具节点(超详细)

python

运行

class ParallelToolNode(ToolNode):

含义:继承 LangGraph 自带的 ToolNode重写它,让工具可以并发执行。原生 ToolNode 是串行,这个改成并行


python

运行

    def __init__(self, tools, max_workers: int = 5):
        super().__init__(tools)
        self.max_workers = max_workers

含义

  • super().__init__(tools):调用父类构造
  • max_workers=5:最多同时跑 5 个工具作用:工具多的时候速度极快。

python

运行

    def _run_single_tool(self, tool_call: dict, tool_map: dict) -> ToolMessage:

含义只跑一个工具,内部方法。

python

运行

        try:

作用:捕获工具执行崩溃,不影响整个流程。

python

运行

            tool_name = tool_call["name"]

从 AI 返回的工具调用里取出名字。

python

运行

            tool = tool_map.get(tool_name)

从字典里找到真正的工具对象。

python

运行

            if not tool:
                raise ValueError(...)

找不到就抛异常。

python

运行

            result = tool.invoke(tool_call["args"])

真正执行工具

python

运行

            return ToolMessage(...)

返回标准格式,给 LangGraph 使用。

python

运行

        except Exception as e:
            logger.error(...)
            return ToolMessage(错误信息)

工具崩溃也返回消息,不中断流程。


python

运行

    def __call__(self, state: dict) -> dict:

含义:让类实例可以像函数一样调用,是 LangGraph 节点的标准格式。

python

运行

        last_message = state["messages"][-1]

取最后一条消息(里面包含 AI 决定调用的工具)。

python

运行

        tool_calls = getattr(last_message, "tool_calls", [])

取出所有要调用的工具列表

python

运行

        if not tool_calls:
            return {"messages": []}

没有工具可调用,直接返回。

python

运行

        tool_map = {tool.name: tool for tool in self.tools}

构建快速查找字典。


python

运行

        with ThreadPoolExecutor(...) as executor:

创建线程池,实现并发

python

运行

            future_to_tool = {
                executor.submit(...)
            }

把所有工具调用提交到线程池后台运行

python

运行

            for future in as_completed(future_to_tool):

按完成顺序获取结果

python

运行

                result = future.result()

获取工具返回值。

python

运行

        return {"messages": results}

把所有工具结果返回给状态。


五、get_latest_question 获取最新用户问题

python

运行

def get_latest_question(state: MessagesState) -> Optional[str]:

作用:从消息历史里倒着找,找到最近的用户问题。

python

运行

        for message in reversed(state["messages"]):
            if message.__class__.__name__ == "HumanMessage":
                return message.content

含义:倒序遍历 → 找到用户消息 → 返回内容。


六、filter_messages 过滤消息

python

运行

def filter_messages(messages: list) -> list:
    filtered = [msg for msg in messages if ...]
    return filtered[-5:]

作用

  • 只保留用户消息 + AI 消息
  • 只保留最近 5 轮目的:省 token、提速、防上下文溢出

七、store_memory 长期记忆存储

python

运行

def store_memory(question: BaseMessage, config, store) -> str:

作用

  • 识别用户说 “记住 XXX”
  • 存到数据库
  • 下次提问自动把相关记忆带进去实现长期记忆

python

运行

        if "记住" in question.content.lower():
            memory = escape(question.content)
            store.put(..., memory)

escape 防止注入攻击。


八、create_chain 创建 LLM 调用链(带缓存)

python

运行

def create_chain(llm_chat, template_file, structured_output=None):

作用

  • 从文件读取提示词
  • 线程安全缓存
  • 绑定模型
  • 返回可执行链

python

运行

    if not hasattr(create_chain, "prompt_cache"):
        create_chain.prompt_cache = {}
        create_chain.lock = threading.Lock()

函数静态变量,只初始化一次。

python

运行

        prompt = ChatPromptTemplate.from_messages(...)
        return prompt | llm

LangChain 表达式语言 LCEL,链式调用。


九、test_connection 数据库连接测试

python

运行

@retry(...)
def test_connection(db_connection_pool):

带重试的连接检查

  • 重试 3 次
  • 指数退避
  • 只捕获数据库错误

python

运行

        cursor.execute("SELECT 1")

标准 PostgreSQL 存活检查。


十、monitor_connection_pool 连接池监控

python

运行

def monitor_connection_pool(...):
    def _monitor():
        while True:
            stats = pool.get_stats()
            ...
            time.sleep(60)

后台守护线程:每 60 秒打印连接数,超过 80% 报警。高并发系统必备。


十一、agent 节点(大脑)

python

运行

def agent(...):

整个系统的大脑

  • 读取用户问题
  • 读取记忆
  • 过滤消息
  • 绑定工具
  • 调用 LLM 判断:直接回答 or 调用工具

python

运行

        llm_chat_with_tool = llm_chat.bind_tools(...)

让大模型知道自己能调用哪些工具

python

运行

        agent_chain = create_chain(...)
        response = agent_chain.invoke(...)

调用提示词 + 模型。


十二、grade_documents 文档评分

python

运行

def grade_documents(...):

判断检索内容是否有用

  • 输入:问题 + 文档
  • 输出:yes /no

python

运行

        grade_chain = create_chain(..., DocumentRelevanceScore)

强制结构化输出。


十三、rewrite 查询重写

python

运行

def rewrite(...):

文档不相关 → 优化问题 → 重新检索。最多重写 3 次。


十四、generate 生成最终答案

python

运行

def generate(...):

用工具结果 + 问题 → 生成自然语言回答。


十五、route_after_tools 工具调用后路由

python

运行

def route_after_tools(...):

根据工具名字决定下一步:

  • 检索 → grade_documents
  • 其他 → generate

十六、route_after_grade 评分后路由

python

运行

def route_after_grade(...):
  • yes → generate
  • no → rewrite
  • 重写满 3 次 → 强制生成

十七、save_graph_visualization 保存流程图

python

运行

def save_graph_visualization(...):

把工作流导出成图片。


十八、create_graph 创建工作流(核心)

python

运行

def create_graph(...):

真正画流程图

python

运行

    workflow = StateGraph(MessagesState)

创建状态图。

python

运行

    workflow.add_node("agent", ...)
    workflow.add_node("call_tools", ...)
    workflow.add_node("rewrite", ...)
    workflow.add_node("generate", ...)
    workflow.add_node("grade_documents", ...)

添加 5 个节点。

python

运行

    workflow.add_edge(START, "agent")

开始 → 大脑

python

运行

    workflow.add_conditional_edges(...)

条件跳转。

python

运行

    return workflow.compile(...)

编译成可运行图。


十九、graph_response 响应输出

python

运行

def graph_response(...):
  • 流式运行
  • 打印工具输出
  • 打印 AI 回答

二十、main 主函数

python

运行

def main():

程序入口:

  • 加载模型
  • 加载工具
  • 初始化数据库
  • 启动监控
  • 创建工作流
  • 循环对话
  • 退出关闭资源

最终超级总结(你一定会懂)

你给的这段代码,是一个完整企业级 RAG 对话机器人,核心流程:

开始 → Agent(大脑) → 调用工具 → 并发执行 → 评分 → 相关 / 不相关 → 重写 / 生成 → 结束

具备:

  • 并发工具
  • 长期记忆
  • 数据库持久化
  • 连接池监控
  • 自动重试
  • 日志切割
  • 线程安全
  • 提示词缓存
  • 查询优化
  • 文档评分