java2AI系列:模型的记忆功能

0 阅读8分钟

1. 引

首先,本文中提到的模型指LLM大语言模型,一个超大规模的神经网络,核心能力是理解,推理,生成,对话。

模型本身不具备记忆功能,但从我们使用豆包、千问的体感上来看,它们似乎记得我们之前的聊天对话内容。这种有“记忆”的假象源于模型上下文,也就是历史对话管理,我们每次对模型进行输入,系统会将我们之前的聊天内容打包在一起塞给模型,模型在生成答案前,看到的不仅有当前的问题,还有之前的聊天内容,以此有了所谓的记忆。

当然,一通会话的所有聊天内容不是一直叠加,而应该是一个滑动窗口(比如近20条聊天记录),不然累计到一定程度后,模型输入的Token量将难以处理。

在这里插入图片描述

2. LLM原理简介

补充一点LLM生成答案的原理,它并不是根据所训练的知识去获取一个匹配度最高的答案,或者是联网搜索一个相关性最高的结果,而是基于Transformer架构的逐字推理。

下面举例描述一下答案生成的过程。

第 1 步:输入:介绍一下 LLM模型预测下一个词:大 第 2 步:输入:介绍一下 LLM 大预测:语 第 3 步:输入:介绍一下 LLM 大语预测:言 第 4 步:输入:介绍一下 LLM 大语言预测:模 ……就这样一个字一个字拼到结束。

输入 → 分词 → 向量 → Transformer 算概率 → 逐字生成 → 循环 → 停止 → 输出。

面对相同的问题,模型可能输出不同的答案。在预测下个字元,也就是下个Token时,它不一定会选中概率最大的那个Token,这与我们设定的模型温度(temperature)有关,温度越低,得到相同答案的概率越高,当温度为0时,模型每次都会去概率最大的下个Token来生成答案。

输入:“我今天想吃”概率分布可能是:

火锅:40%
米饭:25%
面条:15%
烧烤:10%
其他:10%
temperature=0 → 必输出:火锅
temperature=0.7~1.0 → 可能火锅、米饭、面条随机
temperature=1.5+ → 可能出现奇怪答案

3. 记忆力测试

在前面的文章中,我们完成了SpringAI框架接入智谱大模型的工作,在浏览器第一次输入我是你的主人Rosen,得到了如下答案

在这里插入图片描述 接着再输入你知道我是谁吗,从输出答案来看,模型已经不记得我的名字了。

在这里插入图片描述

4. 操作步骤

4.1 请求日志

SpringAI框架为我们提供了许多开箱即用的功能,比如本文会用到的请求日志和上下文管理。我们先开启模型请求日志功能,仅需在初始化ChatClient对象时构建一个Advisor对象SimpleLoggerAdvisor,这样请求模型的输入和输出都会自动打印到控制台中。

@Configuration
public class CommonConfiguration {

    @Bean
    public ChatClient chatClient(ZhiPuAiChatModel model) {
        return ChatClient.builder(model)
                .defaultSystem("你是一个智能桌面助手铁铁,帮助缓解主人工作之余的疲惫和情绪情绪")
                .defaultAdvisors(new SimpleLoggerAdvisor())
                .build();
    }
}

在这里插入图片描述

对应的日志代码在org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor中,为DEBUG级别,对应的,我们在配置文件中设置一下这个目录的日志级别,不然也不会打印。

logging:
  level:
    root: info
    # AI对话的日志级别
    org.springframework.ai.chat.client.advisor: debug

在这里插入图片描述

从请求日志来看,提示词有两部分,一部分是用户输入的问题,另一部分是我们为模型设定的人设,分别对应用户提示词和系统提示词。

在这里插入图片描述

4.2 MessageWindowChatMemory

MessageWindowChatMemory是一个内存型支持保存最近N条记录的ChatMemory实现,是学习SpringAI记忆功能主键的首选。我们修改ChatClient这个Bean如下,增加一个MessageChatMemoryAdvisor来实现模型的记忆功能。

