SpringAI1.0.0正式版-使用MySQL整合对话记忆(chat memory)

602 阅读6分钟

前言

完整代码请查看github项目地址:ai-chat-demo

本文的工作是在之前的AI流式对话基础上实现的,如有疑问,可以参考:实现流式对话应用

阅读本文前应该有的知识基础

  • Java基础

  • Java Web基础

  • 数据库基础

  • 包管理 (Mavengradle等)

  • Spring基础、SSM整合、SpringBoot


大模型中的 chat memory(聊天记忆) ,指的是模型在多轮对话中「记住」之前对话内容,从而实现更连贯、上下文一致的回答能力。

自定义对话记忆实现


实现ChatMemory接口

SpringAI可以通过实现org.springframework.ai.chat.memory.ChatMemory接口以定制对话记忆的实现方式,灵活性很强。

这里使用Mybatis整合MySQL储存实现:

@Component
@Slf4j
@RequiredArgsConstructor
public class MybatisChatMemory implements ChatMemory {

    private final AiMessagePairMapper aiMessagePairMapper;

    /**
    新增对话记忆,不通过此处管理,置空
    */
    @Override
    public void add(String conversationId, List<Message> messages) {

    }

    @Override
    public List<Message> get(String conversationId) {
        // conversationId对应的问答对数据列表,经过处理作为记忆上下文
        List<AiMessagePair> pairs = aiMessagePairMapper.selectBySessionId(Long.valueOf(conversationId));
        // 对应封装的消息列表
        List<Message> messages = new ArrayList<>();

        for (AiMessagePair pair : pairs) {
            // 只处理状态正常(完成)的一对消息
            if (pair.getStatus() != null &&
                    pair.getStatus() == AiMessageStatusEnum.FINISHED.getCode()) {
                if (StrUtil.isNotBlank(pair.getUserContent())) {
                    messages.add(new UserMessage(pair.getUserContent()));
                }
                if (StrUtil.isNotBlank(pair.getAiContent())) {
                    messages.add(new AssistantMessage(pair.getAiContent()));
                }
            }
        }

        return messages;
    }

    /**
     清除对话记忆,不通过此处管理,置空
     */
    @Override
    public void clear(String conversationId) {

    }
}

实体类AiMessagePair的结构:

/**
 * 一轮问答记录表
 * &#064;TableName  ai_message_pair
 */
@Data
public class AiMessagePair implements Serializable {

    /**
     * 主键,自增ID
     */
    private Long id;
    /**
     * 会话ID
     */
    private Long sessionId;
    /**
     * SSE会话ID
     */
    private String sseSessionId;
    /**
     * 用户提问内容
     */
    private String userContent;
    /**
     * AI回复内容
     */
    private String aiContent;
    /**
     * 使用模型id
     */
    private Integer modelUsed;
    /**
     * 状态:0-生成中 1-完成 2-中断
     */
    private Integer status;
    /**
     * 本轮消耗的Token
     */
    private Integer tokens;
    /**
     * 用户提问时间
     */
    private Date createTime;
    /**
     * AI回复完成时间
     */
    private Date responseTime;

}

解析

  1. 创建Spring组件MybatisChatMemory,实现ChatMemory接口。
  2. 分别实现add/get/clear方法,对应新增记忆/获取记忆列表/清除记忆时执行的逻辑。一般来说,如果将问答信息插入数据库,只需要实现get方法即可,插入对话信息和删除对话信息的方法建议另外设计逻辑处理。
  3. 入参conversationId对应一个问答会话的id,在get方法查询对话记忆的时候,需要按照这个传入的会话id来返回消息列表。
  4. Message是一个接口,包括UserMessage/SystemMessage/AssistantMessageToolResponseMessage,即分别对应用户信息/系统信息/AI助手信息和工具调用响应信息。封装Message列表时,请根据自己实体类的实际结构来指定对应类型的上下文信息。
  5. 在这里的get方法实现中,通过注入的aiMessagePairMapper查询返回对应会话idAiMessagePair列表,再通过设计逻辑来封装Message列表返回。像这样返回的Message列表会被提供给模型上下文,以实现对话记忆功能。

封装带记忆的对话反应流

封装工具组件,传入提示词,返回对应的反应式流:

@Component
@RequiredArgsConstructor
public class ModelBuilderSpringAiWithMemo {

    // 注入自定义对话记忆实现
    private final MybatisChatMemory mybatisChatMemory;

