Spring AI 使用 Redis 实现会话记忆
代码
@Component
@RequiredArgsConstructor
public class RedisChatMemory implements ChatMemory {
private static final String REDIS_CHAT_MEMORY_KEY = "chatmemory";
private final StringRedisTemplate redisTemplate;
@Override
public void add(String conversationId, List<Message> messages) {
String key = REDIS_CHAT_MEMORY_KEY + ":" + conversationId;
List<String> serializedMessages = messages.stream().map(this::serialize).toList();
redisTemplate.opsForList().rightPushAll(key, serializedMessages);
}
@Override
public List<Message> get(String conversationId, int lastN) {
String key = REDIS_CHAT_MEMORY_KEY + ":" + conversationId;
List<String> serializedMessages = redisTemplate.opsForList().range(key, -lastN, -1);
if (IF.isEmpty(serializedMessages)) {
return Collections.emptyList();
}
return serializedMessages.stream().map(this::deserialize).toList();
}
@Override
public void clear(String conversationId) {
redisTemplate.delete(REDIS_CHAT_MEMORY_KEY + ":" + conversationId);
}
private String serialize(Message message) {
return GSON.toJson(message);
}
private Message deserialize(String serializedMessage) {
String messageType = Json.get(serializedMessage, "messageType");
return switch (messageType) {
case "USER" -> GSON.toBean(serializedMessage, UserMessage.class);
case "ASSISTANT" -> GSON.toBean(serializedMessage, AssistantMessage.class);
case "SYSTEM" -> GSON.toBean(serializedMessage, SystemMessage.class);
case "TOOL" -> GSON.toBean(serializedMessage, ToolResponseMessage.class);
default -> throw new IllegalStateException("Unexpected value: " + messageType);
};
}
}
注意:我这里采用gson序列化
使用
this.chatClient = chatClientBuilder
.defaultAdvisors(
new MessageChatMemoryAdvisor(chatMemory)
)
.build();