Spring AI 使用 Redis 实现会话记忆

202 阅读1分钟

Spring AI 使用 Redis 实现会话记忆

代码

@Component
@RequiredArgsConstructor
public class RedisChatMemory implements ChatMemory {
    private static final String REDIS_CHAT_MEMORY_KEY = "chatmemory";
    private final StringRedisTemplate redisTemplate;

    @Override
    public void add(String conversationId, List<Message> messages) {
        String key = REDIS_CHAT_MEMORY_KEY + ":" + conversationId;
        List<String> serializedMessages = messages.stream().map(this::serialize).toList();
        redisTemplate.opsForList().rightPushAll(key, serializedMessages);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        String key = REDIS_CHAT_MEMORY_KEY + ":" + conversationId;
        List<String> serializedMessages = redisTemplate.opsForList().range(key, -lastN, -1);
        if (IF.isEmpty(serializedMessages)) {
            return Collections.emptyList();
        }
        return serializedMessages.stream().map(this::deserialize).toList();
    }

    @Override
    public void clear(String conversationId) {
        redisTemplate.delete(REDIS_CHAT_MEMORY_KEY + ":" + conversationId);
    }

    private String serialize(Message message) {
        return GSON.toJson(message);
    }

    private Message deserialize(String serializedMessage) {
        String messageType = Json.get(serializedMessage, "messageType");
        return switch (messageType) {
            case "USER" -> GSON.toBean(serializedMessage, UserMessage.class);
            case "ASSISTANT" -> GSON.toBean(serializedMessage, AssistantMessage.class);
            case "SYSTEM" -> GSON.toBean(serializedMessage, SystemMessage.class);
            case "TOOL" -> GSON.toBean(serializedMessage, ToolResponseMessage.class);
            default -> throw new IllegalStateException("Unexpected value: " + messageType);
        };
    }
}

注意:我这里采用gson序列化

使用

this.chatClient = chatClientBuilder
    .defaultAdvisors(
        new MessageChatMemoryAdvisor(chatMemory)
    )
    .build();