    /**
     * @param aiConfig     传入的AI配置
     * @param systemMsg    系统提示词
     * @param userMsg      用户输入的问题
     * @return Flux<ChatResponse> 反应式对话流
     * @apiNote 创建一个OpenAi模型,流式返回结果
     */
    public Flux<ChatResponse> buildModelStreamWithMemo(AiConfig aiConfig,
                                                String systemMsg,
                                                String userMsg, String conversationId) {

        String apiDomain = aiConfig.getApiDomain().replace("/v1", "");
        OpenAiApi openAiApi = OpenAiApi.builder()
                // 填入自己的API KEY
                .apiKey(aiConfig.getApiKey())
                // 填入自己的API域名,如果是百炼,即为https://dashscope.aliyuncs.com/compatible-mode
                // 注意:这里与langchain4j的配置不同,不需要在后面加/v1
                .baseUrl(apiDomain)
                .build();

        // 模型选项
        OpenAiChatOptions chatOptions = OpenAiChatOptions.builder()
                // 模型生成的最大 tokens 数
                .maxTokens(aiConfig.getMaxContextMsgs())
                // 模型生成的 tokens 的概率质量范围,取值范围 0.0-1.0 越大的概率质量范围越大
                .topP(aiConfig.getSimilarityTopP())
                // 模型生成的 tokens 的随机度,取值范围 0.0-1.0 越大的随机度越大
                .temperature(aiConfig.getTemperature())
                // 模型名称
                .model(aiConfig.getModelId())
                // 打开流式对话token计数配置,默认为false
                .streamUsage(true)
                .build();

        // 工具调用管理器 暂时为空
        ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();

        // 重试机制,设置最多3次
        RetryTemplate retryTemplate = RetryTemplate.builder()
                .maxAttempts(3)
                .build();

        // 观测数据收集器
        ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

        ChatModel model = new OpenAiChatModel(openAiApi,
                chatOptions,
                toolCallingManager,
                retryTemplate,
                observationRegistry);

        // 创建一个ChatClient对象,用于调用模型进行带记忆对话
        ChatClient chatClient = ChatClient.builder(model)
                // TODO: 系统提示词和用户提示词混淆解决
                // 系统提示词
//                .defaultSystem(systemMsg)
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(mybatisChatMemory).build())
                .build();

        // 返回反应式对话流
        return chatClient.prompt(userMsg)
                .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
                .stream()
                .chatResponse();
    }
}

这里测试的时候发现系统提示词和用户提示词会混淆,即每次对话时大模型将.defaultSystem()中的systemMsg当作了用户提示词的一部分,在解决之前暂时先不使用系统提示词参数以防止这个问题。


在对话方法中调用

完整代码:ai-chat-demo

ServiceImpl中注入ModelBuilderSpringAiWithMemo组件并进行使用即可,这里使用一个异步方法实现流式对话输出与聊天数据持久化:

@Service
@RequiredArgsConstructor
@Slf4j
public class AiChatServiceImpl implements AiChatService {

    private final AiMessagePairMapper aiMessagePairMapper;

    private final SseEmitterManager emitterManager;

    private final AiChatMessageService aiChatMessageService;

    private final AiConfigMapper aiConfigMapper;

    private final ModelBuilderSpringAiWithMemo modelBuilderSpringAiWithMemo;

