Spring AI 自定义Redis持久化ChatMemory

311 阅读4分钟

若要实现自定义MySQL持久化ChatMemory,可以参考这篇文章 Spring AI 自定义数据库持久化的ChatMemory

自定义实现ChatMemory

Spring AI的对话记忆实现非常巧妙,解耦了"储存"和"记忆算法",

  • 存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。

  • 记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作

    • MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
    • PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
    • VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。

我们可以单独修改ChatMemory储存来改变对话记忆的保存位置,而无需修改保存对话记忆的流程.

虽然官方文档没有给我们自定义ChatMemory实现的示例,但是我们可以直接去阅读默认实现类 InMemoryChatMemory 的源码

基于内存持久化的ChatMemory

其本质是实现了ChatMemory的增删查接口

ChatMemory

所以我们想实现自己的持久化,修改对应的储存实现就行了.

参考 InMemoryChatMemory 的源码,其实就是通过 ConcurrentHashMap 来维护对话信息,key 是对话 id(相当于房间号),value 是该对话 id 对应的消息列表。

自定义Redis持久化ChatMemory

由于List<Message>中Message是一个接口,虽然需要实现的接口不多,但是实现起来还是有一定复杂度的,一个最主要的问题是 消息和文本的转换。我们在保存消息时,要将消息从 Message 对象转为文件内的文本;读取消息时,要将文件内的文本转换为 Message 对象。也就是对象的序列化和反序列化。

我们本能地会想到通过 JSON 进行序列化,但实际操作中,我们发现这并不容易。原因是:

  1. 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
  2. 每种子类所拥有的字段都不一样,结构不统一
  3. 子类没有无参构造函数,而且没有实现 Serializable 序列化接口

在这里我们使用Kryo的序列化库来实现序列化

1)引入redis依赖

这里使用的是Spring 3.4.4 ,Java 21

        <!-- Redis -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

2)修改配置


# redis 配置
data:
  redis:
    port: 6379
    host: localhost
    database: 0

3)配置redis的bean注入

@Configuration
public class RedisTemplateConfig {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);
        template.setKeySerializer(RedisSerializer.string());
        return template;
    }

}

4)实现序列化

  1. 引入依赖
		<!-- 自定义持久化的序列化库-->
		<dependency>
    		<groupId>com.esotericsoftware</groupId>
    		<artifactId>kryo</artifactId>
    		<version>5.6.2</version>
		</dependency>
  1. 创建序列化实现工具类
@Component
public class MessageSerializer {

    // ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)
    private static final Kryo kryo = new Kryo();

    static {
        kryo.setRegistrationRequired(false);
        // 设置实例化策略(需确保兼容所有 Message 实现类)
        kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
    }

    /**
     * 使用 Kryo 将 Message 序列化为 Base64 字符串
     */
    public static String serialize(Message message) {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             Output output = new Output(baos)) {
            kryo.writeClassAndObject(output, message);  // ⚠️ 依赖动态注册和实例化策略
            output.flush();
            return Base64.getEncoder().encodeToString(baos.toByteArray());
        } catch (IOException e) {
            throw new RuntimeException("序列化失败", e);
        }
    }

    /**
     * 使用 Kryo 将 Base64 字符串反序列化为 Message 对象
     */
    public static Message deserialize(String base64) {
        try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));
             Input input = new Input(bais)) {
            return (Message) kryo.readClassAndObject(input);  // ⚠️ 依赖动态注册和实例化策略
        } catch (IOException e) {
            throw new RuntimeException("反序列化失败", e);
        }
    }
}

5)实现自定义ChatMemory

/**
 * 自定义Redis持久化
 */
@Service
@Slf4j
public class RedisChatMemory implements ChatMemory {

    private RedisTemplate<String, Object> redisTemplate;

    public RedisChatMemory(RedisTemplate<String, Object> objectRedisTemplate){
        this.redisTemplate = objectRedisTemplate;
    }


            

    /**
     * 添加一个数据到Redis中
     * @param conversationId
     * @param message
     */
    @Override
    public void add(String conversationId, Message message) {
        setToRedis(conversationId,List.of(message));


    }

    /**
     * 添加多条数据到Redis中
     * 先从redis中提取数据,如果不存在则创建
     * @param conversationId
     * @param messages
     */
    @Override
    public void add(String conversationId, List<Message> messages) {
        List<Message> messageList = getFromRedis(conversationId);
        messageList.addAll(messages);

        setToRedis(conversationId,messages);


    }

    /**
     * 从Redis中获取数据,
     * 从Redis中获取倒数lastN条数据
     * @param conversationId
     * @param lastN
     * @return
     */
    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<Message> messageList = getFromRedis(conversationId);
        return messageList.stream()
                .skip(Math.max(0, messageList.size() - lastN))
                .toList();
    }

    /**
     * 清空数据
     * @param conversationId
     */
    @Override
    public void clear(String conversationId) {
        redisTemplate.delete(conversationId);

    }

/**
     * 从Redis获取数据工具方法
     * @param conversationId
     * @return
     */
    private List<Message> getFromRedis(String conversationId){
        Object obj =  redisTemplate.opsForValue().get(conversationId);
        List<Message> messageList  = new ArrayList<>();
        if(obj != null){
            List<String> list = Convert.convert(new TypeReference<List<String>>() {
            }, obj);

            for (String s : list) {
                Message message = MessageSerializer.deserialize(s);
                messageList.add(message);
            }
        }
        return messageList;
    }


    /**
     * 将数据存入Redis工具方法
     * @param conversationId
     * @param messages
     */
    private void setToRedis(String conversationId,List<Message> messages){
        List<String> stringList = new ArrayList<>();
        for (Message message : messages) {
            String serialize = MessageSerializer.serialize(message);
            stringList.add(serialize);
        }
        redisTemplate.opsForValue().set(conversationId,stringList);
    }


}



6)配置到自己的APP里面

this.chatClient = ChatClient.builder(dashscopeChatModel)
        .defaultSystem(SYSTEM_PROMPT)
        .defaultAdvisors(
                new MessageChatMemoryAdvisor(redisChatMemory),
                //自定义日志拦截器,可按需开启
                new MyLoggerAdvisor(),
                //权限校验
                new AuthAdvisor(),
                //违禁词校验
                new BannedWordsAdvisor()
        )
        .build();
代码测试
  1. 先与AI对话

在这里插入图片描述

  1. 重启项目重新对话,询问记录的消息

在这里插入图片描述

成功!!!