16_LangChain自定义会话管理和Retriever

99 阅读9分钟

LangChain自定义会话管理和Retriever

引言

在构建基于LLM的对话应用时,会话管理和信息检索是两个核心功能。会话管理使应用能够记住与用户的对话历史,保持上下文连贯性;而自定义检索器(Retriever)则允许从特定数据源获取相关信息,增强LLM的回答能力。本教程将深入探讨如何在LangChain中实现自定义会话管理和检索器,帮助开发者构建更智能、更个性化的对话应用。

1. 自定义会话管理

会话管理的重要性

在实际的对话应用中,我们需要一种机制来持久化存储对话历史,并自动将其注入到每次交互中。这有几个重要原因:

  • 上下文连贯性:使模型能够理解对话的上下文,避免重复提问或无关回答
  • 个性化体验:根据用户之前的交互定制响应
  • 多用户支持:为不同用户维护独立的对话历史
  • 会话持久化:即使应用重启,也能恢复之前的对话状态

LangChain中的会话管理组件

LangChain提供了两个核心组件来实现会话管理:

  1. BaseChatMessageHistory:用于存储对话历史记录
  2. RunnableWithMessageHistory:LCEL链和BaseChatMessageHistory的包装器,负责自动注入和更新对话历史

会话管理架构

实现内存字典会话管理

下面是一个使用内存字典存储会话历史的示例:

from typing import Dict, List, Optional
from langchain.memory import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# 1. 创建一个字典来存储不同会话的历史记录
store: Dict[str, BaseChatMessageHistory] = {}

# 2. 定义一个函数,根据会话ID获取或创建历史记录
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# 3. 创建一个简单的聊天链
prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个友好的AI助手,能够记住我们的对话。"),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}")
])

chain = prompt | ChatOpenAI(temperature=0)

# 4. 使用RunnableWithMessageHistory包装链,实现自动会话管理
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

# 5. 使用会话ID调用链
response1 = chain_with_history.invoke(
    {"input": "我的名字是张三"},
    config={"configurable": {"session_id": "user_123"}}
)
print(response1.content)

response2 = chain_with_history.invoke(
    {"input": "你还记得我的名字吗?"},
    config={"configurable": {"session_id": "user_123"}}
)
print(response2.content)

# 6. 使用不同的会话ID
response3 = chain_with_history.invoke(
    {"input": "你知道我是谁吗?"},
    config={"configurable": {"session_id": "user_456"}}
)
print(response3.content)

# 7. 查看存储的会话历史
print("\n会话历史 user_123:")
for message in store["user_123"].messages:
    print(f"{message.type}: {message.content}")

print("\n会话历史 user_456:")
for message in store["user_456"].messages:
    print(f"{message.type}: {message.content}")

运行结果分析

运行上述代码后,我们会看到:

  1. 第一次交互中,模型记住了用户名"张三"
  2. 第二次交互中,模型能够回忆起用户名,因为使用了相同的会话ID
  3. 第三次交互中,模型无法回忆起用户名,因为使用了不同的会话ID
  4. 我们可以查看存储在字典中的不同会话历史

自定义会话存储

在实际应用中,我们通常需要将会话历史存储在数据库中,而不是内存字典中。LangChain提供了多种会话存储实现:

  • ChatMessageHistory:内存存储,应用重启后会丢失
  • RedisChatMessageHistory:使用Redis存储会话历史
  • MongoDBChatMessageHistory:使用MongoDB存储会话历史
  • PostgresChatMessageHistory:使用PostgreSQL存储会话历史
  • 自定义存储:通过继承BaseChatMessageHistory实现

下面是一个使用Redis存储会话历史的示例:

from langchain.memory import RedisChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# 创建Redis会话历史获取函数
def get_redis_history(session_id: str) -> BaseChatMessageHistory:
    return RedisChatMessageHistory(
        session_id=session_id,
        url="redis://localhost:6379/0"
    )

# 使用Redis会话历史包装链
chain_with_redis_history = RunnableWithMessageHistory(
    chain,
    get_redis_history,
    input_messages_key="input",
    history_messages_key="history"
)

2. 自定义Retriever

