多轮对话的记忆心脏:ChatMemory 滑动窗口原理

0 阅读6分钟

Spring AI 源码解读 - 第 4 篇:ChatMemory 记忆管理

多轮对话的上下文维护机制

📖 开篇引言

多轮对话的关键是 AI 能够记住之前的对话内容。但如何高效地存储和检索这些消息?如何避免消息堆积导致的 Token 浪费?

本篇将深入 ChatMemory 的设计与实现,理解记忆管理的核心机制。


一、ChatMemory 接口设计

1.1 ChatMemory 接口

// org.springframework.ai.chat.memory.ChatMemory
public interface ChatMemory {
    
    // 添加消息到指定会话
    void add(ConversationId conversationId, Message message);
    
    // 获取指定会话的所有消息
    List<Message> get(ConversationId conversationId);
    
    // 获取指定会话的消息数量
    int getMessageCount(ConversationId conversationId);
    
    // 清空指定会话的消息
    void clear(ConversationId conversationId);
}

// 会话 ID 包装类
public record ConversationId(String id) {}

1.2 ChatMemory 的职责

ChatMemory
├── 存储:将消息持久化
├── 检索:根据会话 ID 获取消息
├── 管理:清空、统计消息
└── 隔离:不同会话的消息相互独立

二、MessageWindowChatMemory 实现

2.1 滑动窗口的概念

消息窗口大小 = 31 轮
┌─────────────────────┐
│ [msg1, msg2, msg3]  │ ← 窗口满
└─────────────────────┘

第 2 轮(添加 msg4)
┌─────────────────────┐
│ [msg2, msg3, msg4]  │ ← msg1 被移出
└─────────────────────┘

第 3 轮(添加 msg5)
┌─────────────────────┐
│ [msg3, msg4, msg5]  │ ← msg2 被移出
└─────────────────────┘

2.2 MessageWindowChatMemory 源码

// org.springframework.ai.chat.memory.MessageWindowChatMemory
public class MessageWindowChatMemory implements ChatMemory {
    
    private final int maxMessages;  // 最大消息数
    private final Map<ConversationId, List<Message>> conversationHistory;
    
    // 构造方法
    public MessageWindowChatMemory(int maxMessages) {
        this.maxMessages = maxMessages;
        this.conversationHistory = new ConcurrentHashMap<>();
    }
    
    // 工厂方法:创建默认实例(最多 100 条消息)
    public static MessageWindowChatMemory create() {
        return new MessageWindowChatMemory(100);
    }
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        // 1. 获取或创建该会话的消息列表
        List<Message> messages = conversationHistory
            .computeIfAbsent(conversationId, k -> new ArrayList<>());
        
        // 2. 添加新消息
        messages.add(message);
        
        // 3. 如果超过窗口大小,移除最旧的消息
        if (messages.size() > this.maxMessages) {
            messages.remove(0);  // 移除第一条(最旧)
        }
    }
    
    @Override
    public List<Message> get(ConversationId conversationId) {
        // 返回该会话的所有消息(已在窗口内)
        return conversationHistory.getOrDefault(conversationId, List.of());
    }
    
    @Override
    public int getMessageCount(ConversationId conversationId) {
        return conversationHistory
            .getOrDefault(conversationId, List.of())
            .size();
    }
    
    @Override
    public void clear(ConversationId conversationId) {
        conversationHistory.remove(conversationId);
    }
}

2.3 滑动窗口的优缺点

优点缺点
实现简单可能丢失重要的早期消息
内存占用固定无法区分消息重要性
性能高对长对话支持不足

三、Token 计数与消息截断

3.1 为什么需要 Token 计数?

模型的上下文窗口大小是有限的
例如:Ollama qwen2.5:14b 的上下文窗口 = 32K tokens

如果消息总 Token 数超过上下文窗口,模型会报错

3.2 TokenTextSplitter 的 Token 计数

// org.springframework.ai.document.TokenTextSplitter
public class TokenTextSplitter implements TextSplitter {
    
    private final int chunkSize;      // 每个块的 Token 数
    private final int chunkOverlap;   // 块之间的重叠 Token 数
    private final Tokenizer tokenizer; // Token 计数器
    
    // 计算文本的 Token 数
    public int countTokens(String text) {
        return this.tokenizer.countTokens(text);
    }
    
    // 分割文本
    public List<String> split(String text) {
        List<String> chunks = new ArrayList<>();
        List<Integer> tokenCounts = new ArrayList<>();
        
        // 1. 计算每个块的 Token 数
        for (String chunk : text.split("\n")) {
            int tokens = countTokens(chunk);
            if (tokens > this.chunkSize) {
                // 如果单个块超过大小,继续分割
                chunks.addAll(splitLargeChunk(chunk));
            } else {
                chunks.add(chunk);
            }
        }
        
        return chunks;
    }
}

