Spring AI 源码解读 - 第 4 篇:ChatMemory 记忆管理
多轮对话的上下文维护机制
📖 开篇引言
多轮对话的关键是 AI 能够记住之前的对话内容。但如何高效地存储和检索这些消息?如何避免消息堆积导致的 Token 浪费?
本篇将深入 ChatMemory 的设计与实现,理解记忆管理的核心机制。
一、ChatMemory 接口设计
1.1 ChatMemory 接口
// org.springframework.ai.chat.memory.ChatMemory
public interface ChatMemory {
// 添加消息到指定会话
void add(ConversationId conversationId, Message message);
// 获取指定会话的所有消息
List<Message> get(ConversationId conversationId);
// 获取指定会话的消息数量
int getMessageCount(ConversationId conversationId);
// 清空指定会话的消息
void clear(ConversationId conversationId);
}
// 会话 ID 包装类
public record ConversationId(String id) {}
1.2 ChatMemory 的职责
ChatMemory
├── 存储:将消息持久化
├── 检索:根据会话 ID 获取消息
├── 管理:清空、统计消息
└── 隔离:不同会话的消息相互独立
二、MessageWindowChatMemory 实现
2.1 滑动窗口的概念
消息窗口大小 = 3
第 1 轮
┌─────────────────────┐
│ [msg1, msg2, msg3] │ ← 窗口满
└─────────────────────┘
第 2 轮(添加 msg4)
┌─────────────────────┐
│ [msg2, msg3, msg4] │ ← msg1 被移出
└─────────────────────┘
第 3 轮(添加 msg5)
┌─────────────────────┐
│ [msg3, msg4, msg5] │ ← msg2 被移出
└─────────────────────┘
2.2 MessageWindowChatMemory 源码
// org.springframework.ai.chat.memory.MessageWindowChatMemory
public class MessageWindowChatMemory implements ChatMemory {
private final int maxMessages; // 最大消息数
private final Map<ConversationId, List<Message>> conversationHistory;
// 构造方法
public MessageWindowChatMemory(int maxMessages) {
this.maxMessages = maxMessages;
this.conversationHistory = new ConcurrentHashMap<>();
}
// 工厂方法:创建默认实例(最多 100 条消息)
public static MessageWindowChatMemory create() {
return new MessageWindowChatMemory(100);
}
@Override
public void add(ConversationId conversationId, Message message) {
// 1. 获取或创建该会话的消息列表
List<Message> messages = conversationHistory
.computeIfAbsent(conversationId, k -> new ArrayList<>());
// 2. 添加新消息
messages.add(message);
// 3. 如果超过窗口大小,移除最旧的消息
if (messages.size() > this.maxMessages) {
messages.remove(0); // 移除第一条(最旧)
}
}
@Override
public List<Message> get(ConversationId conversationId) {
// 返回该会话的所有消息(已在窗口内)
return conversationHistory.getOrDefault(conversationId, List.of());
}
@Override
public int getMessageCount(ConversationId conversationId) {
return conversationHistory
.getOrDefault(conversationId, List.of())
.size();
}
@Override
public void clear(ConversationId conversationId) {
conversationHistory.remove(conversationId);
}
}
2.3 滑动窗口的优缺点
| 优点 | 缺点 |
|---|---|
| 实现简单 | 可能丢失重要的早期消息 |
| 内存占用固定 | 无法区分消息重要性 |
| 性能高 | 对长对话支持不足 |
三、Token 计数与消息截断
3.1 为什么需要 Token 计数?
模型的上下文窗口大小是有限的
例如:Ollama qwen2.5:14b 的上下文窗口 = 32K tokens
如果消息总 Token 数超过上下文窗口,模型会报错
3.2 TokenTextSplitter 的 Token 计数
// org.springframework.ai.document.TokenTextSplitter
public class TokenTextSplitter implements TextSplitter {
private final int chunkSize; // 每个块的 Token 数
private final int chunkOverlap; // 块之间的重叠 Token 数
private final Tokenizer tokenizer; // Token 计数器
// 计算文本的 Token 数
public int countTokens(String text) {
return this.tokenizer.countTokens(text);
}
// 分割文本
public List<String> split(String text) {
List<String> chunks = new ArrayList<>();
List<Integer> tokenCounts = new ArrayList<>();
// 1. 计算每个块的 Token 数
for (String chunk : text.split("\n")) {
int tokens = countTokens(chunk);
if (tokens > this.chunkSize) {
// 如果单个块超过大小,继续分割
chunks.addAll(splitLargeChunk(chunk));
} else {
chunks.add(chunk);
}
}
return chunks;
}
}
3.3 消息截断策略
// 在 ChatMemory 中实现 Token 限制
public class TokenLimitedChatMemory implements ChatMemory {
private final int maxTokens; // 最大 Token 数
private final Tokenizer tokenizer;
private final Map<ConversationId, List<Message>> conversationHistory;
@Override
public void add(ConversationId conversationId, Message message) {
List<Message> messages = conversationHistory
.computeIfAbsent(conversationId, k -> new ArrayList<>());
messages.add(message);
// 检查总 Token 数
while (getTotalTokens(messages) > this.maxTokens) {
// 移除最旧的消息
messages.remove(0);
}
}
// 计算消息列表的总 Token 数
private int getTotalTokens(List<Message> messages) {
return messages.stream()
.mapToInt(msg -> tokenizer.countTokens(msg.getContent()))
.sum();
}
}
四、ConversationId 与会话隔离
4.1 ConversationId 的作用
// 不同用户的会话需要隔离
ChatMemory memory = new MessageWindowChatMemory(100);
// 用户 A 的会话
ConversationId sessionA = new ConversationId("user-a-session-1");
memory.add(sessionA, new UserMessage("我想学 Java"));
memory.add(sessionA, new AssistantMessage("Java 是..."));
// 用户 B 的会话
ConversationId sessionB = new ConversationId("user-b-session-1");
memory.add(sessionB, new UserMessage("我想学 Python"));
memory.add(sessionB, new AssistantMessage("Python 是..."));
// 获取消息时完全隔离
List<Message> messagesA = memory.get(sessionA); // 只包含 A 的消息
List<Message> messagesB = memory.get(sessionB); // 只包含 B 的消息
4.2 会话 ID 的生成策略
// 策略 1:基于线程 ID(单线程场景)
String sessionId = String.valueOf(Thread.currentThread().getId());
// 策略 2:基于用户 ID
String sessionId = "user-" + userId;
// 策略 3:基于 HTTP Session ID
String sessionId = httpSession.getId();
// 策略 4:基于 UUID(每次对话新建)
String sessionId = UUID.randomUUID().toString();
五、ChatMemory 在 Advisor 中的使用
5.1 MessageChatMemoryAdvisor 的完整流程
public class MessageChatMemoryAdvisor
implements CallAroundAdvisor, StreamAroundAdvisor {
private final ChatMemory chatMemory;
// 前置处理:注入历史消息
@Override
public AdvisedRequest before(AdvisedRequest advisedRequest) {
// 1. 获取会话 ID
String sessionId = getSessionId(advisedRequest);
ConversationId conversationId = new ConversationId(sessionId);
// 2. 从 ChatMemory 获取历史消息
List<Message> historyMessages = this.chatMemory.get(conversationId);
// 3. 将历史消息注入到请求中
// 注入顺序:SystemMessage → HistoryMessages → CurrentUserMessage
List<Message> allMessages = new ArrayList<>();
// 添加系统消息(如果有)
if (advisedRequest.getSystemText() != null) {
allMessages.add(new SystemMessage(advisedRequest.getSystemText()));
}
// 添加历史消息
allMessages.addAll(historyMessages);
// 添加当前用户消息
allMessages.addAll(advisedRequest.getUserMessage());
// 4. 更新请求
return AdvisedRequest.from(advisedRequest)
.userMessage(allMessages)
.build();
}
// 后置处理:保存消息到 ChatMemory
@Override
public ChatResponse after(AdvisedRequest advisedRequest,
AdvisedResponse<ChatResponse> advisedResponse) {
String sessionId = getSessionId(advisedRequest);
ConversationId conversationId = new ConversationId(sessionId);
// 1. 保存用户消息
for (Message msg : advisedRequest.getUserMessage()) {
if (msg instanceof UserMessage) {
this.chatMemory.add(conversationId, msg);
}
}
// 2. 保存 AI 回复
ChatResponse response = advisedResponse.getChatResponse();
AssistantMessage assistantMessage = new AssistantMessage(
response.getResult().getOutput().getContent()
);
this.chatMemory.add(conversationId, assistantMessage);
return response;
}
}
5.2 记忆注入的完整示例
第 1 轮调用
┌─────────────────────────────────────────┐
│ before() │
│ ChatMemory.get(sessionId) → [] │
│ 注入消息:[SystemMessage, UserMessage] │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ doChat() │
│ chatModel.call(prompt) │
│ 返回 AssistantMessage │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ after() │
│ ChatMemory.add(sessionId, UserMessage) │
│ ChatMemory.add(sessionId, AssistantMsg) │
└─────────────────────────────────────────┘
第 2 轮调用
┌─────────────────────────────────────────┐
│ before() │
│ ChatMemory.get(sessionId) → │
│ [UserMessage, AssistantMessage] │
│ 注入消息:[SystemMessage, ││ UserMessage(历史), ││ AssistantMessage(历史), ││ UserMessage(当前)] │
└─────────────────────────────────────────┘
六、分布式记忆存储
6.1 为什么需要分布式记忆?
单机 ChatMemory 的问题:
- 应用重启后消息丢失
- 多实例部署时消息不共享
- 无法跨应用访问
解决方案:使用 Redis 等分布式存储
6.2 Redis 实现的 ChatMemory
// 基于 Redis 的 ChatMemory 实现
public class RedisChatMemory implements ChatMemory {
private final RedisTemplate<String, Message> redisTemplate;
private final String keyPrefix = "chat:memory:";
@Override
public void add(ConversationId conversationId, Message message) {
// 1. 构建 Redis key
String key = keyPrefix + conversationId.id();
// 2. 将消息序列化后存储
redisTemplate.opsForList().rightPush(key, message);
// 3. 设置过期时间(24 小时)
redisTemplate.expire(key, Duration.ofHours(24));
}
@Override
public List<Message> get(ConversationId conversationId) {
String key = keyPrefix + conversationId.id();
// 从 Redis 获取所有消息
return redisTemplate.opsForList()
.range(key, 0, -1);
}
@Override
public void clear(ConversationId conversationId) {
String key = keyPrefix + conversationId.id();
redisTemplate.delete(key);
}
}
6.3 Redis 中的数据结构
Redis 中的存储结构:
chat:memory:user-a-session-1
├── [0] UserMessage("我想学 Java")
├── [1] AssistantMessage("Java 是...")
├── [2] UserMessage("它有哪些特点?")
└── [3] AssistantMessage("Java 的特点是...")
chat:memory:user-b-session-1
├── [0] UserMessage("我想学 Python")
└── [1] AssistantMessage("Python 是...")
七、ChatMemory 的生命周期
7.1 创建
// 方式 1:默认实现
ChatMemory memory = MessageWindowChatMemory.create();
// 方式 2:自定义大小
ChatMemory memory = new MessageWindowChatMemory(50);
// 方式 3:Redis 实现
ChatMemory memory = new RedisChatMemory(redisTemplate);
7.2 使用
// 在 ChatClient 中使用
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(new MessageChatMemoryAdvisor(memory))
.build();
// 自动注入记忆
chatClient.prompt()
.user("你好")
.call();
7.3 清理
// 清空特定会话的记忆
memory.clear(new ConversationId("user-a-session-1"));
// 或者依赖过期时间自动清理(Redis)
八、小结
8.1 本篇要点
| 主题 | 核心要点 |
|---|---|
| ChatMemory 接口 | add / get / clear 三个核心方法 |
| MessageWindowChatMemory | 滑动窗口实现,固定消息数量 |
| Token 计数 | 防止消息堆积导致的 Token 溢出 |
| ConversationId | 会话隔离,不同用户消息独立 |
| Advisor 集成 | before() 注入记忆,after() 保存消息 |
| 分布式存储 | Redis 实现跨应用、跨实例的记忆共享 |
8.2 关键类清单
| 类 / 接口 | 职责 |
|---|---|
ChatMemory | 记忆接口 |
MessageWindowChatMemory | 滑动窗口实现 |
RedisChatMemory | Redis 分布式实现 |
ConversationId | 会话 ID 包装类 |
MessageChatMemoryAdvisor | 记忆注入拦截器 |
TokenTextSplitter | Token 计数与分割 |
系列目录:
- 第 1 篇:整体架构与核心抽象
- 第 2 篇:ChatClient 调用链路
- 第 3 篇:Prompt 与 Message 体系
- 第 4 篇:ChatMemory 记忆管理(本篇)
需要Spring AI系列学习代码的同学 欢迎关注公众号「AI日撰」,点击菜单「获取源码」获取完整代码(Gitee 仓库)。