@Configuration
public class CommonConfiguration {

    @Bean
    public ChatClient chatClient(ZhiPuAiChatModel model) {
        return ChatClient.builder(model)
                .defaultSystem("你是一个智能桌面助手铁铁,帮助缓解主人工作之余的疲惫和情绪情绪")
                .defaultAdvisors(new SimpleLoggerAdvisor(),
                        MessageChatMemoryAdvisor.builder(MessageWindowChatMemory.builder()
                                        .maxMessages(10)
                                        .build())
                                .build())
                .build();
    }
}

同时,我们在使用ChatClient调用模型的时候需要传递一个名为chat_memory_conversation_id的会话Id参数,修改/api/stream/chat接口如下,增加一个conversationId会话Id参数。当然,这里做的比较粗糙,直接图方便放到请求参数里面了。

    @RequestMapping(value = "/stream/chat", produces = "text/html;charset=UTF-8")
    public Flux<String> streamChat(@RequestParam String prompt, @RequestParam(name = "conversationId", required = false) String conversationId) {
        if (ObjectUtils.isEmpty(conversationId)) {
            conversationId = UUID.randomUUID().toString();
        }
        String finalConversationId = conversationId;
        return chatClient
                .prompt(prompt)
                .advisors(v -> v.param(ChatMemory.CONVERSATION_ID, finalConversationId))
                .stream()
                .content();
    }

修改后,重启服务。 第一次我们输入为我是你的主人Rosen

在这里插入图片描述

接着,我们又输入你知道我是谁吗

在这里插入图片描述

再次输入很高兴你还记得我

在这里插入图片描述

可以看到,从第二的回答中,模型还知道我们第一次的沟通内容。

我们在SimpleLoggerAdvisor输出请求日志的行打上断点,再次输入你很聪明,可以看到我们的请求信息中,promt中的message有8条,红色框中的6条为前三轮对话的问答,第7条为系统提示词,第8条为最新的问题。context中有名为chat_memory_conversation_id的参数,值为我们的对话Id。

在这里插入图片描述

4.3 RedisChatMemoryRepository + JdbcChatMemoryRepository

上面我们已经实现了对话的内存记忆功能,当然,这仅使用与测试环境或者本地学习,生产环境肯定不能这么玩儿。多数情况下,生产环境会选择分布式缓存+持久化组合的方案,在缓冲中仅保存滑动窗口内的上下文内容,提供给模型做上下文输入,历史会话均持久化到数据库中便于查看。

4.3.1 会话记录表

这里创建一张简单的表来保存会话的历史聊天记录。

CREATE TABLE chat_message (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    conversation_id VARCHAR(64) NOT NULL,
    role VARCHAR(16) NOT NULL, -- user / assistant
    content TEXT,
    create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
    INDEX idx_conversation_id (conversation_id)
);

4.3.2 增加相关依赖