    /**
     * @param aiChatParamDTO 对话参数
     * @param sessionId      SSE会话ID
     * @return 是否处理成功
     * @apiNote AI对话请求,基于虚拟线程实现异步处理,SpringAI实现
     */
    @Override
    public CompletableFuture<Boolean> sendQuestionAsyncWithMemo(AiChatParamDTO aiChatParamDTO, String sessionId) {
        return CompletableFuture.supplyAsync(() -> {
            if (emitterManager.isOverLoad())
                return false;
            // 获取会话id对应的sseEmitter
            SseEmitter emitter = emitterManager.getEmitter(sessionId);
            // 先发送一次队列人数通知
            emitterManager.notifyThreadCount();
            // 没有则先创建一个sseEmitter
            if (emitter == null) {
                if (emitterManager.addEmitter(sessionId, new SseEmitter(0L))) {
                    emitter = emitterManager.getEmitter(sessionId);
                } else {
                    // 创建失败,一般是由于队列已满,直接返回false
                    return false;
                }
            }
            // 最终指向的emitter对象
            SseEmitter finalEmitter = emitter;
            StringBuffer sb = new StringBuffer();
            // 开始对话,返回token流
            // 封装插入的信息对象
            AiMessagePair aiMessagePair = new AiMessagePair();
            aiMessagePair.setSseSessionId(sessionId);
            aiMessagePair.setSessionId(aiChatParamDTO.getChatSessionId());

            // 从数据库获取配置
            AiConfig aiConfig = aiConfigMapper.selectByPrimaryKey(aiChatParamDTO.getModelId());
            // 获取conversationId
            Long conversationId = aiChatParamDTO.getConversationId();
            if(conversationId == null){
                return false;
            }
            Flux<ChatResponse> chatResponseFlux = modelBuilderSpringAiWithMemo.buildModelStreamWithMemo(aiConfig,
                    "你是一个友善的AI助手",
                    aiChatParamDTO.getQuestion(),
                    String.valueOf(conversationId));
            // 用于跟踪最后一个 ChatResponse
            AtomicReference<ChatResponse> lastResponse = new AtomicReference<>();
            chatResponseFlux.subscribe(
                    token -> {
                        // 获取当前输出内容片段
                        String text = "";
                        if (token.getResult() != null) {
                            text = token.getResult().getOutput().getText();
                        }
                        if (StrUtil.isNotBlank(text)) {
                            sb.append(text);
                            // log.info("当前段数据:{}", text);
                            // 换行符转义:token换行符转换成<br>
                            text = text.replace("\n", "<br>");
                            // 换行符转义:如果token以换行符为结尾,转换成<br>
                            text = text.replace(" ", "&nbsp;");
                        }
                        // 发送返回的数据
                        try {
                            if (StrUtil.isNotBlank(text)) {
                                finalEmitter.send(SseEmitter.event().data(text));
                            }
                        } catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                        // 更新最后一个响应
                        lastResponse.set(token);
                    },
                    // 反应式流在报错时会直接中断
                    e -> {
                        log.error("ai对话 流式输出报错:{}", e.getMessage());
                        int usageCount = 0;
                        ChatResponse chatResponse = lastResponse.get();
                        if (chatResponse != null) {
                            Usage usage = chatResponse.getMetadata().getUsage();
                            usageCount = usage.getTotalTokens();
                        } else {
                            log.warn("未获取到 Token 使用信息,可能模型未返回或配置未启用");
                        }
                        // 更新中断的状态
                        tryUpdateMessage(aiMessagePair,
                                sb.toString(),
                                true,
                                usageCount);
                        finalEmitter.completeWithError(e);
                        emitterManager.removeEmitter(sessionId); // 出错时也移除
                    }, // 错误处理
                    () -> {// 流结束
                        log.info("\n回答完毕!");
                        // 从最后一个响应中获取 Token 使用信息
                        ChatResponse chatResponse = lastResponse.get();
                        int usageCount = 0;
                        if (chatResponse != null) {
                            Usage usage = chatResponse.getMetadata().getUsage();
                            usageCount = usage.getTotalTokens();
                        } else {
                            log.warn("未获取到 Token 使用信息,可能模型未返回或配置未启用");
                        }
                        finalEmitter.complete();
                        emitterManager.removeEmitter(sessionId); // 只在流结束后移除
                        log.info("最终拼接的数据:\n{}", sb);
                        log.info("token使用:{}", usageCount);
                        // 更新正常完成的状态
                        tryUpdateMessage(aiMessagePair,
                                sb.toString(),
                                false,
                                usageCount);
                    });
            return true;
        }, Executors.newVirtualThreadPerTaskExecutor());
    }

    /**
     * 尝试插入消息的方法
     */
    private void tryUpdateMessage(AiMessagePair message,
                                  String content,
                                  boolean isInterrupted,
                                  Integer tokenUsed) {
        int status = isInterrupted ? AiMessageStatusEnum.STOPPED.getCode() : AiMessageStatusEnum.FINISHED.getCode();
        message.setStatus(status);
        message.setAiContent(content);
        message.setTokens(tokenUsed);
        message.setResponseTime(Date.from(Instant.now()));
        aiMessagePairMapper.updateBySseIdSelective(message);
    }

}

测试对话记忆

第一次对话:

image.png

询问刚才的对话内容:

image.png


🎉显然,AI正确地给出了刚才的对话内容,至此,我们的AI聊天就拥有了对话记忆的功能。