3.3 消息截断策略

// 在 ChatMemory 中实现 Token 限制
public class TokenLimitedChatMemory implements ChatMemory {
    
    private final int maxTokens;  // 最大 Token 数
    private final Tokenizer tokenizer;
    private final Map<ConversationId, List<Message>> conversationHistory;
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        List<Message> messages = conversationHistory
            .computeIfAbsent(conversationId, k -> new ArrayList<>());
        
        messages.add(message);
        
        // 检查总 Token 数
        while (getTotalTokens(messages) > this.maxTokens) {
            // 移除最旧的消息
            messages.remove(0);
        }
    }
    
    // 计算消息列表的总 Token 数
    private int getTotalTokens(List<Message> messages) {
        return messages.stream()
            .mapToInt(msg -> tokenizer.countTokens(msg.getContent()))
            .sum();
    }
}

四、ConversationId 与会话隔离

4.1 ConversationId 的作用

// 不同用户的会话需要隔离
ChatMemory memory = new MessageWindowChatMemory(100);

// 用户 A 的会话
ConversationId sessionA = new ConversationId("user-a-session-1");
memory.add(sessionA, new UserMessage("我想学 Java"));
memory.add(sessionA, new AssistantMessage("Java 是..."));

// 用户 B 的会话
ConversationId sessionB = new ConversationId("user-b-session-1");
memory.add(sessionB, new UserMessage("我想学 Python"));
memory.add(sessionB, new AssistantMessage("Python 是..."));

// 获取消息时完全隔离
List<Message> messagesA = memory.get(sessionA);  // 只包含 A 的消息
List<Message> messagesB = memory.get(sessionB);  // 只包含 B 的消息

4.2 会话 ID 的生成策略

// 策略 1:基于线程 ID(单线程场景)
String sessionId = String.valueOf(Thread.currentThread().getId());

// 策略 2:基于用户 ID
String sessionId = "user-" + userId;

// 策略 3:基于 HTTP Session ID
String sessionId = httpSession.getId();

// 策略 4:基于 UUID(每次对话新建)
String sessionId = UUID.randomUUID().toString();

五、ChatMemory 在 Advisor 中的使用

5.1 MessageChatMemoryAdvisor 的完整流程

public class MessageChatMemoryAdvisor 
    implements CallAroundAdvisor, StreamAroundAdvisor {
    
    private final ChatMemory chatMemory;
    
    // 前置处理:注入历史消息
    @Override
    public AdvisedRequest before(AdvisedRequest advisedRequest) {
        
        // 1. 获取会话 ID
        String sessionId = getSessionId(advisedRequest);
        ConversationId conversationId = new ConversationId(sessionId);
        
        // 2. 从 ChatMemory 获取历史消息
        List<Message> historyMessages = this.chatMemory.get(conversationId);
        
        // 3. 将历史消息注入到请求中
        // 注入顺序:SystemMessage → HistoryMessages → CurrentUserMessage
        List<Message> allMessages = new ArrayList<>();
        
        // 添加系统消息(如果有)
        if (advisedRequest.getSystemText() != null) {
            allMessages.add(new SystemMessage(advisedRequest.getSystemText()));
        }
        
        // 添加历史消息
        allMessages.addAll(historyMessages);
        
        // 添加当前用户消息
        allMessages.addAll(advisedRequest.getUserMessage());
        
        // 4. 更新请求
        return AdvisedRequest.from(advisedRequest)
            .userMessage(allMessages)
            .build();
    }
    
    // 后置处理:保存消息到 ChatMemory
    @Override
    public ChatResponse after(AdvisedRequest advisedRequest, 
                               AdvisedResponse<ChatResponse> advisedResponse) {
        
        String sessionId = getSessionId(advisedRequest);
        ConversationId conversationId = new ConversationId(sessionId);
        
        // 1. 保存用户消息
        for (Message msg : advisedRequest.getUserMessage()) {
            if (msg instanceof UserMessage) {
                this.chatMemory.add(conversationId, msg);
            }
        }
        
        // 2. 保存 AI 回复
        ChatResponse response = advisedResponse.getChatResponse();
        AssistantMessage assistantMessage = new AssistantMessage(
            response.getResult().getOutput().getContent()
        );
        this.chatMemory.add(conversationId, assistantMessage);
        
        return response;
    }
}

5.2 记忆注入的完整示例