这里主要增加了mysql,redis,mybatisplus相关的依赖,还有spring-ai-autoconfigure-model-chat-memory的starter,完整的pom文件如下:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.5.10</version>
        <relativePath/>
    </parent>
    <groupId>com.roswu</groupId>
    <artifactId>spring-ai-learning</artifactId>
    <version>1.0.1-SNAPSHOT</version>
    <name>spring-ai-learning</name>
    <description>spring-ai-learning</description>

    <properties>
        <java.version>17</java.version>
        <spring-ai.version>1.1.2</spring-ai.version>
        <mybatis-plus.version>3.5.16</mybatis-plus.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-model-zhipuai</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-autoconfigure-model-chat-memory</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-spring-boot3-starter</artifactId>
        </dependency>
        <dependency>
            <groupId>com.mysql</groupId>
            <artifactId>mysql-connector-j</artifactId>
            <scope>runtime</scope>
        </dependency>
    </dependencies>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.ai</groupId>
                <artifactId>spring-ai-bom</artifactId>
                <version>${spring-ai.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
            <dependency>
                <groupId>org.projectlombok</groupId>
                <artifactId>lombok</artifactId>
                <version>1.18.42</version>
            </dependency>
            <dependency>
                <groupId>com.baomidou</groupId>
                <artifactId>mybatis-plus-spring-boot3-starter</artifactId>
                <version>3.5.11</version>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

4.3.3 基础对象创建

pojo

@Data
@TableName("chat_message")
public class ChatMessage {
    @TableId(type = IdType.AUTO)
    private Long id;
    private String conversationId;
    private String role;
    private String content;
    private LocalDateTime createTime;
}

mapper

@Mapper
public interface ChatMessageMapper extends BaseMapper<ChatMessage> {
}

mapper.xml和service此处非必须,就省略了。

4.3.4 RedisChatMemoryRepository

下面重点来了,自定义实现ChatMemoryRepository接口的对象RedisChatMemoryRepository,实现内容基本有ai帮我生成,我只做了部分逻辑的校对。

@RequiredArgsConstructor
public class RedisChatMemoryRepository implements ChatMemoryRepository {

    private final StringRedisTemplate redisTemplate;
    private final ChatMessageMapper chatMessageMapper;
    private final ObjectMapper objectMapper = new ObjectMapper();

    private static final String PREFIX = "spring:ai:chat:memory:";
    private static final int MAX_MESSAGES = 16; // 只保留最新16条
    private static final long EXPIRED_MINUTES = 30;

    @Override
    public List<String> findConversationIds() {
        Set<String> keys = redisTemplate.keys(PREFIX + "*");
        return keys.stream().map(k -> k.replace(PREFIX, "")).toList();
    }

    /**
     * 读取:Redis没有 → 从MySQL恢复最新16条
     */
    @Override
    public List<Message> findByConversationId(String conversationId) {
        String key = PREFIX + conversationId;

        // 1. 读Redis
        String json = redisTemplate.opsForValue().get(key);
        try {
            return objectMapper.readValue(json, new TypeReference<List<Message>>() {
            });
        } catch (Exception e) {
            redisTemplate.delete(key);
        }

        // 2. Redis失效 → 从MySQL恢复
        List<ChatMessage> dbMessages = chatMessageMapper.selectList(
                new com.baomidou.mybatisplus.core.conditions.query.QueryWrapper<ChatMessage>()
                        .eq("conversation_id", conversationId)
                        .orderByAsc("create_time")
                        .last("LIMIT " + MAX_MESSAGES));

        List<Message> messages = convertToAiMessage(dbMessages);

        // 3. 写回Redis
        try {
            redisTemplate.opsForValue().set(key, objectMapper.writeValueAsString(messages), EXPIRED_MINUTES, TimeUnit.MINUTES);
        } catch (Exception ignored) {
            // do nothing
        }

        return messages;
    }

    /**
     * 保存:双写 Redis + MySQL
     */
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        String key = PREFIX + conversationId;

        // 1. 只保留最新16条
        if (messages.size() > MAX_MESSAGES) {
            messages = messages.subList(messages.size() - MAX_MESSAGES, messages.size());
        }

        // 2. 保存Redis
        try {
            redisTemplate.opsForValue().set(key, objectMapper.writeValueAsString(messages), EXPIRED_MINUTES, TimeUnit.MINUTES);
        } catch (Exception e) {
            e.printStackTrace();
        }

        // 3. 🔥 关键:同步保存到MySQL(全量历史)
        syncToMysql(conversationId, messages);
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        redisTemplate.delete(PREFIX + conversationId);
    }

    // ===================== 工具方法 =====================

    /**
     * 把对话同步保存到MySQL(全量落库)
     */
    private void syncToMysql(String conversationId, List<Message> messages) {
        for (Message msg : messages) {
            String role = switch (msg.getMessageType()) {
                case USER -> "user";
                case ASSISTANT -> "assistant";
                default -> null;
            };

            if (role == null) continue;

            // 判断是否已存在,避免重复插入
            boolean exists = chatMessageMapper.exists(
                    new com.baomidou.mybatisplus.core.conditions.query.QueryWrapper<ChatMessage>()
                            .eq("conversation_id", conversationId)
                            .eq("role", role)
                            .eq("content", msg.getText())
            );

            if (!exists) {
                ChatMessage chatMessage = new ChatMessage();
                chatMessage.setConversationId(conversationId);
                chatMessage.setRole(role);
                chatMessage.setContent(msg.getText());
                chatMessage.setCreateTime(LocalDateTime.now());
                chatMessageMapper.insert(chatMessage);
            }
        }
    }

    /**
     * MySQL 转 Spring AI Message
     */
    private List<Message> convertToAiMessage(List<ChatMessage> list) {
        List<Message> res = new ArrayList<>();
        for (ChatMessage m : list) {
            if ("user".equals(m.getRole())) {
                res.add(new UserMessage(m.getContent()));
            } else if ("assistant".equals(m.getRole())) {
                res.add(new AssistantMessage(m.getContent()));
            }
        }
        return res;
    }
}

