记忆:通过Memory记住客户上次买花时的对话细节|豆包MarsCode AI刷题

213 阅读7分钟

默认情况下,无论是LLM还是代理都是无状态的,每次模型的调用都是独立于其他交互的。如果需要记住之前的对话,需要使用记忆(Memory)机制,记录之前的对话的上下文,并把这个上下文作为提示的一部分,在最新的调用中传递给模型。

使用ConversationChain

主要特点:提供了包含AI前缀和人类前缀的对话摘要格式,这个对话格式和记忆机制结合得非常紧密。

打印ConversationChain的内置提示模板

from langchain.chains import ConversationChain
import os
os.environ["DASHSCOPE_API_KEY"] ='阿里的DASHSCOPE_API_KEY'

# 创建聊天模型
from langchain_community.chat_models import ChatTongyi
llm = ChatTongyi(temperature=0)


# 初始化对话链
conv_chain = ConversationChain(llm=llm)

# 打印对话的模板
print(conv_chain.prompt.template)

输出

The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.

Current conversation:
{history}
Human: {input}
AI:

注意,此处若AI不知道答案,它会直接说不知道,而不是编造一些答案,减少了幻觉。

把历史对话信息存储在提示模板中,并作为新的提示内容在新一轮的对话过程中传递给模型。——记忆机制的原理

下面就让我们来在ConversationChain中加入记忆功能。

使用ConversationBufferMemory

在LangChain中,通过ConversationBufferMemory(缓冲记忆)可以实现最简单的记忆机制。

在对话链中引入ConversationBufferMemory。

from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferMemory
import os
os.environ["DASHSCOPE_API_KEY"] ='阿里的DASHSCOPE_API_KEY'

# 创建聊天模型
from langchain_community.chat_models import ChatTongyi
llm = ChatTongyi(temperature=0)

# 初始化对话链
conversation = ConversationChain(llm=llm, memory=ConversationBufferMemory())

# 第一天的对话
# 回合1
conversation("我姐姐明天要过生日,我需要一束生日花束。")
print("第一次对话后的记忆:", conversation.memory.buffer)

# 回合2
conversation("她喜欢粉色玫瑰,颜色是粉色的。")
print("第二次对话后的记忆:", conversation.memory.buffer)

# 回合3 (第二天的对话)
conversation("我又来了,还记得我昨天为什么要来买花吗?")
print("/n第三次对话后时提示:/n", conversation.prompt.template)
print("/n第三次对话后的记忆:/n", conversation.memory.buffer)

conversation.memory.buffer存了之前的对话记录。但是新输入中包含了更多的token(history),意味着响应时间变慢和更高的成本。而且,当达到LLM的令牌数(上下文窗口)限制时,太长的对话无法被记住。

下面来解决token太多、聊天历史记录过长的一些解决方案。

使用ConversationBufferWindowMemory

人类最新的经历最鲜活,也最重要。

ConversationBufferWindowMemory 是缓冲窗口记忆,它的思路就是只保存最新最近的几次人类和AI的互动。因此,它在之前的“缓冲记忆”基础上增加了一个窗口值 k。这意味着我们只保留一定数量的过去互动,然后“忘记”之前的互动。

from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferWindowMemory

import os
os.environ["DASHSCOPE_API_KEY"] ='阿里的DASHSCOPE_API_KEY'

# 创建聊天模型
from langchain_community.chat_models import ChatTongyi
llm = ChatTongyi(temperature=0)

# 初始化对话链
conversation = ConversationChain(llm=llm, memory=ConversationBufferWindowMemory(k=1))

# 第一天的对话
# 回合1
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
print(result)
# 回合2
result = conversation("她喜欢粉色玫瑰,颜色是粉色的。")
# print("\n第二次对话后的记忆:\n", conversation.memory.buffer)
print(result)

# 第二天的对话
# 回合3
result = conversation("我又来了,还记得我昨天为什么要来买花吗?")
print(result)

如果只需要记住最近的互动,缓冲窗口记忆是一个很好的选择。但是如果需要混合远期和近期的互动信息,还有其他选择。

使用ConversationSummaryMemory

ConversationSummaryMemory(对话总结记忆)的思路就是将对话历史进行汇总,然后再传递给 {history} 参数。这种方法旨在通过对之前的对话进行汇总来避免过度使用 Token。

