用 LangGraph 重构确定客户意图部分 | 豆包MarsCode AI刷题

384 阅读6分钟

LangGraph 与 LangChain 之间的关系

LangGraph 并不是一个独立于 LangChain 的新框架,而是在 LLM 和 LangChain 的基础之上构建的一个扩展库,可以于 LangChain 现有的链(Chain)等无缝协作。

LangGraph 能够协调多个 Chain、Agent、Tool 等共同协作,实现依赖外部工具、外部数据库且带有反馈的问答任务。

问题场景

假设鲜花运营智能客服 ChatBot 通常会接到两大类问题:

  1. 鲜花养护:保持花的健康、如何浇水、施肥等
  2. 鲜花装饰:如何搭配花、如何装饰场地等

整体框架

  1. 创建一个StateGraph对象(整个状态图的基础类),将整个系统的结构定义为“状态机”
  • 添加nodes表示工作流可以调用的 Chain、Agent 或函数
  • 添加edges表示从一个nodes跳转到下一个nodes的关系
  1. 添加工作节点
  2. 添加路由节点(根据输入选择下一个节点)
  3. 添加一个entry节点(入口点),告诉StateGraph每次运行的时候,从哪里开始
  4. 添加finish节点(结束节点),当StateGraph运行到该节点,说明本轮结束
  5. 编译 Graph
  6. 绘制 Graph 的结构图
  7. 运行 Graph

创建状态图基础类

首先,我们需要创建一个StateGraph对象(整个状态图的基础类)。后续我们将在这个状态图上添加节点和边来描述我们整个系统的工作流程。

class GraphState(TypedDict):
    question: str
    generation: str


# 1. 创建一个 StateGraph 对象
workflow = StateGraph(GraphState)

创建工作节点

接下来我们需要定义分别两个 Agent,一个专门用于回答鲜花养育问题,一个专门由于回答鲜花装饰问题。

对于这两个 Agent,分别用一个类进行定义。类的初始化涉及选择要调用的 LLM 模型,以及提示词的系统预设。不难看出这两个类的差距非常小,其实也可以简化成一个类,只需要在初始化的时候,选择不同的系统预设。

同时,这里我利用管道构建了 Agent 的整个工作链self.chain。首先,会根据输入的question字段值,补全prompt_template,然后将完整的prompt输入给 LLM 进行处理。最后,输出解析器StrOutputParser会接收 LLM 的输出,并将其解析成字符串的形式,返回给用户。

class Gardener:
    def __init__(self):
        self.llm = ChatOpenAI(
            model="internlm2.5-latest",
            api_key=os.getenv("INTERNLM_API_KEY"),
            base_url="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/"
        )
        self.system = """你是一个经验丰富的园丁,擅长解答关于养花育花的问题。"""
        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", self.system),
                ("user", "{problem}"),
            ]
        )
        # StrOutputParser()用于将语言模型的输出解析为字符串
        self.chain = self.prompt_template | self.llm | StrOutputParser()

    def answer(self, problem):
        return self.chain.invoke({"problem": problem})
class FlowerArranger:
    def __init__(self):
        self.llm = ChatOpenAI(
            model="internlm2.5-latest",
            api_key=os.getenv("INTERNLM_API_KEY"),
            base_url="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/"
        )
        self.system = """你是一位网红插花大师,擅长解答关于鲜花装饰的问题。"""
        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", self.system),
                ("user", "{problem}"),
            ]
        )
        # StrOutputParser()用于将语言模型的输出解析为字符串
        self.chain = self.prompt_template | self.llm | StrOutputParser()

    def answer(self, problem):
        return self.chain.invoke({"problem": problem})

定义好两个 Agent 类后,我们需要在状态图中定义这两个工作节点:

def gardener(state):
    """
    回答鲜花养育相关问题
    :param state: 当前状态字典state,包含用户问题
    :return: 更新state,添加 generation 键,即生成的回答
    """
    print("---进入鲜花养育节点---")
    question = state["question"]

    generation = gardener_agent.answer({"question": question})
    return {"question": question, "generation": generation}


def arranger(state):
    """
    回答鲜花装饰相关问题
    :param state: 当前状态字典state,包含用户问题
    :return: 更新state,添加 generation 键,即生成的回答
    """
    print("---进入鲜花装饰节点---")
    question = state["question"]

    generation = arranger_agent.answer({"question": question})
    return {"question": question, "generation": generation}