4.3.5 ChatClient

最后,是将自定义的MessageChatMemoryAdvisor注入到ChatClient对象中。

@Configuration
public class CommonConfiguration {

    @Bean
    public RedisChatMemoryRepository redisChatMemoryRepository(
            StringRedisTemplate redisTemplate,
            ChatMessageMapper chatMessageMapper) {
        return new RedisChatMemoryRepository(redisTemplate, chatMessageMapper);
    }

    @Bean
    public ChatMemory redisChatMemory(RedisChatMemoryRepository redisChatMemoryRepository) {
        // Redis存储
        // 滑动窗口:保留最近8轮(16条消息)
        return MessageWindowChatMemory.builder()
                .chatMemoryRepository(redisChatMemoryRepository)
                .maxMessages(16) // 8轮问答
                .build();
    }

    @Bean
    public ChatClient chatClient(ZhiPuAiChatModel model, ChatMemory chatMemory) {
        return ChatClient.builder(model)
                .defaultSystem("你是一个智能桌面助手铁铁,帮助缓解主人工作之余的疲惫和情绪情绪")
                .defaultAdvisors(new SimpleLoggerAdvisor(),
//                        MessageChatMemoryAdvisor.builder(MessageWindowChatMemory.builder()
//                                        .maxMessages(10)
//                                        .build())
//                                .build())
                        MessageChatMemoryAdvisor
                                .builder(chatMemory)
                                .build())
                .build();
    }
}

4.3.6 添加配置

差点把配置文件给忘了,这里引入了redis缓存中间件和mysql数据库,需要增加对应的配置,启动类上也注意添加@MapperScan,完整的配置文件更新如下:

server:
  port: 8088

spring:
  application:
    name: spring-ai-learning
  ai:
    zhipuai:
      base-url: https://open.bigmodel.cn/api/paas
      api-key: ${OPENAI_API_KEY}
      chat:
        model: GLM-4.7
        temperature: 0.7
  # 数据库配置
  datasource:
    driver-class-name: com.mysql.cj.jdbc.Driver
    url: jdbc:mysql://localhost:3306/springai?useUnicode=true&characterEncoding=utf8&serverTimezone=Asia/Shanghai&allowMultiQueries=true
    username: root
    password: roswu
    type: com.zaxxer.hikari.HikariDataSource
    hikari:
      maximum-pool-size: 10
      minimum-idle: 5
      idle-timeout: 30000
      connection-timeout: 20000
  data:
    redis:
      host: localhost
      port: 6379
      password: 123456
      database: 0
      timeout: 10000ms
      lettuce:
        pool:
          max-active: 8
          max-idle: 8
          min-idle: 0
          max-wait: -1ms


logging:
  level:
    root: info
    # AI对话的日志级别
    org.springframework.ai.chat.client.advisor: debug

