如何裁剪消息以适应聊天模型的上下文窗口

77 阅读4分钟

如何裁剪消息以适应聊天模型的上下文窗口

在使用聊天模型时,消息的长度是一个需要重点管理的问题。所有的聊天模型都有有限的上下文窗口,即它们能够处理的输入令牌数量是有限的。如果你的消息很长或者在一个链/代理中积累了大量的消息历史,你需要管理传递给模型的消息长度。本文将介绍如何使用 trim_messages 工具来裁剪消息,并提供实用的代码示例。

1. 引言

随着聊天应用的广泛使用,确保消息在有限的上下文窗口中准确传递变得尤为重要。这篇文章将介绍如何裁剪消息以适应聊天模型的上下文窗口,提供实用的代码示例,并讨论潜在的挑战和解决方案。

2. 主要内容

2.1 获取最后的最大令牌数

# pip install -U langchain-openai
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    SystemMessage,
    trim_messages,
)
from langchain_openai import ChatOpenAI

messages = [
    SystemMessage("you're a good assistant, you always respond with a joke."),
    HumanMessage("i wonder why it's called langchain"),
    AIMessage(
        'Well, I guess they thought "WordRope" and "SentenceString" just didn\'t have the same ring to it!'
    ),
    HumanMessage("and who is harrison chasing anyways"),
    AIMessage(
        "Hmmm let me think.\n\nWhy, he's probably chasing after the last cup of coffee in the office!"
    ),
    HumanMessage("what do you call a speechless parrot"),
]

trimmed_messages = trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
)
print(trimmed_messages)

在这个示例中,我们通过裁剪最后的消息来确保总数不超过45个令牌。

2.2 保持初始系统消息

如果我们想要保留初始的系统消息,可以指定 include_system=True:

trimmed_messages = trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
)
print(trimmed_messages)

2.3 允许部分消息被裁剪

如果我们希望允许部分消息内容被裁剪,可以指定 allow_partial=True:

trimmed_messages = trim_messages(
    messages,
    max_tokens=56,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
    allow_partial=True,
)
print(trimmed_messages)

2.4 获取第一个最大令牌数

我们也可以执行相反的操作,即获取最前面的最大令牌数,通过指定 strategy="first":

trimmed_messages = trim_messages(
    messages,
    max_tokens=45,
    strategy="first",
    token_counter=ChatOpenAI(model="gpt-4o"),
)
print(trimmed_messages)

2.5 编写自定义令牌计数器

我们可以编写自定义的令牌计数器函数来处理消息列表并返回整数:

from typing import List
import tiktoken
from langchain_core.messages import BaseMessage, ToolMessage

def str_token_counter(text: str) -> int:
    enc = tiktoken.get_encoding("o200k_base")
    return len(enc.encode(text))

def tiktoken_counter(messages: List[BaseMessage]) -> int:
    num_tokens = 3  # 每条回复起始于 <|start|>assistant<|message|>
    tokens_per_message = 3
    tokens_per_name = 1
    for msg in messages:
        if isinstance(msg, HumanMessage):
            role = "user"
        elif isinstance(msg, AIMessage):
            role = "assistant"
        elif isinstance(msg, ToolMessage):
            role = "tool"
        elif isinstance(msg, SystemMessage):
            role = "system"
        else:
            raise ValueError(f"Unsupported messages type {msg.__class__}")
        num_tokens += (
            tokens_per_message
            + str_token_counter(role)
            + str_token_counter(msg.content)
        )
        if msg.name:
            num_tokens += tokens_per_name + str_token_counter(msg.name)
    return num_tokens

trimmed_messages = trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=tiktoken_counter,
)
print(trimmed_messages)

3. 代码示例

以下是一个完整的代码示例,展示了上述裁剪消息的方法:

# pip install -U langchain-openai tiktoken
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    SystemMessage,
    trim_messages,
)
from langchain_openai import ChatOpenAI
from typing import List
import tiktoken

# Example messages
messages = [
    SystemMessage("you're a good assistant, you always respond with a joke."),
    HumanMessage("i wonder why it's called langchain"),
    AIMessage(
        'Well, I guess they thought "WordRope" and "SentenceString" just didn\'t have the same ring to it!'
    ),
    HumanMessage("and who is harrison chasing anyways"),
    AIMessage(
        "Hmmm let me think.\n\nWhy, he's probably chasing after the last cup of coffee in the office!"
    ),
    HumanMessage("what do you call a speechless parrot"),
]

# Custom token counter
def str_token_counter(text: str) -> int:
    enc = tiktoken.get_encoding("o200k_base")
    return len(enc.encode(text))

def tiktoken_counter(messages: List[BaseMessage]) -> int:
    num_tokens = 3
    tokens_per_message = 3
    tokens_per_name = 1
    for msg in messages:
        if isinstance(msg, HumanMessage):
            role = "user"
        elif isinstance(msg, AIMessage):
            role = "assistant"
        elif isinstance(msg, ToolMessage):
            role = "tool"
        elif isinstance(msg, SystemMessage):
            role = "system"
        else:
            raise ValueError(f"Unsupported messages type {msg.__class__}")
        num_tokens += (
            tokens_per_message
            + str_token_counter(role)
            + str_token_counter(msg.content)
        )
        if msg.name:
            num_tokens += tokens_per_name + str_token_counter(msg.name)
    return num_tokens

# Trim messages example
trimmed_messages = trim_messages(
    messages,
    max_tokens=45,
    strategy="last",
    token_counter=ChatOpenAI(model="gpt-4o"),
    include_system=True,
    allow_partial=True,
)
print(trimmed_messages)

4. 常见问题和解决方案

问题一:消息裁剪后不完整

解决方案:可以通过设置 allow_partial=True 来允许部分消息的内容被裁剪,从而确保令牌数量符合限制。

问题二:系统消息被裁剪

解决方案:通过设置 include_system=True 来保留系统消息,即使裁剪后仍然存在。

问题三:令牌计数不准确

解决方案:编写自定义令牌计数器,确保计数逻辑符合特定应用需求。

5. 总结和进一步学习资源

裁剪消息对使用聊天模型至关重要。本文介绍了如何使用 trim_messages 工具来管理消息长度,提供了实用的代码示例,并讨论了常见问题和解决方案。继续深入学习的资源包括:

6. 参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---