1 轮调用
┌─────────────────────────────────────────┐
│ before()                                 │
│ ChatMemory.get(sessionId) → []           │
│ 注入消息:[SystemMessage, UserMessage]   │
└─────────────────────────────────────────┘
         ↓
┌─────────────────────────────────────────┐
│ doChat()                                 │
│ chatModel.call(prompt)                   │
│ 返回 AssistantMessage                    │
└─────────────────────────────────────────┘
         ↓
┌─────────────────────────────────────────┐
│ after()                                  │
│ ChatMemory.add(sessionId, UserMessage)   │
│ ChatMemory.add(sessionId, AssistantMsg)  │
└─────────────────────────────────────────┘

第 2 轮调用
┌─────────────────────────────────────────┐
│ before()                                 │
│ ChatMemory.get(sessionId) →              │
│   [UserMessage, AssistantMessage]        │
│ 注入消息:[SystemMessage,                ││   UserMessage(历史),                     ││   AssistantMessage(历史),                ││   UserMessage(当前)]                     │
└─────────────────────────────────────────┘

六、分布式记忆存储

6.1 为什么需要分布式记忆?

单机 ChatMemory 的问题:
- 应用重启后消息丢失
- 多实例部署时消息不共享
- 无法跨应用访问

解决方案:使用 Redis 等分布式存储

6.2 Redis 实现的 ChatMemory

// 基于 Redis 的 ChatMemory 实现
public class RedisChatMemory implements ChatMemory {
    
    private final RedisTemplate<String, Message> redisTemplate;
    private final String keyPrefix = "chat:memory:";
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        // 1. 构建 Redis key
        String key = keyPrefix + conversationId.id();
        
        // 2. 将消息序列化后存储
        redisTemplate.opsForList().rightPush(key, message);
        
        // 3. 设置过期时间(24 小时)
        redisTemplate.expire(key, Duration.ofHours(24));
    }
    
    @Override
    public List<Message> get(ConversationId conversationId) {
        
        String key = keyPrefix + conversationId.id();
        
        // 从 Redis 获取所有消息
        return redisTemplate.opsForList()
            .range(key, 0, -1);
    }
    
    @Override
    public void clear(ConversationId conversationId) {
        
        String key = keyPrefix + conversationId.id();
        redisTemplate.delete(key);
    }
}

6.3 Redis 中的数据结构

Redis 中的存储结构:

chat:memory:user-a-session-1
├── [0] UserMessage("我想学 Java")
├── [1] AssistantMessage("Java 是...")
├── [2] UserMessage("它有哪些特点?")
└── [3] AssistantMessage("Java 的特点是...")

chat:memory:user-b-session-1
├── [0] UserMessage("我想学 Python")
└── [1] AssistantMessage("Python 是...")

七、ChatMemory 的生命周期

7.1 创建

// 方式 1:默认实现
ChatMemory memory = MessageWindowChatMemory.create();

// 方式 2:自定义大小
ChatMemory memory = new MessageWindowChatMemory(50);

// 方式 3:Redis 实现
ChatMemory memory = new RedisChatMemory(redisTemplate);

7.2 使用

// 在 ChatClient 中使用
ChatClient chatClient = ChatClient.builder(chatModel)
    .defaultAdvisors(new MessageChatMemoryAdvisor(memory))
    .build();

// 自动注入记忆
chatClient.prompt()
    .user("你好")
    .call();

7.3 清理

// 清空特定会话的记忆
memory.clear(new ConversationId("user-a-session-1"));

// 或者依赖过期时间自动清理(Redis)

八、小结

8.1 本篇要点

主题核心要点
ChatMemory 接口add / get / clear 三个核心方法
MessageWindowChatMemory滑动窗口实现,固定消息数量
Token 计数防止消息堆积导致的 Token 溢出
ConversationId会话隔离,不同用户消息独立
Advisor 集成before() 注入记忆,after() 保存消息
分布式存储Redis 实现跨应用、跨实例的记忆共享

8.2 关键类清单

类 / 接口职责
ChatMemory记忆接口
MessageWindowChatMemory滑动窗口实现
RedisChatMemoryRedis 分布式实现
ConversationId会话 ID 包装类
MessageChatMemoryAdvisor记忆注入拦截器
TokenTextSplitterToken 计数与分割

系列目录

  • 第 1 篇:整体架构与核心抽象
  • 第 2 篇:ChatClient 调用链路
  • 第 3 篇:Prompt 与 Message 体系
  • 第 4 篇:ChatMemory 记忆管理(本篇)

需要Spring AI系列学习代码的同学 欢迎关注公众号「AI日撰」,点击菜单「获取源码」获取完整代码(Gitee 仓库)。