mybatis-plus:
  mapper-locations: classpath:mapper/*.xml
  type-aliases-package: com.roswu.springailearning.model
  configuration:
    map-underscore-to-camel-case: true
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl

4.4 测试

服务重新启动,浏览器先输入我是你的主人Rosen,服务端响应正常。

在这里插入图片描述

查看数据库聊天数据落库正常

在这里插入图片描述

查看redis缓存状态正常

在这里插入图片描述

由于我们滑动窗口设置的是16,中间省略部分聊天内容,来到第8次输入后的状态

数据库

在这里插入图片描述

缓存

在这里插入图片描述

从缓存结果来看,目前是保存着近16条聊天记录的,也就是8轮对话,到达了我们设定的阈值,理论上,下次输入后,将会把第一次的问答内容给挤出缓存。

请求参数

在这里插入图片描述

缓存数据

在这里插入图片描述

这里的缓存数据不大符合预期,挤掉的是第一次用户问题和最后一次用户问题,理论上应该是挤掉第一次的问答才对。经排查,主要是findByConversationId里面的两个问题导致,一是反序列化为List一直在报错,然后删除了缓冲,重新查询了数据库。

com.fasterxml.jackson.databind.exc.InvalidDefinitionException: Cannot construct instance of `org.springframework.ai.chat.messages.Message` (no Creators, like default constructor, exist): abstract types either need to be mapped to concrete types, have custom deserializer, or contain additional type information
 at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` disabled); line: 1, column: 2] (through reference chain: java.util.ArrayList[0])

查询数据库这里的逻辑也有问题,滑动窗口最新16条的应该是desc,然后再调整为正序。调整后的findByConversationId方法如下:

public List<Message> findByConversationId(String conversationId) {
        String key = PREFIX + conversationId;

        // 1. 读Redis
        String json = redisTemplate.opsForValue().get(key);
        try {
//            return objectMapper.readValue(json, new TypeReference<List<Message>>() {
//            });
            if (!ObjectUtils.isEmpty(json)) {
                JsonNode array = objectMapper.readTree(json);
                List<Message> messages = new ArrayList<>();

                for (JsonNode node : array) {
                    String type = node.get("messageType").asText();
                    String text = node.get("text").asText();

                    if ("USER".equals(type)) {
                        messages.add(new UserMessage(text));
                    } else if ("ASSISTANT".equals(type)) {
                        messages.add(new AssistantMessage(text));
                    } else if ("SYSTEM".equals(type)) {
                        messages.add(new SystemMessage(text));
                    }
                }
                return messages;
            }
        } catch (Exception e) {
            redisTemplate.delete(key);
        }

        // 2. Redis失效 → 从MySQL恢复
        List<ChatMessage> dbMessages = chatMessageMapper.selectList(
                new QueryWrapper<ChatMessage>()
                        .eq("conversation_id", conversationId)
                        .orderByDesc("create_time") // 最新的在前
                        .last("LIMIT 16") // 取最新16条
        );

        // 反转 → 变成【旧 → 新】正序,模型才能正确识别
        dbMessages = dbMessages.stream().sorted(Comparator.comparing(ChatMessage::getId)).collect(Collectors.toList());

        List<Message> messages = convertToAiMessage(dbMessages);

        // 3. 写回Redis
        try {
            redisTemplate.opsForValue().set(key, objectMapper.writeValueAsString(messages), EXPIRED_MINUTES, TimeUnit.MINUTES);
        } catch (Exception ignored) {
            // do nothing
        }

        return messages;
    }

调整后重启服务,清除缓存,增加新的问题,触发阈值,查看缓存数据,整体流程就全都正常了。遗留的问题就是那个聊天内容的重复性校验,使用content进行全匹配效率太低了,可以尝试将内容hash或者转md5后先进行一轮比较,如果相等则在进行一次全匹配,应该要好得多,肝太晚了,就先到这里了。