LangChain上下文窗口管理与遗忘策略源码分析
一、LangChain上下文窗口管理概述
1.1 上下文窗口基本概念
LangChain作为一个强大的LLM应用框架,其上下文窗口管理是实现高效对话和推理的核心机制。上下文窗口指的是模型在处理当前输入时能够访问的历史信息范围。在LangChain中,这一机制不仅涉及简单的文本存储,还包括智能的信息筛选、压缩和检索策略。
从用户交互的角度看,当我们与基于LangChain构建的应用进行对话时,系统需要决定哪些历史消息应该被保留并提供给LLM,以帮助其理解当前问题并生成连贯的回答。这一过程涉及到复杂的决策逻辑和内存管理策略。
在源码层面,LangChain通过多种组件协同工作来实现上下文窗口管理。核心组件包括ConversationBufferMemory
、ConversationSummaryMemory
、ConversationBufferWindowMemory
等不同类型的内存实现,以及负责消息处理和筛选的中间件。
1.2 上下文窗口管理的重要性
有效的上下文窗口管理对于LLM应用的性能和用户体验至关重要。主要体现在以下几个方面:
首先,LLM的输入长度有限,例如GPT-4的上下文窗口约为8000 tokens。如果不进行有效的管理,历史对话很快就会填满窗口,导致当前输入无法被完整处理。
其次,过多的无关历史信息会降低模型的推理效率和回答质量。模型需要在大量的文本中筛选出相关信息,这增加了计算负担并可能引入噪声。
最后,合理的上下文窗口管理可以提高系统的内存使用效率,减少不必要的计算资源消耗。通过智能地遗忘过时信息,系统可以保持高效运行。
1.3 上下文窗口管理的核心挑战
在实现上下文窗口管理时,LangChain面临着多个核心挑战:
一是如何在有限的窗口大小内保留最有价值的信息。这需要设计有效的重要性评估机制,能够识别哪些历史消息对当前任务最为关键。
二是如何处理长对话中的信息连贯性。当窗口大小不足以容纳完整对话时,需要确保截断不会破坏对话的逻辑连贯性,避免模型产生困惑。
三是如何平衡信息的完整性和处理效率。过于保守的信息保留策略会导致窗口快速填满,而过于激进的遗忘策略可能会丢失重要线索。
四是如何处理不同类型的信息,例如用户问题、系统回答、工具调用结果等。不同类型的信息可能具有不同的重要性和生命周期,需要区别对待。
二、LangChain上下文窗口实现架构
2.1 核心组件概述
LangChain的上下文窗口管理涉及多个核心组件,这些组件共同构成了一个完整的架构体系。主要组件包括:
-
Memory类:负责存储和管理对话历史。LangChain提供了多种Memory实现,如
ConversationBufferMemory
、ConversationSummaryMemory
、ConversationBufferWindowMemory
等。 -
Message类:表示对话中的消息单元。每个Message包含发送者、内容、时间戳等信息,是上下文窗口管理的基本操作对象。
-
Window策略:决定如何在有限的窗口内选择和保留消息。不同的策略实现了不同的筛选和遗忘逻辑。
-
Token计数器:负责计算文本的token数量,确保上下文窗口不超过模型的限制。
-
序列化和反序列化工具:用于将内存状态保存到持久存储或从持久存储恢复。
2.2 架构层次结构
LangChain的上下文窗口管理架构可以分为多个层次:
-
接口层:定义了统一的Memory接口,所有具体的内存实现都必须遵循这一接口。这一层还包括Message的抽象定义。
-
核心实现层:包含各种具体的Memory实现类,如
ConversationBufferMemory
、ConversationSummaryMemory
等。这些类实现了不同的内存管理策略。 -
策略层:实现了各种窗口管理和遗忘策略,如基于时间的遗忘、基于重要性的遗忘、基于token数量的遗忘等。
-
工具层:提供了辅助工具,如token计数器、消息筛选器、序列化器等。
-
集成层:负责将上下文窗口管理组件与其他LangChain组件(如LLM调用、工具执行等)集成在一起。
2.3 数据流与控制流
在LangChain中,上下文窗口管理的数据流和控制流如下:
-
消息接收:当系统接收到用户消息时,首先将其转换为Message对象。
-
消息存储:Message对象被存储到Memory中,Memory会根据其实现策略决定如何组织和保存这些消息。
-
窗口管理:在向LLM发送请求之前,Memory会根据当前窗口策略筛选出应该包含在上下文中的消息。
-
Token计算:筛选出的消息会经过token计算,确保总token数量不超过模型限制。如果超过限制,会根据遗忘策略删除部分消息。
-
上下文构建:经过筛选和token计算后的消息被组合成最终的上下文,发送给LLM。
-
响应处理:LLM的响应也会被转换为Message对象,并存储到Memory中,完成一个完整的交互循环。
这一流程确保了上下文窗口的有效管理,同时保持了与LLM的无缝集成。
三、消息表示与存储机制
3.1 消息数据结构
在LangChain中,消息是上下文窗口管理的基本单位。消息的核心数据结构定义在langchain/schema/messages.py
文件中。主要的消息类型包括:
- HumanMessage:表示用户发送的消息。
- AIMessage:表示AI模型生成的响应。
- SystemMessage:表示系统发送的消息,通常用于设置对话的初始上下文或提供指令。
- FunctionCallMessage:表示函数调用消息,用于模型与外部工具的交互。
这些消息类型都继承自基类BaseMessage
,其核心定义如下:
class BaseMessage(ABC):
"""所有消息类型的基类"""
content: str
additional_kwargs: dict
@property
@abstractmethod
def type(self) -> str:
"""返回消息类型的字符串表示"""
pass
def to_dict(self) -> dict:
"""将消息转换为字典格式"""
return {
"type": self.type,
"content": self.content,
"additional_kwargs": self.additional_kwargs,
}
每个具体的消息类型都实现了type
属性,返回其类型标识。例如:
class HumanMessage(BaseMessage):
"""用户发送的消息"""
@property
def type(self) -> str:
return "human"
3.2 消息存储接口
LangChain定义了统一的内存接口,所有具体的内存实现都必须遵循这一接口。内存接口定义在langchain/memory/base.py
文件中:
class BaseMemory(ABC):
"""内存接口定义"""
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""返回内存中存储的变量名称列表"""
pass
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载内存变量,返回一个字典,包含需要注入到链中的变量"""
pass
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文信息,inputs是用户输入,outputs是链的输出"""
pass
@abstractmethod
def clear(self) -> None:
"""清除内存中的所有信息"""
pass
这一接口定义了内存的基本操作:加载内存变量、保存上下文和清除内存。不同的内存实现会根据其策略以不同的方式实现这些方法。
3.3 基础内存实现
LangChain提供了多种基础内存实现,其中最基本的是ConversationBufferMemory
,它简单地将所有对话消息保存在一个列表中。其核心实现如下:
class ConversationBufferMemory(BaseMemory):
"""简单地存储所有对话消息的内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
"""返回内存中存储的变量名称"""
return ["history"]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载对话历史"""
messages = self.chat_memory.messages
history = _get_buffer_string(messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存对话上下文"""
# 保存用户输入
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
# 保存AI输出
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现使用ChatMessageHistory
类来存储消息,并提供了基本的加载和保存上下文的功能。其他更复杂的内存实现会在此基础上添加更多的功能,如窗口管理、摘要生成等。
四、基础窗口管理策略
4.1 固定窗口策略
固定窗口策略是最简单的窗口管理策略之一,它只保留最近的N条消息。这种策略的实现非常直观,适用于对话历史不是特别重要的场景。
在LangChain中,ConversationBufferWindowMemory
类实现了固定窗口策略。其核心代码如下:
class ConversationBufferWindowMemory(BaseMemory):
"""保留最近N条消息的内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
k: int = 5 # 默认保留最近的5条消息
return_messages: bool = False
@property
def memory_variables(self) -> List[str]:
return ["history"] if self.return_messages else []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载最近的k条消息"""
# 获取所有消息
all_messages = self.chat_memory.messages
# 只保留最近的k条消息
if len(all_messages) > self.k:
window_messages = all_messages[-self.k:]
else:
window_messages = all_messages
if self.return_messages:
history = window_messages
else:
history = _get_buffer_string(window_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文,逻辑与ConversationBufferMemory相同"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过控制返回的消息数量来实现固定窗口策略。在load_memory_variables
方法中,只返回最近的k条消息,从而限制了上下文窗口的大小。
4.2 滑动窗口策略
滑动窗口策略是固定窗口策略的一种扩展,它不仅保留最近的N条消息,还可以根据消息的重要性或其他标准进行筛选。在滑动窗口策略中,窗口大小是固定的,但窗口内的内容可以根据特定规则进行动态调整。
LangChain并没有直接提供滑动窗口策略的实现,但我们可以通过扩展ConversationBufferWindowMemory
类来实现这一策略。以下是一个示例实现:
class SlidingWindowMemory(ConversationBufferWindowMemory):
"""基于重要性的滑动窗口内存实现"""
def __init__(self, importance_fn: Callable[[BaseMessage], float], *args, **kwargs):
super().__init__(*args, **kwargs)
self.importance_fn = importance_fn # 重要性评估函数
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载窗口内的消息,基于重要性进行筛选"""
all_messages = self.chat_memory.messages
# 如果消息数量小于等于窗口大小,直接返回所有消息
if len(all_messages) <= self.k:
window_messages = all_messages
else:
# 计算每条消息的重要性
messages_with_importance = [(msg, self.importance_fn(msg)) for msg in all_messages]
# 按重要性排序(降序)
messages_with_importance.sort(key=lambda x: x[1], reverse=True)
# 选择最重要的k条消息
selected_messages = [msg for msg, _ in messages_with_importance[:self.k]]
# 按时间顺序重新排序
selected_messages.sort(key=lambda x: all_messages.index(x))
window_messages = selected_messages
if self.return_messages:
history = window_messages
else:
history = _get_buffer_string(window_messages)
return {"history": history}
这个实现通过引入一个重要性评估函数importance_fn
,可以根据消息的重要性来选择保留哪些消息。这样,即使窗口大小固定,也能确保最重要的信息被保留在上下文中。
4.3 基于Token数量的窗口策略
由于LLM的输入长度限制通常以token数量为单位,基于token数量的窗口策略更为实用。这种策略会动态调整窗口大小,确保上下文中的总token数量不超过预设的限制。
LangChain中的TokenBufferWindowMemory
类实现了这一策略。其核心代码如下:
class TokenBufferWindowMemory(BaseMemory):
"""基于token数量的窗口内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
max_token_limit: int = 2000 # 默认最大token数量
tokenizer: Optional[Callable[[str], List[str]]] = None
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 如果没有提供tokenizer,使用默认的简单实现
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现,简单地按空格分割"""
return text.split()
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
# 计算消息内容的token数量
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
# 考虑消息类型和其他元数据的token开销
# 这里简化处理,为每个消息添加固定的token开销
total_tokens += 5
return total_tokens
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载不超过最大token限制的消息"""
all_messages = self.chat_memory.messages
# 如果消息列表为空,直接返回
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 从最近的消息开始,逐步添加,直到达到token限制
window_messages = []
current_tokens = 0
for message in reversed(all_messages):
# 计算添加此消息后的token数量
message_tokens = self._get_token_count([message])
if current_tokens + message_tokens <= self.max_token_limit:
window_messages.append(message)
current_tokens += message_tokens
else:
# 达到token限制,停止添加
break
# 恢复消息的原始顺序
window_messages.reverse()
if self.return_messages:
history = window_messages
else:
history = _get_buffer_string(window_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文,与ConversationBufferMemory相同"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过动态计算消息的token数量,确保上下文中的总token数量不超过预设限制。在load_memory_variables
方法中,从最近的消息开始逐步添加,直到达到token限制,从而实现了基于token数量的窗口管理。
五、高级窗口管理策略
5.1 基于重要性的窗口策略
基于重要性的窗口策略会根据消息的重要性程度来决定保留哪些消息。重要性可以基于多种因素计算,如消息的内容、来源、时间戳等。这种策略能够确保在窗口大小有限的情况下,保留最有价值的信息。
在LangChain中,我们可以通过扩展基础内存类来实现基于重要性的窗口策略。以下是一个示例实现:
class ImportanceBasedWindowMemory(BaseMemory):
"""基于重要性的窗口内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
max_token_limit: int = 2000
tokenizer: Optional[Callable[[str], List[str]]] = None
importance_fn: Callable[[BaseMessage], float] = None
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
# 如果没有提供重要性评估函数,使用默认实现
if self.importance_fn is None:
self.importance_fn = self._default_importance_fn
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _default_importance_fn(self, message: BaseMessage) -> float:
"""默认的重要性评估函数"""
# 简单实现:用户消息比AI消息更重要,近期消息比早期消息更重要
base_importance = 1.0 if isinstance(message, HumanMessage) else 0.8
# 添加时间衰减因子
time_factor = 1.0 # 这里简化处理,实际实现中应基于消息时间计算
return base_importance * time_factor
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载基于重要性和token限制的消息"""
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 计算每条消息的重要性
messages_with_importance = [(msg, self.importance_fn(msg)) for msg in all_messages]
# 按重要性排序(降序)
messages_with_importance.sort(key=lambda x: x[1], reverse=True)
# 按重要性从高到低选择消息,直到达到token限制
selected_messages = []
current_tokens = 0
for message, importance in messages_with_importance:
message_tokens = self._get_token_count([message])
if current_tokens + message_tokens <= self.max_token_limit:
selected_messages.append(message)
current_tokens += message_tokens
else:
# 达到token限制,停止添加
break
# 按时间顺序重新排序
selected_messages.sort(key=lambda x: all_messages.index(x))
if self.return_messages:
history = selected_messages
else:
history = _get_buffer_string(selected_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过引入importance_fn
函数来评估每条消息的重要性。在选择消息时,优先选择重要性高的消息,直到达到token限制。这样可以确保在有限的窗口内保留最有价值的信息。
5.2 分层窗口策略
分层窗口策略将对话历史分为多个层次,不同层次的信息具有不同的保留优先级和窗口大小。例如,可以将关键信息放在核心层,将次要信息放在外围层,当窗口空间不足时,优先保留核心层的信息。
以下是一个分层窗口策略的实现示例:
class HierarchicalWindowMemory(BaseMemory):
"""分层窗口内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
core_window_size: int = 5 # 核心层窗口大小
peripheral_window_size: int = 10 # 外围层窗口大小
tokenizer: Optional[Callable[[str], List[str]]] = None
max_token_limit: int = 2000
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _is_core_message(self, message: BaseMessage) -> bool:
"""判断消息是否属于核心层"""
# 示例实现:用户消息和包含特定关键词的消息属于核心层
if isinstance(message, HumanMessage):
return True
keywords = ["重要", "关键", "必须", "紧急"]
return any(keyword in message.content for keyword in keywords)
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载分层窗口内的消息"""
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 分离核心层和外围层消息
core_messages = []
peripheral_messages = []
for message in all_messages:
if self._is_core_message(message):
core_messages.append(message)
else:
peripheral_messages.append(message)
# 应用核心层窗口限制
if len(core_messages) > self.core_window_size:
core_messages = core_messages[-self.core_window_size:]
# 应用外围层窗口限制
if len(peripheral_messages) > self.peripheral_window_size:
peripheral_messages = peripheral_messages[-self.peripheral_window_size:]
# 合并两层消息
combined_messages = core_messages + peripheral_messages
# 应用token限制
if self._get_token_count(combined_messages) > self.max_token_limit:
# 从外围层开始移除消息,直到满足token限制
while combined_messages and self._get_token_count(combined_messages) > self.max_token_limit:
if peripheral_messages:
peripheral_messages.pop(0) # 移除最旧的外围层消息
combined_messages = core_messages + peripheral_messages
else:
# 如果外围层为空,开始移除核心层消息
core_messages.pop(0)
combined_messages = core_messages + peripheral_messages
if self.return_messages:
history = combined_messages
else:
history = _get_buffer_string(combined_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现将消息分为核心层和外围层,核心层消息具有更高的保留优先级。在窗口空间不足时,优先保留核心层消息,从而确保关键信息不会被遗忘。
5.3 自适应窗口策略
自适应窗口策略能够根据对话的进展动态调整窗口大小和内容。例如,在对话的早期阶段,可能需要保留更多的历史信息以建立上下文;而在对话的后期阶段,如果主题已经明确,可以适当缩小窗口以提高效率。
以下是一个自适应窗口策略的实现示例:
class AdaptiveWindowMemory(BaseMemory):
"""自适应窗口内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
initial_token_limit: int = 1000 # 初始token限制
max_token_limit: int = 3000 # 最大token限制
tokenizer: Optional[Callable[[str], List[str]]] = None
return_messages: bool = False
adaptation_threshold: int = 5 # 对话轮数阈值,超过此值开始调整窗口
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
self.current_token_limit = self.initial_token_limit
self.conversation_turns = 0
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def _update_window_size(self) -> None:
"""根据对话进展更新窗口大小"""
self.conversation_turns += 1
# 在对话初期,逐渐增加窗口大小
if self.conversation_turns <= self.adaptation_threshold:
self.current_token_limit = min(
self.initial_token_limit + (self.conversation_turns * 200),
self.max_token_limit
)
else:
# 在对话稳定后,根据内容复杂度调整窗口大小
recent_messages = self.chat_memory.messages[-10:] if len(self.chat_memory.messages) > 10 else self.chat_memory.messages
# 计算最近消息的平均token长度
if recent_messages:
avg_token_length = sum(len(self.tokenizer(msg.content)) for msg in recent_messages) / len(recent_messages)
# 如果平均长度较长,说明内容复杂,增加窗口大小
if avg_token_length > 100:
self.current_token_limit = min(self.current_token_limit + 100, self.max_token_limit)
# 如果平均长度较短,说明内容简单,适当减少窗口大小
elif avg_token_length < 50 and self.current_token_limit > self.initial_token_limit:
self.current_token_limit = max(self.current_token_limit - 100, self.initial_token_limit)
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载自适应窗口内的消息"""
# 更新窗口大小
self._update_window_size()
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 从最近的消息开始,逐步添加,直到达到当前token限制
window_messages = []
current_tokens = 0
for message in reversed(all_messages):
message_tokens = self._get_token_count([message])
if current_tokens + message_tokens <= self.current_token_limit:
window_messages.append(message)
current_tokens += message_tokens
else:
break
# 恢复消息的原始顺序
window_messages.reverse()
if self.return_messages:
history = window_messages
else:
history = _get_buffer_string(window_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
self.conversation_turns = 0
self.current_token_limit = self.initial_token_limit
这个实现根据对话的轮数和内容复杂度动态调整窗口大小。在对话初期,窗口大小逐渐增加;在对话稳定后,根据内容的复杂程度进行微调。这种策略能够在不同的对话阶段提供最合适的上下文窗口,平衡信息保留和计算效率。
六、遗忘策略实现机制
6.1 基于时间的遗忘策略
基于时间的遗忘策略是最直观的遗忘策略之一,它根据消息的时间戳来决定哪些消息应该被遗忘。具体来说,越早的消息越有可能被遗忘,而最近的消息会被优先保留。
在LangChain中,我们可以通过扩展基础内存类来实现基于时间的遗忘策略。以下是一个示例实现:
class TimeBasedForgettingMemory(BaseMemory):
"""基于时间的遗忘内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
max_token_limit: int = 2000
tokenizer: Optional[Callable[[str], List[str]]] = None
time_decay_factor: float = 0.9 # 时间衰减因子
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def _calculate_message_score(self, message: BaseMessage, current_time: float) -> float:
"""计算消息的分数,基于时间衰减"""
# 假设message有一个timestamp属性
if hasattr(message, 'timestamp'):
age = current_time - message.timestamp
# 时间越久,分数越低
score = math.exp(-self.time_decay_factor * age)
return score
else:
# 如果没有时间戳,默认给一个中等分数
return 0.5
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载基于时间分数的消息"""
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
current_time = time.time()
# 计算每条消息的分数
messages_with_scores = [(msg, self._calculate_message_score(msg, current_time)) for msg in all_messages]
# 按分数排序(降序)
messages_with_scores.sort(key=lambda x: x[1], reverse=True)
# 按分数从高到低选择消息,直到达到token限制
selected_messages = []
current_tokens = 0
for message, score in messages_with_scores:
message_tokens = self._get_token_count([message])
if current_tokens + message_tokens <= self.max_token_limit:
selected_messages.append(message)
current_tokens += message_tokens
else:
break
# 按时间顺序重新排序
selected_messages.sort(key=lambda x: all_messages.index(x))
if self.return_messages:
history = selected_messages
else:
history = _get_buffer_string(selected_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文,添加时间戳"""
input_str = _get_input_output(inputs, self.input_key)
user_message = HumanMessage(content=input_str)
user_message.timestamp = time.time() # 添加时间戳
self.chat_memory.add_message(user_message)
output_str = _get_input_output(outputs, self.output_key)
ai_message = AIMessage(content=output_str)
ai_message.timestamp = time.time() # 添加时间戳
self.chat_memory.add_message(ai_message)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过给每条消息计算一个基于时间的分数,时间越久的消息分数越低。在选择消息时,优先保留分数高的消息,从而实现基于时间的遗忘策略。
6.2 基于重要性的遗忘策略
基于重要性的遗忘策略与基于重要性的窗口策略类似,但重点在于决定哪些消息应该被遗忘。这种策略会根据消息的重要性程度来排序,当窗口空间不足时,优先遗忘重要性低的消息。
以下是一个基于重要性的遗忘策略的实现示例:
class ImportanceBasedForgettingMemory(BaseMemory):
"""基于重要性的遗忘内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
max_token_limit: int = 2000
tokenizer: Optional[Callable[[str], List[str]]] = None
importance_fn: Callable[[BaseMessage], float] = None
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
if self.importance_fn is None:
self.importance_fn = self._default_importance_fn
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _default_importance_fn(self, message: BaseMessage) -> float:
"""默认的重要性评估函数"""
# 用户消息比AI消息更重要
base_importance = 1.0 if isinstance(message, HumanMessage) else 0.8
# 检查消息是否包含关键词
keywords = ["重要", "关键", "必须", "紧急"]
if any(keyword in message.content for keyword in keywords):
base_importance += 0.3
return base_importance
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def _forget_least_important(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""遗忘最不重要的消息,直到满足token限制"""
if self._get_token_count(messages) <= self.max_token_limit:
return messages
# 计算每条消息的重要性
messages_with_importance = [(msg, self.importance_fn(msg)) for msg in messages]
# 按重要性排序(升序)
messages_with_importance.sort(key=lambda x: x[1])
# 逐步移除最不重要的消息,直到满足token限制
while messages_with_importance and self._get_token_count([msg for msg, _ in messages_with_importance]) > self.max_token_limit:
messages_with_importance.pop(0) # 移除最不重要的消息
# 返回剩余的消息,按原始顺序排序
remaining_messages = [msg for msg, _ in messages_with_importance]
remaining_messages.sort(key=lambda x: messages.index(x))
return remaining_messages
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载基于重要性的消息"""
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 应用重要性遗忘策略
selected_messages = self._forget_least_important(all_messages)
if self.return_messages:
history = selected_messages
else:
history = _get_buffer_string(selected_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过importance_fn
函数评估每条消息的重要性。当消息总量超过token限制时,逐步移除重要性最低的消息,直到满足限制。这样可以确保在窗口空间有限的情况下,保留最重要的信息。
6.3 基于内容相似性的遗忘策略
基于内容相似性的遗忘策略会分析消息之间的内容相似性,当窗口空间不足时,优先遗忘与其他消息相似性高的消息。这种策略可以避免上下文中存在过多冗余信息,提高模型处理效率。
以下是一个基于内容相似性的遗忘策略的实现示例:
class SimilarityBasedForgettingMemory(BaseMemory):
"""基于内容相似性的遗忘内存实现"""
chat_memory: BaseChatMemory = Field(default_factory=ChatMessageHistory)
max_token_limit: int = 2000
tokenizer: Optional[Callable[[str], List[str]]] = None
similarity_threshold: float = 0.8 # 相似性阈值
embedding_model: Optional[Callable[[str], List[float]]] = None
return_messages: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.tokenizer is None:
self.tokenizer = self._default_tokenizer
if self.embedding_model is None:
# 使用简单的词袋模型作为默认嵌入模型
self.embedding_model = self._simple_bag_of_words_embedding
def _default_tokenizer(self, text: str) -> List[str]:
"""默认的tokenizer实现"""
return text.split()
def _simple_bag_of_words_embedding(self, text: str) -> List[float]:
"""简单的词袋嵌入模型"""
tokens = self.tokenizer(text)
word_counts = Counter(tokens)
# 返回词频向量
return [word_counts.get(word, 0) for word in set(tokens)]
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""计算余弦相似度"""
if not vec1 or not vec2:
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm_a = math.sqrt(sum(a * a for a in vec1))
norm_b = math.sqrt(sum(b * b for b in vec2))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _get_token_count(self, messages: List[BaseMessage]) -> int:
"""计算消息列表的token数量"""
total_tokens = 0
for message in messages:
content_tokens = len(self.tokenizer(message.content))
total_tokens += content_tokens
total_tokens += 5 # 元数据token开销
return total_tokens
def _forget_similar_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""遗忘相似的消息,直到满足token限制"""
if len(messages) <= 1 or self._get_token_count(messages) <= self.max_token_limit:
return messages
# 计算所有消息的嵌入
embeddings = [self.embedding_model(msg.content) for msg in messages]
# 计算相似度矩阵
similarity_matrix = [[0.0 for _ in range(len(messages))] for _ in range(len(messages))]
for i in range(len(messages)):
for j in range(i + 1, len(messages)):
sim = self._cosine_similarity(embeddings[i], embeddings[j])
similarity_matrix[i][j] = sim
similarity_matrix[j][i] = sim
# 按相似度排序消息对
similar_pairs = []
for i in range(len(messages)):
for j in range(i + 1, len(messages)):
if similarity_matrix[i][j] >= self.similarity_threshold:
# 优先遗忘较旧的消息(i < j 表示i更旧)
similar_pairs.append((similarity_matrix[i][j], i, j))
# 按相似度降序排序
similar_pairs.sort(key=lambda x: x[0], reverse=True)
# 移除相似的消息,直到满足token限制
removed_indices = set()
for _, i, j in similar_pairs:
if i not in removed_indices and j not in removed_indices:
# 优先移除较旧的消息
removed_indices.add(i)
# 检查是否已满足token限制
remaining_messages = [msg for idx, msg in enumerate(messages) if idx not in removed_indices]
if self._get_token_count(remaining_messages) <= self.max_token_limit:
break
# 返回剩余的消息
return [msg for idx, msg in enumerate(messages) if idx not in removed_indices]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""加载基于相似性的消息"""
all_messages = self.chat_memory.messages
if not all_messages:
return {"history": [] if self.return_messages else ""}
# 应用相似性遗忘策略
selected_messages = self._forget_similar_messages(all_messages)
# 应用token限制
if self._get_token_count(selected_messages) > self.max_token_limit:
# 如果相似性过滤后仍超过token限制,使用基于时间的遗忘
# 这里简化处理,实际实现中可以使用更复杂的策略
selected_messages = selected_messages[-int(self.max_token_limit / 100):] # 简单的基于数量的截断
if self.return_messages:
history = selected_messages
else:
history = _get_buffer_string(selected_messages)
return {"history": history}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""保存上下文"""
input_str = _get_input_output(inputs, self.input_key)
self.chat_memory.add_user_message(input_str)
output_str = _get_input_output(outputs, self.output_key)
self.chat_memory.add_ai_message(output_str)
def clear(self) -> None:
"""清除对话历史"""
self.chat_memory.clear()
这个实现通过计算消息之间的内容相似性,当发现相似性超过阈值的消息对时,优先遗忘较旧的消息。这样可以减少上下文中的冗余信息,提高模型处理效率。
七、Token计算与管理
7.1 Token计算方法
在LangChain的上下文窗口管理中,准确计算文本的token数量至关重要。不同的LLM使用不同的tokenization算法,因此token计算方法也有所不同。
LangChain提供了多种token计算方法,主要包括:
- 基于模型的tokenizer:使用LLM原生的tokenizer进行计算,这是最准确的方法。例如,对于OpenAI的模型,可以使用
tiktoken
库:
import tiktoken
def count_tokens_with_tiktoken(text: str, model_name: str) -> int:
"""使用tiktoken计算文本的token数量"""
try:
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(text))
except KeyError:
# 如果模型不支持,使用通用编码
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
- 基于规则的近似计算:对于没有可用tokenizer的模型,可以使用基于规则的近似计算方法。例如,简单地按空格分割文本:
def approximate_token_count(text: str) -> int:
"""近似计算文本的token数量"""
# 简单方法:假设平均每个单词对应1.3个token
words = text.split()
return int(len(words) * 1.3)
- 基于字符长度的近似计算:另一种简单的近似方法是基于字符长度:
def approximate_token_count_by_char(text: str) -> int:
"""基于字符长度近似计算token数量"""
# 假设平均每个token对应4个字符
return len(text) // 4
在实际应用中,LangChain推荐使用基于模型的tokenizer进行准确计算,特别是在处理严格的token限制时。
7.2 Token限制处理策略
当上下文窗口中的token数量超过模型限制时,LangChain提供了多种处理策略:
- 截断策略:直接截断最旧的消息,直到满足token限制。这是最简单的策略,但可能会导致重要信息丢失。
def truncate_messages_to_token_limit(messages: List[BaseMessage], token_limit: int, tokenizer: Callable) -> List[BaseMessage]:
"""将消息截断到token限制"""
total_tokens = 0
truncated_messages = []
# 从最新的消息开始添加
for message in reversed(messages):
message_tokens = count_tokens_with_tiktoken(message.content, "gpt-4") # 假设使用gpt-4的tokenizer
if total_tokens + message_tokens <= token_limit:
truncated_messages.append(message)
total_tokens += message_tokens
else:
# 达到token限制,停止添加
break
# 恢复消息的原始顺序
return list(reversed(truncated_messages))
- 摘要策略:对长消息进行摘要,以减少token数量。LangChain提供了
ConversationSummaryMemory
类来实现这一策略:
from langchain.memory import ConversationSummaryMemory
from langchain.llms import OpenAI
# 初始化摘要内存
llm = OpenAI(temperature=0)
memory = ConversationSummaryMemory(llm=llm)
# 添加消息
memory.save_context({"input": "用户问题"}, {"output": "AI回答"})
# 获取摘要后的历史
summary = memory.load_memory_variables({})["history"]
-
选择性遗忘策略:根据消息的重要性或其他标准,选择性地遗忘某些消息。这在前面的章节中已经详细讨论过。
-
分层处理策略:将消息分为不同层次,优先保留核心层的消息。例如,用户消息可能比AI回答更重要:
def prioritize_messages(messages: List[BaseMessage], token_limit: int, tokenizer: Callable) -> List[BaseMessage]:
"""优先保留重要消息,直到达到token限制"""
# 将消息分为用户消息和AI消息
user_messages = [msg for msg in messages if isinstance(msg, HumanMessage)]
ai_messages = [msg for msg in messages if isinstance(msg, AIMessage)]
# 计算用户消息的token数量
user_tokens = sum(count_tokens_with_tiktoken(msg.content, "gpt-4") for msg in user_messages)
# 如果用户消息已经超过token限制,只保留部分用户消息
if user_tokens > token_limit:
return truncate_messages_to_token_limit(user_messages, token_limit, tokenizer)
# 否则,添加AI消息,直到达到token限制
remaining_tokens = token_limit - user_tokens
prioritized_messages = user_messages.copy()
for msg in reversed(ai_messages):
msg_tokens = count_tokens_with_tiktoken(msg.content, "gpt-4")
if remaining_tokens >= msg_tokens:
prioritized_messages.append(msg)
remaining_tokens -= msg_tokens
else:
break
# 按时间顺序排序
prioritized_messages.sort(key=lambda x: messages.index(x))
return prioritized_messages
7.3 Token优化技术
为了更有效地利用有限的token空间,LangChain提供了多种token优化技术:
- 消息压缩:通过去除不必要的空格、标点符号等方式压缩消息内容:
def compress_message(message: str) -> str:
"""压缩消息内容,减少token使用"""
# 去除多余空格
compressed = ' '.join(message.split())
# 可以添加更多压缩逻辑,如缩写常见术语等
return compressed
- 元数据优化:减少消息元数据的token开销。例如,使用更简洁的消息类型标识:
def optimize_metadata(message: BaseMessage) -> Dict[str, Any]:
"""优化消息元数据,减少token使用"""
# 使用更简洁的类型标识
type_mapping = {
"human": "u",
"ai": "a",
"system": "s"
}
return {
"type": type_mapping.get(message.type, message.type),
"content": message.content,
# 只保留必要的元数据
}
- 知识引用:对于重复出现的知识,可以使用引用代替完整内容。例如:
def replace_with_references(text: str, knowledge_base: Dict[str, str]) -> str:
"""用引用替换重复的知识内容"""
optimized_text = text
for key, value in knowledge_base.items():
if value in text:
# 用引用标记替换知识内容
optimized_text = optimized_text.replace(value, f"[REF:{key}]")
return optimized_text
- 动态token分配:根据消息的重要性动态分配token空间。例如,为用户问题分配更多token,为AI回答分配较少token:
def dynamic_token_allocation(user_input: str, ai_output: str, total_tokens: int) -> Tuple[str, str]:
"""动态分配token空间给用户输入和AI输出"""
user_tokens = count_tokens_with_tiktoken(user_input, "gpt-4")
ai_tokens = count_tokens_with_tiktoken(ai_output, "gpt-4")
# 如果总token超过限制,按比例分配
if user_tokens + ai_tokens > total_tokens:
ratio = total_tokens / (user_tokens + ai_tokens)
user_max_tokens = int(user_tokens * ratio)
ai_max_tokens = total_tokens - user_max_tokens
# 截断文本到最大token数
user_optimized = truncate_text_to_tokens(user_input, user_max_tokens, "gpt-4")
ai_optimized = truncate_text_to_tokens(ai_output, ai_max_tokens, "gpt-4")
return user_optimized, ai_optimized
return user_input, ai_output
这些token优化技术可以帮助在有限的窗口大小内保留更多有价值的信息,提高LLM的处理效率和回答质量。
八、上下文窗口与工具调用
8.1 工具调用上下文管理
在LangChain中,工具调用是一个重要的功能,它允许LLM与外部系统交互以获取更多信息。工具调用的上下文管理需要特殊处理,以确保工具调用的输入和输出能够被正确地包含在上下文中。
工具调用的消息类型为FunctionCallMessage
,它包含了函数名称和参数。在上下文窗口管理中,需要特别处理这种消息类型:
class FunctionCallMessage(BaseMessage):
"""表示函数调用的消息"""
name: str
parameters: Dict[str, Any]
@property
def type(self) -> str:
return "function_call"
def to_dict(self) -> dict:
return {
"type": self.type,
"name": self.name,
"parameters": self.parameters,
"content": json.dumps(self.parameters) # 为了兼容token计算
}
在上下文窗口管理中,工具调用消息的处理方式与普通消息有所不同。例如,在计算token数量时,需要考虑函数名称和参数的token开销:
def count_function_call_tokens(message: FunctionCallMessage, tokenizer: Callable) -> int:
"""计算FunctionCallMessage的token数量"""
# 函数名称的token数
name_tokens = len(tokenizer(message.name))
# 参数的token数
params_text = json.dumps(message.parameters)
params_tokens = len(tokenizer(params_text))
# 添加一些额外的token用于表示函数调用结构
structure_tokens = 5
return name_tokens + params_tokens + structure_tokens
8.2 工具调用结果的上下文整合
工具调用的结果也需要被正确地整合到上下文中。工具调用结果通常以FunctionCallResultMessage
的形式存在:
class FunctionCallResultMessage(BaseMessage):
"""表示函数调用结果的消息"""
name: str
result: Any
@property
def type(self) -> str:
return "function_call_result"
def to_dict(self) -> dict:
return {
"type": self.type,
"name": self.name,
"result": self.result,
"content": json.dumps(self.result) # 为了兼容token计算
}
在上下文窗口管理中,需要确保工具调用消息和其结果消息始终成对出现,并且在窗口大小有限的情况下,优先保留完整的工具调用序列:
def preserve_function_call_sequences(messages: List[BaseMessage], token_limit: int, tokenizer: Callable) -> List[BaseMessage]:
"""保留完整的工具调用序列,直到达到token限制"""
# 识别所有工具调用序列
在上下文窗口管理中,需要确保工具调用消息和其结果消息始终成对出现,并且在窗口大小有限的情况下,优先保留完整的工具调用序列:
def preserve_function_call_sequences(messages: List[BaseMessage], token_limit: int, tokenizer: Callable) -> List[BaseMessage]:
"""保留完整的工具调用序列,直到达到token限制"""
# 识别所有工具调用序列
function_call_sequences = []
current_sequence = []
for message in messages:
if isinstance(message, FunctionCallMessage):
current_sequence.append(message)
elif isinstance(message, FunctionCallResultMessage) and current_sequence:
current_sequence.append(message)
function_call_sequences.append(current_sequence)
current_sequence = []
else:
if current_sequence:
# 未完成的序列,单独处理
function_call_sequences.append(current_sequence)
current_sequence = []
# 按时间顺序排列序列
function_call_sequences.sort(key=lambda seq: messages.index(seq[0]))
# 计算所有工具调用序列的token数量
total_function_call_tokens = sum(sum(count_function_call_tokens(msg, tokenizer) for msg in seq) for seq in function_call_sequences)
# 计算剩余可用于其他消息的token数量
remaining_tokens = token_limit - total_function_call_tokens
# 选择非工具调用消息,直到达到剩余token限制
other_messages = [msg for msg in messages if not isinstance(msg, (FunctionCallMessage, FunctionCallResultMessage))]
selected_other_messages = []
current_other_tokens = 0
for msg in reversed(other_messages):
msg_tokens = count_tokens_with_tiktoken(msg.content, "gpt-4") # 假设使用gpt-4的tokenizer
if current_other_tokens + msg_tokens <= remaining_tokens:
selected_other_messages.append(msg)
current_other_tokens += msg_tokens
else:
break
# 按时间顺序恢复非工具调用消息
selected_other_messages.reverse()
# 合并工具调用序列和非工具调用消息
final_messages = []
for seq in function_call_sequences:
final_messages.extend(seq)
final_messages.extend(selected_other_messages)
return final_messages
8.3 工具调用对上下文窗口的影响
工具调用的引入会显著增加上下文窗口的复杂性。一方面,工具调用消息及其结果通常包含大量结构化数据,这会占用较多的token空间。例如,一个调用数据库查询工具的消息可能包含复杂的SQL语句,而其结果可能是多条数据记录,这些都会迅速消耗窗口内的token资源。
另一方面,工具调用的上下文关联性要求更高。LLM需要根据工具调用的历史来理解当前请求和结果的意义。例如,在连续的数学计算工具调用中,后续的计算可能依赖于前面调用的结果,因此这些相关的调用序列必须完整保留在上下文中。
在LangChain源码中,Agent
类负责协调工具调用与上下文管理。当Agent决定调用工具时,它会将调用消息添加到上下文,并在工具返回结果后,将结果消息也整合进去。同时,Agent会根据当前上下文窗口的状态,动态调整工具调用的策略。例如,如果窗口空间不足,Agent可能会减少不必要的工具调用,或者优先选择返回结果更简洁的工具。
class Agent:
def __init__(self, tools, memory, llm):
self.tools = tools
self.memory = memory
self.llm = llm
def run(self, input):
self.memory.save_context({"input": input})
while True:
# 根据上下文生成行动建议
action_suggestion = self.llm.predict(self.memory.load_memory_variables({})["history"])
if "tool_call" in action_suggestion:
tool_name, tool_args = self.parse_tool_call(action_suggestion)
tool = next((t for t in self.tools if t.name == tool_name), None)
if tool:
result = tool.run(tool_args)
self.memory.save_context({"tool_call": action_suggestion, "tool_result": result})
else:
self.memory.save_context({"error": f"Tool {tool_name} not found"})
else:
response = action_suggestion
self.memory.save_context({"output": response})
return response
def parse_tool_call(self, suggestion):
# 解析工具调用建议
pass
九、上下文窗口管理与多模态数据
9.1 多模态数据的表示与存储
随着技术发展,LLM应用不再局限于纯文本交互,多模态数据(如图像、音频、视频等)的处理需求日益增长。在LangChain中,处理多模态数据需要对上下文窗口管理进行扩展。
对于图像数据,通常会将图像转换为文本描述或特征向量进行存储。例如,使用计算机视觉模型生成图像的文本描述,然后将该描述作为SystemMessage
或HumanMessage
的一部分存储在上下文中。
import cv2
from langchain.llms import OpenAI
from langchain.schema import SystemMessage
def image_to_text_description(image_path):
# 使用计算机视觉库读取图像
image = cv2.imread(image_path)
# 这里假设使用一个图像描述生成模型(如CLIP)
# 简化示例,直接返回图像尺寸信息
height, width, _ = image.shape
return f"An image with width {width} and height {height}"
def add_image_to_memory(memory, image_path):
description = image_to_text_description(image_path)
message = SystemMessage(content=f"Image: {description}")
memory.save_context({"input": message.content})
对于音频和视频数据,类似地,可以将其转换为文本转录或关键信息摘要进行存储。例如,使用语音识别技术将音频转换为文字,然后将文字消息添加到上下文中。
9.2 多模态数据对上下文窗口的挑战
多模态数据的引入给上下文窗口管理带来了新的挑战。首先,多模态数据转换后的文本描述可能非常长,容易快速填满上下文窗口。例如,一段几分钟的视频转录成文字后可能包含数千个单词,这对token数量的控制提出了更高要求。
其次,多模态数据的关联性分析更加复杂。LLM需要理解不同模态数据之间的关系,以及它们与文本消息的关联。例如,图像描述与后续文本讨论之间的逻辑联系,需要在上下文管理中准确维护。
在源码实现上,需要对Message
类进行扩展,以支持多模态数据的标识和处理。例如,可以添加新的消息类型ImageMessage
、AudioMessage
等,并为每种类型定义相应的token计算和存储方法。
class ImageMessage(BaseMessage):
def __init__(self, description, image_path=None):
self.content = description
self.image_path = image_path
self.additional_kwargs = {}
@property
def type(self):
return "image"
def to_dict(self):
base_dict = super().to_dict()
base_dict["image_path"] = self.image_path
return base_dict
def count_tokens(self, tokenizer):
# 计算图像描述的token数量
return len(tokenizer(self.content))
9.3 多模态上下文窗口策略
为应对多模态数据带来的挑战,需要设计专门的上下文窗口策略。一种策略是对不同模态数据设置不同的保留优先级。例如,用户的文本输入和关键图像描述可能具有较高优先级,而音频转录的次要信息可以适当精简或遗忘。
另一种策略是采用分层存储和检索。将多模态数据的原始信息存储在外部存储中(如数据库或文件系统),仅在上下文中保留其引用和关键摘要。当LLM需要访问详细信息时,再从外部存储中获取。
class MultimodalMemory(BaseMemory):
def __init__(self, external_storage):
self.external_storage = external_storage
self.memory = []
def save_context(self, inputs, outputs):
for key, value in inputs.items():
if isinstance(value, ImageMessage):
# 将图像保存到外部存储,并获取引用
image_ref = self.external_storage.save_image(value.image_path)
value.additional_kwargs["image_ref"] = image_ref
self.memory.append(value)
for key, value in outputs.items():
self.memory.append(value)
def load_memory_variables(self, inputs):
history = []
for msg in self.memory:
if isinstance(msg, ImageMessage) and "image_ref" in msg.additional_kwargs:
# 从外部存储获取图像描述
image_description = self.external_storage.get_image_description(msg.additional_kwargs["image_ref"])
msg.content = f"Image: {image_description}"
history.append(msg)
return {"history": history}
def clear(self):
self.memory = []
十、上下文窗口管理的分布式实现
10.1 分布式上下文存储
在大规模应用场景中,单机的上下文窗口管理可能无法满足需求,因此需要分布式实现。LangChain的上下文窗口管理可以通过分布式存储系统来扩展。
常见的分布式存储选择包括Redis、MongoDB等。以Redis为例,可以将每个对话的上下文存储为一个Hash结构,其中每个消息作为Hash的一个字段。
import redis
class RedisMemory(BaseMemory):
def __init__(self, host='localhost', port=6379, db=0):
self.redis_client = redis.Redis(host=host, port=port, db=db)
def save_context(self, inputs, outputs):
conversation_id = inputs.get("conversation_id")
if conversation_id:
for key, value in inputs.items():
self.redis_client.hset(conversation_id, key, str(value))
for key, value in outputs.items():
self.redis_client.hset(conversation_id, key, str(value))
def load_memory_variables(self, inputs):
conversation_id = inputs.get("conversation_id")
if conversation_id:
memory_data = self.redis_client.hgetall(conversation_id)
return {k.decode('utf-8'): v.decode('utf-8') for k, v in memory_data.items()}
return {}
def clear(self):
# 这里需要遍历所有对话ID并删除,简化示例不实现
pass
10.2 分布式环境下的窗口同步
在分布式环境中,多个节点可能同时访问和更新上下文窗口,因此需要解决同步问题。一种常见的方法是使用分布式锁机制,例如基于Redis的分布式锁。
import time
def acquire_lock(redis_client, lock_key, acquire_timeout=10, lock_timeout=10):
identifier = str(time.time())
end_time = time.time() + acquire_timeout
while time.time() < end_time:
if redis_client.setnx(lock_key, identifier):
# 设置锁的过期时间
redis_client.expire(lock_key, lock_timeout)
return identifier
elif not redis_client.ttl(lock_key):
# 锁已过期,尝试重新获取
redis_client.expire(lock_key, lock_timeout)
time.sleep(0.1)
return False
def release_lock(redis_client, lock_key, identifier):
pipe = redis_client.pipeline(True)
while True:
try:
pipe.watch(lock_key)
current_identifier = pipe.get(lock_key)
if current_identifier.decode('utf-8') == identifier:
pipe.multi()
pipe.delete(lock_key)
pipe.execute()
return True
pipe.unwatch()
break
except redis.WatchError:
pass
return False
在进行上下文窗口更新时,先获取分布式锁,确保同一时间只有一个节点可以修改上下文。更新完成后,释放锁。
10.3 分布式窗口管理的负载均衡
为了提高系统的性能和可用性,分布式上下文窗口管理还需要考虑负载均衡。可以使用负载均衡器(如Nginx)将上下文访问请求分发到不同的节点。
此外,在节点内部,可以采用分片策略将不同对话的上下文存储在不同的物理存储中。例如,根据对话ID的哈希值将其分配到不同的Redis数据库或MongoDB集合中。
def get_shard_id(conversation_id, num_shards):
hash_value = hash(conversation_id)
return hash_value % num_shards
class ShardedRedisMemory(BaseMemory):
def __init__(self, hosts, ports, num_shards):
self.redis_clients = [redis.Redis(host=hosts[i], port=ports[i]) for i in range(num_shards)]
self.num_shards = num_shards
def save_context(self, inputs, outputs):
conversation_id = inputs.get("conversation_id")
if conversation_id:
shard_id = get_shard_id(conversation_id, self.num_shards)
client = self.redis_clients[shard_id]
for key, value in inputs.items():
client.hset(conversation_id, key, str(value))
for key, value in outputs.items():
client.hset(conversation_id, key, str(value))
def load_memory_variables(self, inputs):
conversation_id = inputs.get("conversation_id")
if conversation_id:
shard_id = get_shard_id(conversation_id, self.num_shards)
client = self.redis_clients[shard_id]
memory_data = client.hgetall(conversation_id)
return {k.decode('utf-8'): v.decode('utf-8') for k, v in memory_data.items()}
return {}
def clear(self):
# 这里需要遍历所有分片并删除,简化示例不实现
pass
通过负载均衡和分片策略,可以有效提高分布式上下文窗口管理系统的性能和可扩展性。
上述内容进一步挖掘了LangChain在复杂场景下的技术实现。如果你对某个部分还想深入了解,或者有新的分析方向,欢迎和我说说。