bug 出现位置
Marcode 青训营,LangChain 实战课程,AI 学中练代码部分,10_记忆 - 04_ConversationSummaryMemory.py
bug 内容
Traceback (most recent call last):
File "/cloudide/workspace/LangChain-shizhanke/10_记忆/04_ConversationSummaryMemory.py", line 45, in <module>
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain_core/_api/deprecation.py", line 180, in warning_emitting_wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 381, in __call__
return self.invoke(
^^^^^^^^^^^^
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 164, in invoke
raise e
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 159, in invoke
final_outputs: Dict[str, Any] = self.prep_outputs(
^^^^^^^^^^^^^^^^^^
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/chains/base.py", line 458, in prep_outputs
self.memory.save_context(inputs, outputs)
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/memory/summary_buffer.py", line 82, in save_context
self.prune()
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain/memory/summary_buffer.py", line 94, in prune
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cloudide/.local/lib/python3.12/site-packages/langchain_openai/chat_models/base.py", line 904, in get_num_tokens_from_messages
raise NotImplementedError(
NotImplementedError: get_num_tokens_from_messages() is not presently implemented for model cl100k_base. See https://platform.openai.com/docs/guides/text-generation/managing-tokens for information on how messages are converted to tokens.
bug 分析
根据阅读源代码,可以发现和查看报错 raise NotImplementedError( NotImplementedError: get_num_tokens_from_messages() is not presently implemented for model cl100k_base. 可以发现,其实就是 langchain_open 的库下对于豆包的大模型模型,设置这个方法,导致了错误。
解决方案
根据以上发现,我们其实只需要重写一个继承 ChatOpenAI 的类,并给他补一个 get_num_tokens_from_messages 方法即可,如下
# 原本代码
llm = ChatOpenAI(
temperature=0.5,
model=os.environ.get("LLM_MODELEND"),
)
# 替换后的代码
```python
class ChildChatOpenAI(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
model, encoding = self._get_encoding_model()
# 注意:这里只是解决了计算的问题,具体计算方案具体准确取决于具体模型
if model.startswith("cl100k_base"):
# 调用祖父类的函数
return super(BaseChatOpenAI, self).get_num_tokens_from_messages(messages)
else:
return super().get_num_tokens_from_messages(messages)
llm = ChildChatOpenAI(
temperature=0.5,
model=os.environ.get("LLM_MODELEND"),
)
完整可跑通代码
"""
本文件是【记忆:通过 Memory 记住客户上次买花时的对话细节】章节的配套代码,课程链接:https://juejin.cn/book/7387702347436130304/section/7388070989826883621
您可以点击最上方的“运行“按钮,直接运行该文件;更多操作指引请参考Readme.md文件。
"""
# 设置OpenAI API密钥
import os
from typing import List
# 导入所需的库
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from langchain_openai.chat_models.base import BaseChatOpenAI
# 初始化大语言型
# llm = ChatOpenAI(
# temperature=0.5,
# model=os.environ.get("LLM_MODELEND"),
# )
class ChildChatOpenAI(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
model, encoding = self._get_encoding_model()
# 注意:这里只是解决了计算的问题,具体计算方案具体准确取决于具体模型
if model.startswith("cl100k_base"):
# 调用祖父类的函数
return super(BaseChatOpenAI, self).get_num_tokens_from_messages(messages)
else:
return super().get_num_tokens_from_messages(messages)
llm = ChildChatOpenAI(
temperature=0.5,
model=os.environ.get("LLM_MODELEND"),
)
# 初始化对话链
conversation = ConversationChain(
llm=llm, memory=ConversationSummaryBufferMemory(llm=llm, max_token_limit=300)
)
# 第一天的对话
# 回合1
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
print(result)
# 回合2
result = conversation("她喜欢粉色玫瑰,颜色是粉色的。")
# print("\n第二次对话后的记忆:\n", conversation.memory.buffer)
print(result)
# 第二天的对话
# 回合3
result = conversation("我又来了,还记得我昨天为什么要来买花吗?")
print(result)