# 2. 定义工作节点
workflow.add_node("gardener", gardener)
workflow.add_node("arranger", arranger)

创建路由节点

为了满足问题场景,只有两个工作节点还不够,我们还需要定义一个路由节点,根据问题类型,判断将该问题路由到负责鲜花养护的节点还是路由到负责鲜花装饰的节点。

首先,我们需要定义路由节点输出的数据模型,确保路由节点只会输出flower_careflower_decoration两种结果。

class RouteQuestionType(BaseModel):
    """
    定义数据模型
    question_type字段的值只能是 flower_care 或 flower_decoration
    """

    question_type: Literal["flower_care", "flower_decoration"] = Field(
        ...,
        description="给定用户问题,选择将其路由到flower_care或flower_decoration。",
    )

接下来,创建路由节点类。路由节点类还是通过 LLM 实现,并利用with_structured_output方法来结构化模型的输出(该方法的核心思想是复用了 Tool use 的功能,将给定的模式 schema 看成一个工具,要求 LLM 根据 schema 的描述,输出符合要求的“请求格式”)

class QuestionRouter:
    def __init__(self, model_name: str = "llama3.2:3b"):
        """
        初始化语言模型和路由逻辑
        """
        # 初始化语言模型
        self.llm = ChatOllama(model=model_name)
        # 添加结构化输出
        # 将 llm 模型与 RouteQuestionType 数据模型关联起来
        self.structured_llm_router = self.llm.with_structured_output(RouteQuestionType)
        # Prompt
        self.system = """您是将用户问题路由到鲜花护理或鲜花装饰的专家。\n\n
        如果你认为该问题是关于鲜花护理的,请输出 flower_care\n\n
        如果你认为该问题是关于鲜花装饰的,请输出 flower_decoration\n\n
        """
        self.route_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.system),
                ("human", "{question}"),
            ]
        )
        # 将 route_prompt 和 structured_llm_router 结合起来创建的
        # 根据用户的问题决定将问题路由到不同 Agent
        self.question_router = self.route_prompt | self.structured_llm_router

    def route_question(self, question: str) -> Union[dict, BaseModel]:
        """
        根据用户的问题决定将问题路由到不同的 Agent
        """
        return self.question_router.invoke({"question": question})

最后,将路由节点引入到状态图的工作流中:

def route_question(state):
    """
    决定问题应该发送给哪个Agent
    :param state: 包含用户问题
    :return: 返回下一个要调用的节点名称
    """
    print("---问题路由---")
    question = state["question"]
    source = question_router.route_question(question)
    if source.question_type == "flower_care":
        print("---路由到鲜花养育 Agent---")
        return "flower_care"
    elif source.question_type == "flower_decoration":
        print("---路由到鲜花装饰 Agent---")
        return "flower_decoration"

定义边

在 LangGraph 框架中,开始节点和结束节点以及预定义好了,分别为STARTEND,用于告诉 StateGraph 本轮从哪里开始,从哪里结束。

在 LangGraph 中添加edges表示从一个nodes跳转到下一个nodes的关系。接下来我们会利用边,将上面定义的所有节点都连接起来,构建出最终完整的状态图。

首先,当一轮工作流从START节点开始时,会进入路由节点,路由节点根据输入的问题,判断该问题是鲜花养育问题还是鲜花装饰问题,然后路由到对应的工作节点。

workflow.add_conditional_edges(
    START,
    route_question,
    {
        "flower_care": "gardener",
        "flower_decoration": "arranger",
    },
)

工作节点处理完成后,输出最终的回答结果,然后进入END节点,本轮工作流结束。

workflow.add_edge("gardener", END)
workflow.add_edge("arranger", END)

编译并测试

状态图定义完成后,我们需要编译状态图,并可以通过输出可视化结果来检查状态图是否正确:

# 4. 编译状态图
app = workflow.compile()

# 5. 可视化状态图
app.get_graph().draw_mermaid_png(output_file_path="graph.png")

graph.png

测试:

    inputs = {
        "question": "如何为玫瑰浇水?",
        # "question": "如何为婚礼场地装饰花朵?"
    }
    for output in app.stream(inputs):
        for key, value in output.items():
            # Node
            print(f"Node '{key}':")
            # Optional: print full state at each node
            # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
        print("\n---\n")

    # Final generation
    print(value["generation"])