Spring AI 聊天记忆(ChatMemory)源码分析以及实战

2,638 阅读6分钟

#2.Spring AI 接入OpenAI大模型实现同步和流式对话 文章中实现了聊天功能,其实现方式是一问一答,此方式无法让AI具有聊天记忆力能力,那如何能够实现大模型聊天记忆能力?本篇文章详细介绍其实现原理及实践。

场景引入

通过使用#2.Spring AI 接入OpenAI大模型实现同步和流式对话实现的方式,给AI大模型连续发送如下两个Prompt,我们观察结果看效果回答效果。

  • Hi,My name is ivygeek
  • What's my name?

image.png

image.png 当发送第二prompt时,大模型无法记住之前的聊天上下文,导致无法回答出正确的答案。要实现一个可以让大模型具有聊天记忆能力,根据过去的聊天信息进行回答,我们如何实现呢?

ChatGPT是一个基于预训练语言大模型,可以根据输入的Prompt推断出回答结果,所以我们如果把交互的上下文内容一并传给大模型就可以了,想想也挺简单的。

聊天记忆原理

将聊天信息包括大模型回复信息依次存储在一个队列中发送给ChatGPT,然后ChatGPT会根据整个聊天信息对回复内容进行判断。

Spring AI Message类型 image.png

  • UserMessage:用户消息,指用户输入的消息,比如提问的问题。
  • SystemMessage:系统限制性消息,这种消息比较特殊,权重很大,AI会优先依据SystemMessage里的内容进行回复。在设定Chat角色时,可以用的到,下篇文章分析
  • AssistantMessage:大模型回复的消息。
  • FunctionMessage:函数调用消息,开发中一般使用不到,一般无需关心。

所以我们会将 UserMessage、SystemMessage、AssistantMessage 、FunctionMessage放在一个队列中,然后将整个队列发送给ChatGPT,然后ChatGPT就会根据整个聊天信息对回复内容进行判断。

大家思考一下如下两个问题:

  • 对于System Message类型的消息该如何处理呢?

  • 对于Function Message 如何处理?

代码实现

我们将使用两种方式实现Chat上下文记忆能力,第一种使用简单的List<Message>记录用户输入提示词以及大模型返回的消息并分析此种实现方式存在的问题。根据存在问题,引出使用 SpringAI 框架封装ChatMemory实现。

使用List<Message>实现Chat上下文记忆能力

package com.ivy.controller;

import jakarta.annotation.Resource;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.ArrayList;
import java.util.List;

@RestController
public class ChatMemoryController {
    private final List<Message> historyMessage = new ArrayList<>();

    @Resource
    private OpenAiChatModel openAiChatModel;

    @GetMapping("/chatWithList")
    public String chatWithList(String prompt) {
        // 将用户消息添加到历史消息列表中
        historyMessage.add(new UserMessage(prompt));
        Generation result = openAiChatModel.call(new Prompt(historyMessage)).getResult();
        // 将AI消息添加到历史消息列表中
        AssistantMessage assistantMessage = result.getOutput();
        historyMessage.add(assistantMessage);
        return assistantMessage.getContent();
    }
}

image.png

image.png 通过验证,以上代码可以实现chat上下文记忆能力。但是请大家思考一下这么实现会有哪些问题?思考一分钟在往下看?

问题一:如果聊天内容很多,会超过ChatGPT窗口大小,比如想GPT-3限制4000多token,很容易导致大模型无法回答内容。

问题二:聊天内容可能中间会有一些不相关的文本,如果一同传过去会消耗更多的成本。ChatGPT是按照token收费的。token越多,一次交互的成本就越高。

问题三:历史消息没有和会话进行关联,应该是每次会话一个历史消息。

所以Spring AI针对于以上问题给出了很好的解决方案,下面将简单分析Spring AI框架实现源码相关细节。 在Spring AI 1.0.0版本之前,还没有实现上下文记忆的能力,1.0.0版本开始框架才提供。

Spring AI 关于上下文记忆实现源码分析

ChatMemory

package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }
    
    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}

该接口表示聊天对话历史记录的存储。它提供了将消息添加到对话、从对话中检索消息以及清除对话历史记录的方法。主要提供如下方法:

add(String conversationId, Message message)

  • 参数:conversationId:会话ID, message:消息(包括用户消息和回复消息)
  • 作用:将消息添加到会话中

get(String conversationId, int lastN)

  • 参数:conversationId:会话ID, lastN:取最新的几条数据,以此可以控制一次会话窗口的大小,比如对于GPT3.5的窗口大小限制4096
  • 作用:从会话中检索N条最新消息

clear(String conversationId)

  • 参数:conversationId:会话ID
  • 作用:清除对话历史消息

InMemoryChatMemory

ChatMemory的实现类 表示为聊天对话历史记录提供内存中存储。源码如下; image.png 实现还是比较简单的,其中关键在于 Map<String, List<Message>> conversationHistory = new ConcurrentHashMap(); 用于存储会话对应的历史消息。

Advisor

Spring AI 框架提供了三种Advisor来使用ChatMeomry。 image.png

  • MessageChatMemoryAdvisor:查询对象会话ID的历史消息添加到提示词文本中,核心代码如下;
@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {

		String conversationId = this.doGetConversationId(context);

		int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(context);

		// 1. Retrieve the chat memory for the current conversation.
		List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize);

		// 2. Advise the request messages list.
		List<Message> advisedMessages = new ArrayList<>(request.messages());
		advisedMessages.addAll(memoryMessages);

		// 3. Create a new request with the advised messages.
		AdvisedRequest advisedRequest = AdvisedRequest.from(request).withMessages(advisedMessages).build();

		// 4. Add the new user input to the conversation memory.
		UserMessage userMessage = new UserMessage(request.userText(), request.media());
		this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage);

		return advisedRequest;
	}

  • PromptChatMemoryAdvisor:检索到的内存中的历史消息将添加到提示的系统文本中。
@Override
	public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {

		// 1. Advise system parameters.
		List<Message> memoryMessages = this.getChatMemoryStore()
			.get(this.doGetConversationId(context), this.doGetChatMemoryRetrieveSize(context));

		String memory = (memoryMessages != null) ? memoryMessages.stream()
			.filter(m -> m.getMessageType() != MessageType.SYSTEM)
			.map(m -> m.getMessageType() + ":" + m.getContent())
			.collect(Collectors.joining(System.lineSeparator())) : "";

		Map<String, Object> advisedSystemParams = new HashMap<>(request.systemParams());
		advisedSystemParams.put("memory", memory);

		// 2. Advise the system text.
		String advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise;

		// 3. Create a new request with the advised system text and parameters.
		AdvisedRequest advisedRequest = AdvisedRequest.from(request)
			.withSystemText(advisedSystemText)
			.withSystemParams(advisedSystemParams)
			.build();

		// 4. Add the new user input to the conversation memory.
		UserMessage userMessage = new UserMessage(request.userText(), request.media());
		this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage);

		return advisedRequest;
	}
  • VectorStoreChatMemoryAdvisor:检索向量数据库中的历史消息将添加到提示的系统文本中。

使用 Spring AI 框架实现Chat上下文记忆能力

package com.ivy.controller;

import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;

@RestController
public class ChatMemoryController {
    @Resource
    private OpenAiChatModel openAiChatModel;

    private final ChatMemory chatMemory = new InMemoryChatMemory();
    @GetMapping("/chatWithChatMemory")
    public Flux<String> chatWithChatMemory(String chatId, String prompt) {
        ChatClient chatClient = ChatClient.builder(openAiChatModel)
                .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory))
                .build();

        return chatClient.prompt()
                .user(prompt)
                .advisors(a -> a
                        .param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
                        .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)
                )
                .stream()
                .content();
    }
}

image.png

image.png 从回答的结果上看,已经实现了上下文记忆的能力。如果把参数chatId改为11111,AI的回答如下图所示

image.png

对代码解释: 其中参数chatId 表示会话ID,实现上下文与会话绑定。CHAT_MEMORY_RETRIEVE_SIZE_KEY 表示历史会话最多100条发给AI。

总结

本文通过两种方式实现了Chat记忆能力,但是第一种根据其原理简单粗暴的实现了上下文记忆的能力,但是存在诸多问题。Spring AI框架从1.0.0版本开始提供了上下文记忆的能力,通过Spring AI 框架可以非常简单的实现。

Github:github.com/fangjieDevp…