前言
完整代码请查看github项目地址:ai-chat-demo
本文的工作是在之前的AI流式对话基础上实现的,如有疑问,可以参考:实现流式对话应用
阅读本文前应该有的知识基础
-
Java基础 -
Java Web基础 -
数据库基础
-
包管理 (
Maven、gradle等) -
Spring基础、SSM整合、SpringBoot等
大模型中的
chat memory(聊天记忆) ,指的是模型在多轮对话中「记住」之前对话内容,从而实现更连贯、上下文一致的回答能力。
自定义对话记忆实现
实现ChatMemory接口
SpringAI可以通过实现org.springframework.ai.chat.memory.ChatMemory接口以定制对话记忆的实现方式,灵活性很强。
这里使用Mybatis整合MySQL储存实现:
@Component
@Slf4j
@RequiredArgsConstructor
public class MybatisChatMemory implements ChatMemory {
private final AiMessagePairMapper aiMessagePairMapper;
/**
新增对话记忆,不通过此处管理,置空
*/
@Override
public void add(String conversationId, List<Message> messages) {
}
@Override
public List<Message> get(String conversationId) {
// conversationId对应的问答对数据列表,经过处理作为记忆上下文
List<AiMessagePair> pairs = aiMessagePairMapper.selectBySessionId(Long.valueOf(conversationId));
// 对应封装的消息列表
List<Message> messages = new ArrayList<>();
for (AiMessagePair pair : pairs) {
// 只处理状态正常(完成)的一对消息
if (pair.getStatus() != null &&
pair.getStatus() == AiMessageStatusEnum.FINISHED.getCode()) {
if (StrUtil.isNotBlank(pair.getUserContent())) {
messages.add(new UserMessage(pair.getUserContent()));
}
if (StrUtil.isNotBlank(pair.getAiContent())) {
messages.add(new AssistantMessage(pair.getAiContent()));
}
}
}
return messages;
}
/**
清除对话记忆,不通过此处管理,置空
*/
@Override
public void clear(String conversationId) {
}
}
实体类AiMessagePair的结构:
/**
* 一轮问答记录表
* @TableName ai_message_pair
*/
@Data
public class AiMessagePair implements Serializable {
/**
* 主键,自增ID
*/
private Long id;
/**
* 会话ID
*/
private Long sessionId;
/**
* SSE会话ID
*/
private String sseSessionId;
/**
* 用户提问内容
*/
private String userContent;
/**
* AI回复内容
*/
private String aiContent;
/**
* 使用模型id
*/
private Integer modelUsed;
/**
* 状态:0-生成中 1-完成 2-中断
*/
private Integer status;
/**
* 本轮消耗的Token
*/
private Integer tokens;
/**
* 用户提问时间
*/
private Date createTime;
/**
* AI回复完成时间
*/
private Date responseTime;
}
解析
- 创建
Spring组件MybatisChatMemory,实现ChatMemory接口。 - 分别实现
add/get/clear方法,对应新增记忆/获取记忆列表/清除记忆时执行的逻辑。一般来说,如果将问答信息插入数据库,只需要实现get方法即可,插入对话信息和删除对话信息的方法建议另外设计逻辑处理。 - 入参
conversationId对应一个问答会话的id,在get方法查询对话记忆的时候,需要按照这个传入的会话id来返回消息列表。 Message是一个接口,包括UserMessage/SystemMessage/AssistantMessage和ToolResponseMessage,即分别对应用户信息/系统信息/AI助手信息和工具调用响应信息。封装Message列表时,请根据自己实体类的实际结构来指定对应类型的上下文信息。- 在这里的
get方法实现中,通过注入的aiMessagePairMapper查询返回对应会话id的AiMessagePair列表,再通过设计逻辑来封装Message列表返回。像这样返回的Message列表会被提供给模型上下文,以实现对话记忆功能。
封装带记忆的对话反应流
封装工具组件,传入提示词,返回对应的反应式流:
@Component
@RequiredArgsConstructor
public class ModelBuilderSpringAiWithMemo {
// 注入自定义对话记忆实现
private final MybatisChatMemory mybatisChatMemory;
/**
* @param aiConfig 传入的AI配置
* @param systemMsg 系统提示词
* @param userMsg 用户输入的问题
* @return Flux<ChatResponse> 反应式对话流
* @apiNote 创建一个OpenAi模型,流式返回结果
*/
public Flux<ChatResponse> buildModelStreamWithMemo(AiConfig aiConfig,
String systemMsg,
String userMsg, String conversationId) {
String apiDomain = aiConfig.getApiDomain().replace("/v1", "");
OpenAiApi openAiApi = OpenAiApi.builder()
// 填入自己的API KEY
.apiKey(aiConfig.getApiKey())
// 填入自己的API域名,如果是百炼,即为https://dashscope.aliyuncs.com/compatible-mode
// 注意:这里与langchain4j的配置不同,不需要在后面加/v1
.baseUrl(apiDomain)
.build();
// 模型选项
OpenAiChatOptions chatOptions = OpenAiChatOptions.builder()
// 模型生成的最大 tokens 数
.maxTokens(aiConfig.getMaxContextMsgs())
// 模型生成的 tokens 的概率质量范围,取值范围 0.0-1.0 越大的概率质量范围越大
.topP(aiConfig.getSimilarityTopP())
// 模型生成的 tokens 的随机度,取值范围 0.0-1.0 越大的随机度越大
.temperature(aiConfig.getTemperature())
// 模型名称
.model(aiConfig.getModelId())
// 打开流式对话token计数配置,默认为false
.streamUsage(true)
.build();
// 工具调用管理器 暂时为空
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
// 重试机制,设置最多3次
RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(3)
.build();
// 观测数据收集器
ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
ChatModel model = new OpenAiChatModel(openAiApi,
chatOptions,
toolCallingManager,
retryTemplate,
observationRegistry);
// 创建一个ChatClient对象,用于调用模型进行带记忆对话
ChatClient chatClient = ChatClient.builder(model)
// TODO: 系统提示词和用户提示词混淆解决
// 系统提示词
// .defaultSystem(systemMsg)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(mybatisChatMemory).build())
.build();
// 返回反应式对话流
return chatClient.prompt(userMsg)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.chatResponse();
}
}
这里测试的时候发现系统提示词和用户提示词会混淆,即每次对话时大模型将.defaultSystem()中的systemMsg当作了用户提示词的一部分,在解决之前暂时先不使用系统提示词参数以防止这个问题。
在对话方法中调用
完整代码:ai-chat-demo
在ServiceImpl中注入ModelBuilderSpringAiWithMemo组件并进行使用即可,这里使用一个异步方法实现流式对话输出与聊天数据持久化:
@Service
@RequiredArgsConstructor
@Slf4j
public class AiChatServiceImpl implements AiChatService {
private final AiMessagePairMapper aiMessagePairMapper;
private final SseEmitterManager emitterManager;
private final AiChatMessageService aiChatMessageService;
private final AiConfigMapper aiConfigMapper;
private final ModelBuilderSpringAiWithMemo modelBuilderSpringAiWithMemo;
/**
* @param aiChatParamDTO 对话参数
* @param sessionId SSE会话ID
* @return 是否处理成功
* @apiNote AI对话请求,基于虚拟线程实现异步处理,SpringAI实现
*/
@Override
public CompletableFuture<Boolean> sendQuestionAsyncWithMemo(AiChatParamDTO aiChatParamDTO, String sessionId) {
return CompletableFuture.supplyAsync(() -> {
if (emitterManager.isOverLoad())
return false;
// 获取会话id对应的sseEmitter
SseEmitter emitter = emitterManager.getEmitter(sessionId);
// 先发送一次队列人数通知
emitterManager.notifyThreadCount();
// 没有则先创建一个sseEmitter
if (emitter == null) {
if (emitterManager.addEmitter(sessionId, new SseEmitter(0L))) {
emitter = emitterManager.getEmitter(sessionId);
} else {
// 创建失败,一般是由于队列已满,直接返回false
return false;
}
}
// 最终指向的emitter对象
SseEmitter finalEmitter = emitter;
StringBuffer sb = new StringBuffer();
// 开始对话,返回token流
// 封装插入的信息对象
AiMessagePair aiMessagePair = new AiMessagePair();
aiMessagePair.setSseSessionId(sessionId);
aiMessagePair.setSessionId(aiChatParamDTO.getChatSessionId());
// 从数据库获取配置
AiConfig aiConfig = aiConfigMapper.selectByPrimaryKey(aiChatParamDTO.getModelId());
// 获取conversationId
Long conversationId = aiChatParamDTO.getConversationId();
if(conversationId == null){
return false;
}
Flux<ChatResponse> chatResponseFlux = modelBuilderSpringAiWithMemo.buildModelStreamWithMemo(aiConfig,
"你是一个友善的AI助手",
aiChatParamDTO.getQuestion(),
String.valueOf(conversationId));
// 用于跟踪最后一个 ChatResponse
AtomicReference<ChatResponse> lastResponse = new AtomicReference<>();
chatResponseFlux.subscribe(
token -> {
// 获取当前输出内容片段
String text = "";
if (token.getResult() != null) {
text = token.getResult().getOutput().getText();
}
if (StrUtil.isNotBlank(text)) {
sb.append(text);
// log.info("当前段数据:{}", text);
// 换行符转义:token换行符转换成<br>
text = text.replace("\n", "<br>");
// 换行符转义:如果token以换行符为结尾,转换成<br>
text = text.replace(" ", " ");
}
// 发送返回的数据
try {
if (StrUtil.isNotBlank(text)) {
finalEmitter.send(SseEmitter.event().data(text));
}
} catch (IOException e) {
throw new RuntimeException(e);
}
// 更新最后一个响应
lastResponse.set(token);
},
// 反应式流在报错时会直接中断
e -> {
log.error("ai对话 流式输出报错:{}", e.getMessage());
int usageCount = 0;
ChatResponse chatResponse = lastResponse.get();
if (chatResponse != null) {
Usage usage = chatResponse.getMetadata().getUsage();
usageCount = usage.getTotalTokens();
} else {
log.warn("未获取到 Token 使用信息,可能模型未返回或配置未启用");
}
// 更新中断的状态
tryUpdateMessage(aiMessagePair,
sb.toString(),
true,
usageCount);
finalEmitter.completeWithError(e);
emitterManager.removeEmitter(sessionId); // 出错时也移除
}, // 错误处理
() -> {// 流结束
log.info("\n回答完毕!");
// 从最后一个响应中获取 Token 使用信息
ChatResponse chatResponse = lastResponse.get();
int usageCount = 0;
if (chatResponse != null) {
Usage usage = chatResponse.getMetadata().getUsage();
usageCount = usage.getTotalTokens();
} else {
log.warn("未获取到 Token 使用信息,可能模型未返回或配置未启用");
}
finalEmitter.complete();
emitterManager.removeEmitter(sessionId); // 只在流结束后移除
log.info("最终拼接的数据:\n{}", sb);
log.info("token使用:{}", usageCount);
// 更新正常完成的状态
tryUpdateMessage(aiMessagePair,
sb.toString(),
false,
usageCount);
});
return true;
}, Executors.newVirtualThreadPerTaskExecutor());
}
/**
* 尝试插入消息的方法
*/
private void tryUpdateMessage(AiMessagePair message,
String content,
boolean isInterrupted,
Integer tokenUsed) {
int status = isInterrupted ? AiMessageStatusEnum.STOPPED.getCode() : AiMessageStatusEnum.FINISHED.getCode();
message.setStatus(status);
message.setAiContent(content);
message.setTokens(tokenUsed);
message.setResponseTime(Date.from(Instant.now()));
aiMessagePairMapper.updateBySseIdSelective(message);
}
}
测试对话记忆
第一次对话:
询问刚才的对话内容:
🎉显然,AI正确地给出了刚才的对话内容,至此,我们的AI聊天就拥有了对话记忆的功能。