最近在学习Spring AI框架和Langchain4j框架的生产级应用,由于对于框架的不熟悉导致花了好几天的时间才脱坑。
本博文详细记录笔者在Spring AI(Langchain4j框架下同理,只是以Spring AI框架为例)框架下对接LLM大模型流式输出场景下的坑以及脱坑方案。
踩坑详解
背景介绍
Spring AI版本:
dependencies {
implementation 'org.springframework.ai:spring-ai-starter-model-openai:1.1.4'
}
使用的ChatMemory模型是自己封装的基于Redis的聊天短期记忆模型:
@Component(value = "redisChatMemory")
public class RedisChatMemory implements ChatMemory {
private static final String KEY_PREFIX = "ai:chat:memory:";
private static final int MAX_HISTORY = 10;
private final RedisTemplate<String, Message> redisTemplate;
public RedisChatMemory(RedisTemplate<String, Message> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Override
public void add(String conversationId, List<Message> messages) {
if (CollUtil.isNotEmpty(messages)) {
String key = KEY_PREFIX + conversationId;
redisTemplate.opsForList().rightPushAll(key, messages);
redisTemplate.opsForList().trim(key, -MAX_HISTORY, -1);
}
}
@Override
public List<Message> get(String conversationId) {
String key = KEY_PREFIX + conversationId;
return redisTemplate.opsForList().range(key, 0, -1);
}
@Override
public void clear(String conversationId) {
redisTemplate.delete(KEY_PREFIX + conversationId);
}
}
功能描述
在我搭建一个类似于大模型门面的文本聊天页面的时候,想要实现以下功能:
- 记录用户的聊天记录
- 前端在发送了SSE请求之后可以随时中断
- 前端在终端SSE之后发起新的请求,根据用户新的提示词生成新的响应文案流
就是在我实现前端中断了SSE请求之后,再次发送一个新的交互请求,发现后端给我的推送还是上一个提示词的响应文案,直接把我干蒙圈了。
试错经历
遇到这个现象,我首先并且坚定的认为是前端中断了SSE连接了之后,后端并没有实时中断和LLM的SSE连接,导致的响应流还是给后端不断地推送文案。前端的新的SSE请求过来的收,后端把挤压的旧的流的响应直接错乱,推给了新的请求连接的响应(在这里我给框架的开发人员郑重道歉,我真是个天才)。 之后就不断寻找解决方案,其中试过以下几种方案:
方案一
直接在ChatClient的响应流里面加一个订阅,用于获取响应流的Disposable对象,在新的SSE请求过来的时候先中断之前的SSE连接,在发起新的SSE连接。这种方案可以实现上述的理想效果,但是会抛一个莫名其妙的异常:
java.lang.IllegalStateException: No StreamAdvisors available to execute
at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.lambda$nextStream$6(DefaultAroundAdvisorChain.java:127)
at reactor.core.publisher.FluxDeferContextual.subscribe(FluxDeferContextual.java:49)
at reactor.core.publisher.Flux.subscribe(Flux.java:8891)
at reactor.core.publisher.MonoFlatMapMany$FlatMapManyMain.onNext(MonoFlatMapMany.java:196)
at reactor.core.publisher.FluxMap$MapSubscriber.onNext(FluxMap.java:122)
at reactor.core.publisher.FluxSubscribeOnValue$ScheduledScalar.run(FluxSubscribeOnValue.java:181) at reactor.core.scheduler.SchedulerTask.call(SchedulerTask.java:68)
at reactor.core.scheduler.SchedulerTask.call(SchedulerTask.java:28)
代码书写:
Disposable disposable = ChatClient.create(openAiChatModel)
.prompt()
.user(messsage)
.system(systemMessage)
.advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build())
.stream()
.content()
.timeout(Duration.ofMinutes(2))
.subscribe();
之后的三个小时,我就把源码基本上扒了一遍,找到了抛异常的原因: org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec#doGetObservableFluxChatResponse
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
return Flux.deferContextual(contextView -> {
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.request(chatClientRequest)
.advisors(this.advisorChain.getStreamAdvisors())
.stream(true)
.build();
Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
this.observationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
() -> observationContext, this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null))
.start();
// @formatter:off
// Apply the advisor chain that terminates with the ChatModelStreamAdvisor.
// 这里触发advisorChain.nextStream方法
Flux<ChatClientResponse> chatClientResponse = this.advisorChain.nextStream(chatClientRequest)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on
return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse,
observationContext::setResponse);
});
}
org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain#nextStream
@Override
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
return Flux.deferContextual(contextView -> {
// 手动订阅的话 直接上下文丢失 这里抛出异常
if (this.streamAdvisors.isEmpty()) {
return Flux.error(new IllegalStateException("No StreamAdvisors available to execute"));
}
var advisor = this.streamAdvisors.pop();
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.chatClientRequest(chatClientRequest)
.order(advisor.getOrder())
.build();
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(this.observationConvention,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
// @formatter:off
Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
// @formatter:on
return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse,
observationContext::setChatClientResponse);
});
}
org.springframework.web.servlet.mvc.method.annotation.ReactiveTypeHandler#handleValue
@Nullable
public ResponseBodyEmitter handleValue(
Object returnValue, MethodParameter returnType, @Nullable MediaType presetContentType,
ModelAndViewContainer mav, NativeWebRequest request) throws Exception {
Assert.notNull(returnValue, "Expected return value");
Class<?> clazz = returnValue.getClass();
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz);
Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz);
TaskDecorator taskDecorator = null;
if (isContextPropagationPresent) {
ContextSnapshotHelper helper = (ContextSnapshotHelper) this.contextSnapshotHelper;
Assert.notNull(helper, "No ContextSnapshotHelper");
returnValue = helper.writeReactorContext(returnValue);
taskDecorator = helper.getTaskDecorator();
}
ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric();
Class<?> elementClass = elementType.toClass();
Collection<MediaType> mediaTypes = getMediaTypes(request, presetContentType);
Optional<MediaType> mediaType = mediaTypes.stream().filter(MimeType::isConcrete).findFirst();
if (adapter.isMultiValue()) {
// 主要看这里 是SSE响应直接在这里订阅返回的Flux对象,会触发Flux的订阅
if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) ||
ServerSentEvent.class.isAssignableFrom(elementClass)) {
SseEmitter emitter = new SseEmitter(STREAMING_TIMEOUT_VALUE);
new SseEmitterSubscriber(emitter, this.taskExecutor, taskDecorator).connect(adapter, returnValue);
return emitter;
}
if (CharSequence.class.isAssignableFrom(elementClass)) {
ResponseBodyEmitter emitter = getEmitter(mediaType.orElse(MediaType.TEXT_PLAIN));
new TextEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
return emitter;
}
MediaType streamingResponseType = findConcreteJsonStreamMediaType(mediaTypes);
if (streamingResponseType != null) {
ResponseBodyEmitter emitter = getEmitter(streamingResponseType);
new JsonEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
return emitter;
}
}
// Not streaming...
DeferredResult<Object> result = new DeferredResult<>();
new DeferredResultSubscriber(result, adapter, elementType).connect(adapter, returnValue);
WebAsyncUtils.getAsyncManager(request).startDeferredResultProcessing(result, mav);
return null;
}
以上代码就可以了解到异常抛出的原因:
- 在我们调用ChatClient的stream方法之后只是返回了一个Flux冷流,Flux冷流不被订阅的话不会执行初始化的动作
- 我们手动订阅的时候会触发Flux的初始化动作,但是此时初始化的时候上下文里面没有Spring内部管理的bean,导致报错
- 同时,还有一个原因是因为同一个Flux被订阅了两次,Spring MVC的返回值处理器会订阅一次,我们手动在订阅,就上下文丢失了
那显然这个方式报错没办法解决,也会破坏Spring的内部机制。因为本来就是需要Spring MVC的对应返回值处理器去订阅的。
方案二
既然我们不能破坏Spring AI的Flux流订阅,就想着在ChatClient返回响应之后直接手动订阅,然后把LLM给的Flux直接塞给我们手动创建的冷流里面,让SpringMVC去订阅我们手动创建的冷流,我们就可以拿到跟LLM交互的Flux的Disposable对象对象了。我们就可以在新请求过来的时候先停掉旧的连接了。
满心欢喜,试过之后发现没有什么蛋用,直接裂开。
心情不好就不贴代码了。
解决方案
在我快要崩溃的时候,我看了一眼聊天记忆的记录,发现了一个问题:如果是前端主动断开连接的话,后端生成的片段的响应文案并没有保存在聊天记忆里面!导致后续的新的请求带着新问题的时候,和老问题中间并没有AssistantMessage!
这样的提示词就会出现问题:LLM会解析到我们的老问题,以为那就是本次交互的问题,后面的用户提示词都只是老问题的补充文案!所以会响应老问题的答案!
所以问题根本就不是处在连接是不是断干净了,而是在交互的提示词链上出现了问题。
那为什么没有保存上LLM响应的片段呢?我们通过源码会了解: org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor#after
@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
List<Message> assistantMessages = new ArrayList<>();
if (chatClientResponse.chatResponse() != null) {
assistantMessages = chatClientResponse.chatResponse()
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
}
// 这里写了,只有在LLM响应完成之后,才会把assistantMessages放在聊天记忆里面
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
assistantMessages);
return chatClientResponse;
}
也就是说,Spring AI处理聊天记忆会在开始和LLM交互之前,把用户的提示词存储到聊天记忆里面,在LLM正常完成响应之后才会把响应文案存储到聊天记忆里面! 那我们前端主动中断的场景根本就不会触发响应完成的事件,那响应的片段也就直接丢失了。
到这里,其实我们的解决方案就呼之欲出了:我们只需要订阅ChatClient返回的Flux的doOnCancel事件,在流正常推送的时候搞一下中间的缓存来缓存一下LLM已经生成的响应片段,之后再监听到前端主动断开连接的时候,手动把响应片段保存到聊天记忆里面就可以了。 简单示例:
ChatClient.create(openAiChatModel)
.prompt()
.user(finalMessage)
.system(AiServerCommonConstant.BLOG_ASSISTANT_SYSTEM_PROMPT)
.advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build())
.stream()
.content()
.map(token -> {
// 这里保存token到缓存中
chatResponseCacheUtil.append(taskId, token);
return token;
})
.timeout(Duration.ofMinutes(2))
.doOnCancel(() -> {
// 这里获取到缓存片段,手动保存到聊天记忆里面
String cacheResponseText = chatResponseCacheUtil.getCache(taskId);
savePartialResponse(conversationId, finalMessage, cacheResponseText);
chatResponseCacheUtil.removeCache(taskId);
})
// 后面做一些缓存资源的清除,防止内存泄漏
.doOnComplete(() -> chatResponseCacheUtil.removeCache(taskId))
.doOnTerminate(() -> chatResponseCacheUtil.removeCache(taskId))
.doOnError(error -> chatResponseCacheUtil.removeCache(taskId));
这样修改之后,我们的问题就解决了。