Retriever的作用

检索器(Retriever)是LangChain中的核心组件,负责从外部数据源检索与用户查询相关的文档。在RAG(检索增强生成)应用中,检索器扮演着关键角色:

  1. 接收用户查询
  2. 从数据源中检索相关文档
  3. 将检索到的文档格式化为提示
  4. 输入LLM生成回答

Retriever接口

要创建自定义检索器,需要继承BaseRetriever类并实现以下方法:

方法描述必需/可选
_get_relevant_documents获取与查询相关的文档必需
_aget_relevant_documents实现以提供异步本机支持可选

通过继承BaseRetriever,您的检索器将自动成为LangChain Runnable,获得标准的Runnable功能。

自定义Retriever的优势

将检索器实现为BaseRetriever而不是RunnableLambda有以下优势:

  1. 专门的监控:一些监控工具为检索器实现了专门的行为
  2. 标准事件:在API中,检索器会触发特定事件,如on_retriever_start而不是on_chain_start
  3. 更好的集成:与LangChain生态系统的其他组件更好地集成

实现简单的文本匹配检索器

下面是一个简单的文本匹配检索器示例,它返回所有包含用户查询文本的文档:

from typing import List
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

class SimpleTextMatchRetriever(BaseRetriever):
    """一个简单的文本匹配检索器,返回包含查询文本的文档。"""
    
    def __init__(self, documents: List[Document]):
        """初始化检索器。"""
        super().__init__()
        self.documents = documents
        
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """获取与查询相关的文档。"""
        # 简单的文本匹配逻辑
        return [doc for doc in self.documents if query.lower() in doc.page_content.lower()]

# 创建一些示例文档
documents = [
    Document(page_content="猫是一种常见的家养宠物,性格独立。", metadata={"animal": "cat"}),
    Document(page_content="狗是人类最忠诚的朋友,非常友好。", metadata={"animal": "dog"}),
    Document(page_content="兔子有长长的耳朵和蓬松的尾巴。", metadata={"animal": "rabbit"}),
    Document(page_content="熊猫是中国的国宝,主要吃竹子。", metadata={"animal": "panda"}),
]

# 创建检索器实例
retriever = SimpleTextMatchRetriever(documents)

# 测试检索器
query = "猫"
results = retriever.invoke(query)
print(f"查询: '{query}'")
print(f"找到 {len(results)} 个相关文档:")
for doc in results:
    print(f"- {doc.page_content} (metadata: {doc.metadata})")

创建高级自定义检索器

下面是一个更复杂的检索器示例,它结合了文本匹配和元数据过滤:

from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

