LangChain 实战课 AI 学中练 bug 记录 | 豆包MarsCode AI刷题

449 阅读3分钟

bug 出现位置

Marcode 青训营,LangChain 实战课程,AI 学中练代码部分,10_记忆 - 04_ConversationSummaryMemory.py

image.png

bug 内容

Traceback (most recent call last):
  File "/cloudide/workspace/LangChain-shizhanke/10_记忆/04_ConversationSummaryMemory.py", line 45, in <module>
    result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain_core/_api/deprecation.py", line 180, in warning_emitting_wrapper
    return wrapped(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 381, in __call__
    return self.invoke(
           ^^^^^^^^^^^^
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 164, in invoke
    raise e
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 159, in invoke
    final_outputs: Dict[str, Any] = self.prep_outputs(
                                    ^^^^^^^^^^^^^^^^^^
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 458, in prep_outputs
    self.memory.save_context(inputs, outputs)
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/memory/summary_buffer.py", line 82, in save_context
    self.prune()
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/memory/summary_buffer.py", line 94, in prune
    curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cloudide/.local/lib/python3.12/site-packages/langchain_openai/chat_models/base.py", line 904, in get_num_tokens_from_messages
    raise NotImplementedError(
NotImplementedError: get_num_tokens_from_messages() is not presently implemented for model cl100k_base. See https://platform.openai.com/docs/guides/text-generation/managing-tokens for information on how messages are converted to tokens.

bug 分析

根据阅读源代码,可以发现和查看报错 raise NotImplementedError( NotImplementedError: get_num_tokens_from_messages() is not presently implemented for model cl100k_base. 可以发现,其实就是 langchain_open 的库下对于豆包的大模型模型,设置这个方法,导致了错误。

image.png

解决方案

根据以上发现,我们其实只需要重写一个继承 ChatOpenAI 的类,并给他补一个 get_num_tokens_from_messages 方法即可,如下

# 原本代码
llm = ChatOpenAI(
    temperature=0.5,
    model=os.environ.get("LLM_MODELEND"),
)

# 替换后的代码
```python
class ChildChatOpenAI(ChatOpenAI):
    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
        model, encoding = self._get_encoding_model()
        # 注意:这里只是解决了计算的问题,具体计算方案具体准确取决于具体模型
        if model.startswith("cl100k_base"):
            # 调用祖父类的函数
            return super(BaseChatOpenAI, self).get_num_tokens_from_messages(messages)
        else:
            return super().get_num_tokens_from_messages(messages)


llm = ChildChatOpenAI(
    temperature=0.5,
    model=os.environ.get("LLM_MODELEND"),
)

完整可跑通代码

""" 
本文件是【记忆:通过 Memory 记住客户上次买花时的对话细节】章节的配套代码,课程链接:https://juejin.cn/book/7387702347436130304/section/7388070989826883621
您可以点击最上方的“运行“按钮,直接运行该文件;更多操作指引请参考Readme.md文件。
"""
# 设置OpenAI API密钥
import os
from typing import List

# 导入所需的库
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from langchain_openai.chat_models.base import BaseChatOpenAI

# 初始化大语言型
# llm = ChatOpenAI(
#     temperature=0.5,
#     model=os.environ.get("LLM_MODELEND"),
# )

class ChildChatOpenAI(ChatOpenAI):
    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
        model, encoding = self._get_encoding_model()
        # 注意:这里只是解决了计算的问题,具体计算方案具体准确取决于具体模型
        if model.startswith("cl100k_base"):
            # 调用祖父类的函数
            return super(BaseChatOpenAI, self).get_num_tokens_from_messages(messages)
        else:
            return super().get_num_tokens_from_messages(messages)


llm = ChildChatOpenAI(
    temperature=0.5,
    model=os.environ.get("LLM_MODELEND"),
)

# 初始化对话链
conversation = ConversationChain(
    llm=llm, memory=ConversationSummaryBufferMemory(llm=llm, max_token_limit=300)
)

# 第一天的对话
# 回合1
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
print(result)
# 回合2
result = conversation("她喜欢粉色玫瑰,颜色是粉色的。")
# print("\n第二次对话后的记忆:\n", conversation.memory.buffer)
print(result)

# 第二天的对话
# 回合3
result = conversation("我又来了,还记得我昨天为什么要来买花吗?")
print(result)

image.png