核心特点:

  1. 汇总对话:此方法不是保存整个对话历史,而是每次新的互动发生时对其进行汇总,然后将其添加到之前所有互动的“运行汇总”中。
  2. 使用LLM进行汇总:该汇总功能由另一个LLM驱动,这意味着对话的汇总实际上是由AI自己进行的。
  3. 适合长对话:对于长对话,此方法的优势尤为明显。虽然最初使用的 Token 数量较多,但随着对话的进展,汇总方法的增长速度会减慢。与此同时,常规的缓冲内存模型会继续线性增长。
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory

import os
os.environ["DASHSCOPE_API_KEY"] ='阿里的DASHSCOPE_API_KEY'

# 创建聊天模型
from langchain_community.chat_models import ChatTongyi
llm = ChatTongyi(temperature=0)

# 初始化对话链
conversation = ConversationChain(
    llm=llm, memory=ConversationSummaryBufferMemory(llm=llm, max_token_limit=300)
)

# 第一天的对话
# 回合1
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
print(result)
# 回合2
result = conversation("她喜欢粉色玫瑰,颜色是粉色的。")
# print("\n第二次对话后的记忆:\n", conversation.memory.buffer)
print(result)

# 第二天的对话
# 回合3
result = conversation("我又来了,还记得我昨天为什么要来买花吗?")
print(result)

不仅利用LLM来回答每轮问题,还利用LLM来对之前的对话纪念性总结性的陈述,以节约token数量。其中回答问题和总结对话的大模型可以是不同的。

对于长对话来说,可以减少使用的token数量,但对于较短的对话,可能会导致更高的token的使用。对话的历史记忆依赖于中间汇总LLM的能力,还需要为汇总LLM使用token,增加成本,且不限制对话长度。

通过历史对话的汇总来优化和管理token的使用,ConversationSummaryMemory 为那些预期会有多轮的、长时间对话的场景提供了一种很好的方法。然而,这种方法仍然受到 Token 数量的限制。在一段时间后,我们仍然会超过大模型的上下文窗口限制。

主要问题:

总结的过程中并没有区分近期的对话和长期的对话(通常情况下近期的对话更重要),所以我们还要继续寻找新的记忆管理方法。

注意:遇到了huggingface.co无法加载gpt2的问题,可以手动clone gpt2然后修改langchain_core\language_models\base.py 文件中的

@cache  # Cache the tokenizer
def get_tokenizer() -> Any:
    """Get a GPT-2 tokenizer instance.

    This function is cached to avoid re-loading the tokenizer
    every time it is called.
    """
    try:
        from transformers import GPT2TokenizerFast  # type: ignore[import]
    except ImportError as e:
        msg = (
            "Could not import transformers python package. "
            "This is needed in order to calculate get_token_ids. "
            "Please install it with `pip install transformers`."
        )
        raise ImportError(msg) from e
    # create a GPT-2 tokenizer instance
    # 指定本地路径
    local_path = 'model/gpt2/gpt2'
    # return GPT2TokenizerFast.from_pretrained("gpt2")
    return GPT2Tokenizer.from_pretrained(local_path)

使用ConversationSummaryBufferMemory

对话总结缓冲记忆,是一种混合记忆模型,结合了上述各种记忆机制,包括ConversationSummaryMemory 和 ConversationBufferWindowMemory的特点。模型旨在在对话中总结早期的互动,同时尽量保留最近互动中的原始内容。

它是通过max_token_limit这个参数做到这一点的。当最新的对话文字长度在300字之内的时候,LangChain会记忆原始对话内容;当之前的对话文字超出了这个参数的长度,那么模型就会把所有超过预设长度的内容进行总结,以节省Token数量。

from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from transformers import GPT2Tokenizer, GPT2Model

import os
os.environ["DASHSCOPE_API_KEY"] ='阿里的DASHSCOPE_API_KEY'

# 创建聊天模型
from langchain_community.chat_models import ChatTongyi
llm = ChatTongyi(temperature=0)

# 初始化对话链
conversation = ConversationChain(
    llm=llm, memory=ConversationSummaryBufferMemory(llm=llm, max_token_limit=300)
)

# 第一天的对话
# 回合1
result = conversation("我姐姐明天要过生日,我需要一束生日花束。")
print(result)
# 回合2
result = conversation("她喜欢粉色玫瑰,颜色是粉色的。")
# print("\n第二次对话后的记忆:\n", conversation.memory.buffer)
print(result)

# 第二天的对话
# 回合3
result = conversation("我又来了,还记得我昨天为什么要来买花吗?")
print(result)