class AdvancedRetriever(BaseRetriever):
    """高级检索器,支持文本匹配和元数据过滤。"""
    
    def __init__(self, documents: List[Document]):
        """初始化检索器。"""
        super().__init__()
        self.documents = documents
        
    def _get_relevant_documents(
        self, 
        query: str, 
        metadata_filter: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """获取与查询相关的文档。
        
        Args:
            query: 用户查询
            metadata_filter: 可选的元数据过滤条件
            
        Returns:
            相关文档列表
        """
        # 先进行文本匹配
        matched_docs = [
            doc for doc in self.documents 
            if query.lower() in doc.page_content.lower()
        ]
        
        # 如果有元数据过滤条件,进一步过滤
        if metadata_filter:
            filtered_docs = []
            for doc in matched_docs:
                match = True
                for key, value in metadata_filter.items():
                    if key not in doc.metadata or doc.metadata[key] != value:
                        match = False
                        break
                if match:
                    filtered_docs.append(doc)
            return filtered_docs
        
        return matched_docs
    
    def invoke(self, query: str, metadata_filter: Optional[Dict[str, Any]] = None) -> List[Document]:
        """重写invoke方法以支持元数据过滤。"""
        return self._get_relevant_documents(query, metadata_filter)

# 创建检索器实例
advanced_retriever = AdvancedRetriever(documents)

# 测试检索器 - 带元数据过滤
query = "宠物"
metadata_filter = {"animal": "dog"}
results = advanced_retriever.invoke(query, metadata_filter)
print(f"\n查询: '{query}' 带元数据过滤 {metadata_filter}")
print(f"找到 {len(results)} 个相关文档:")
for doc in results:
    print(f"- {doc.page_content} (metadata: {doc.metadata})")

异步检索器实现

对于需要异步支持的应用,可以实现_aget_relevant_documents方法:

import asyncio
from typing import List
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

class AsyncRetriever(BaseRetriever):
    """支持异步操作的检索器。"""
    
    def __init__(self, documents: List[Document]):
        """初始化检索器。"""
        super().__init__()
        self.documents = documents
        
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """同步方法实现。"""
        return [doc for doc in self.documents if query.lower() in doc.page_content.lower()]
    
    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        """异步方法实现。"""
        # 模拟异步操作
        await asyncio.sleep(0.1)
        return [doc for doc in self.documents if query.lower() in doc.page_content.lower()]

# 创建异步检索器
async_retriever = AsyncRetriever(documents)

# 异步调用示例
async def test_async_retriever():
    query = "宠物"
    results = await async_retriever.ainvoke(query)
    print(f"\n异步查询: '{query}'")
    print(f"找到 {len(results)} 个相关文档:")
    for doc in results:
        print(f"- {doc.page_content}")

# 运行异步测试
import asyncio
asyncio.run(test_async_retriever())

3. 将自定义会话管理与检索器结合

在实际应用中,我们通常需要将会话管理与检索器结合使用,创建一个能够记住对话历史并从外部数据源检索信息的智能对话系统。

创建RAG应用示例

下面是一个结合了自定义会话管理和检索器的RAG应用示例:

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory

# 1. 创建检索器
retriever = SimpleTextMatchRetriever(documents)

# 2. 创建会话存储
sessions = {}
def get_session(session_id):
    if session_id not in sessions:
        sessions[session_id] = ChatMessageHistory()
    return sessions[session_id]

# 3. 创建RAG提示模板
prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个友好的AI助手,能够回答关于动物的问题。使用以下检索到的信息来回答用户问题:\n\n{context}"),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{question}")
])

# 4. 创建RAG链
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | ChatOpenAI(temperature=0)
    | StrOutputParser()
)

# 5. 添加会话管理
rag_chain_with_history = RunnableWithMessageHistory(
    rag_chain,
    get_session,
    input_messages_key="question",
    history_messages_key="history"
)

# 6. 测试应用
session_id = "user_789"
response1 = rag_chain_with_history.invoke(
    "告诉我关于猫的信息",
    config={"configurable": {"session_id": session_id}}
)
print("回答1:", response1)

response2 = rag_chain_with_history.invoke(
    "它们有什么特点?",
    config={"configurable": {"session_id": session_id}}
)
print("回答2:", response2)

4. 最佳实践

会话管理最佳实践

  1. 选择合适的存储:根据应用需求选择适当的会话存储方式(内存、数据库等)
  2. 会话过期策略:实现会话过期机制,避免存储过多历史消息
  3. 消息压缩:在长对话中考虑使用消息压缩或摘要技术
  4. 多用户隔离:确保不同用户的会话完全隔离
  5. 安全性考虑:加密敏感的会话内容

检索器最佳实践

  1. 相关性评分:实现相关性评分机制,只返回最相关的文档
  2. 多样性考虑:确保返回的文档具有多样性,避免冗余
  3. 错误处理:妥善处理检索过程中可能出现的错误
  4. 性能优化:对于大型数据集,考虑使用向量数据库或其他高效索引
  5. 动态更新:支持数据源的动态更新,确保检索最新信息

总结

自定义会话管理和检索器是构建高质量LLM应用的关键组件。通过LangChain提供的灵活接口,开发者可以根据具体需求实现个性化的会话管理和信息检索功能。

会话管理使应用能够记住与用户的对话历史,提供连贯的交互体验;而自定义检索器则使应用能够从特定数据源获取相关信息,增强LLM的回答能力。将这两个组件结合使用,可以构建出既能记住上下文又能利用外部知识的智能对话系统。

随着应用规模的增长,可以考虑使用更高级的存储解决方案和更复杂的检索算法,进一步提升系统的性能和用户体验。