掌握消息修剪技巧:优化聊天模型的上下文窗口管理

93 阅读4分钟

引言

在使用语言模型进行聊天交互时,有一个关键限制:上下文窗口的长度。每个模型都有一个最大令牌限制,当消息过长时,可能导致信息丢失或性能下降。因此,管理和修剪消息成为了一项重要任务。在这篇文章中,我们将探索如何使用 trim_messages 工具来优化消息长度,以便更好地适配语言模型的上下文窗口。

主要内容

修剪消息的基本策略

trim_messages 提供了几种基本策略来修剪消息列表,使其符合特定的令牌长度要求。通过使用不同的策略和配置选项,我们可以更好地管理消息的长度。

获取最后的 max_tokens

要获取消息列表中的最后 max_tokens,可以设置 strategy="last"。可以传入一个语言模型作为 token_counter,因为它们具有消息令牌计数的方法。以下是一个例子:

from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    SystemMessage,
    trim_messages,
)
from langchain_openai import ChatOpenAI

messages = [
    SystemMessage("you're a good assistant, you always respond with a joke."),
    HumanMessage("i wonder why it's called langchain"),
    AIMessage(
        'Well, I guess they thought "WordRope" and "SentenceString" just didn\'t have the same ring to it!'
    ),
    HumanMessage("and who is harrison chasing anyways"),
    AIMessage(
        "Hmmm let me think.\n\nWhy, he's probably chasing after the last cup of coffee in the office!"
    ),
    HumanMessage("what do you call a speechless parrot"),
]

trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
)

上面的代码将通过保留最近的消息,使其不超过45个令牌。

始终保留系统消息

如果希望始终保留初始系统消息,可以指定 include_system=True

trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
)

允许部分消息

如果允许拆分消息的内容,可以指定 allow_partial=True

trim_messages(
    messages,
    max_tokens=56,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
    allow_partial=True,
)

指定消息类型

如果需要确保第一个消息(不包括系统消息)始终是特定类型,可以指定 start_on

trim_messages(
    messages,
    max_tokens=60,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
    start_on="human",
)

获取最开始的 max_tokens

通过指定 strategy="first" 可以获取消息的最开始部分:

trim_messages(
    messages,
    max_tokens=45,
    strategy="first",
    token_counter=ChatOpenAI(model="gpt-4o"),
)

自定义令牌计数器

可以编写自定义令牌计数函数来适应不同的需求。下面是如何使用 tiktoken 编写一个自定义令牌计数器的示例:

import tiktoken
from typing import List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage

def str_token_counter(text: str) -> int:
    enc = tiktoken.get_encoding("o200k_base")
    return len(enc.encode(text))

def tiktoken_counter(messages: List[BaseMessage]) -> int:
    num_tokens = 3  # every reply is primed with <|start|>assistant<|message|>
    tokens_per_message = 3
    tokens_per_name = 1
    for msg in messages:
        if isinstance(msg, HumanMessage):
            role = "user"
        elif isinstance(msg, AIMessage):
            role = "assistant"
        elif isinstance(msg, SystemMessage):
            role = "system"
        else:
            raise ValueError(f"Unsupported messages type {msg.__class__}")
        num_tokens += (
            tokens_per_message
            + str_token_counter(role)
            + str_token_counter(msg.content)
        )
        if msg.name:
            num_tokens += tokens_per_name + str_token_counter(msg.name)
    return num_tokens

使用聊天历史修剪

修剪消息在处理聊天历史时尤其有用,这些历史可能会变得非常长。使用 InMemoryChatMessageHistory 可以方便地管理历史记录:

from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_openai import ChatOpenAI

chat_history = InMemoryChatMessageHistory(messages=messages[:-1])

llm = ChatOpenAI(model="gpt-4o")

trimmer = trim_messages(
    max_tokens=45,
    strategy="last",
    token_counter=llm,
    include_system=True,
)

chain_with_history = trimmer | llm
chain_with_history.invoke(
    [HumanMessage("what do you call a speechless parrot")],
)

常见问题和解决方案

  1. 如果令牌计数不准确怎么办? 自定义令牌计数器可以解决大多数情况下的计数不准确问题。
  2. 如何处理网络访问问题? 由于某些地区的网络限制,开发者可能需要考虑使用API代理服务以提高访问稳定性。

总结和进一步学习资源

有效管理消息长度对于提升语言模型的性能和提高上下文相关性至关重要。利用 trim_messages 工具,可以实现灵活的消息修剪策略,并针对不同的场景优化聊天体验。以下是一些推荐的学习资源:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---