踩坑 SpringAI 流式输出:前端中断后串流?一招彻底根治!

0 阅读6分钟

最近在学习Spring AI框架和Langchain4j框架的生产级应用,由于对于框架的不熟悉导致花了好几天的时间才脱坑。

本博文详细记录笔者在Spring AI(Langchain4j框架下同理,只是以Spring AI框架为例)框架下对接LLM大模型流式输出场景下的坑以及脱坑方案。

踩坑详解

背景介绍

Spring AI版本:

dependencies {
    implementation 'org.springframework.ai:spring-ai-starter-model-openai:1.1.4'
}

使用的ChatMemory模型是自己封装的基于Redis的聊天短期记忆模型:

@Component(value = "redisChatMemory")
public class RedisChatMemory implements ChatMemory {

    private static final String KEY_PREFIX = "ai:chat:memory:";
    private static final int MAX_HISTORY = 10;
    private final RedisTemplate<String, Message> redisTemplate;

    public RedisChatMemory(RedisTemplate<String, Message> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Override
    public void add(String conversationId, List<Message> messages) {
        if (CollUtil.isNotEmpty(messages)) {
            String key = KEY_PREFIX + conversationId;
            redisTemplate.opsForList().rightPushAll(key, messages);
            redisTemplate.opsForList().trim(key, -MAX_HISTORY, -1);
        }
    }

    @Override
    public List<Message> get(String conversationId) {
        String key = KEY_PREFIX + conversationId;
        return redisTemplate.opsForList().range(key, 0, -1);
    }

    @Override
    public void clear(String conversationId) {
        redisTemplate.delete(KEY_PREFIX + conversationId);
    }

}

功能描述

在我搭建一个类似于大模型门面的文本聊天页面的时候,想要实现以下功能:

  • 记录用户的聊天记录
  • 前端在发送了SSE请求之后可以随时中断
  • 前端在终端SSE之后发起新的请求,根据用户新的提示词生成新的响应文案流

就是在我实现前端中断了SSE请求之后,再次发送一个新的交互请求,发现后端给我的推送还是上一个提示词的响应文案,直接把我干蒙圈了。

试错经历

遇到这个现象,我首先并且坚定的认为是前端中断了SSE连接了之后,后端并没有实时中断和LLM的SSE连接,导致的响应流还是给后端不断地推送文案。前端的新的SSE请求过来的收,后端把挤压的旧的流的响应直接错乱,推给了新的请求连接的响应(在这里我给框架的开发人员郑重道歉,我真是个天才)。 之后就不断寻找解决方案,其中试过以下几种方案:

方案一

直接在ChatClient的响应流里面加一个订阅,用于获取响应流的Disposable对象,在新的SSE请求过来的时候先中断之前的SSE连接,在发起新的SSE连接。这种方案可以实现上述的理想效果,但是会抛一个莫名其妙的异常:

java.lang.IllegalStateException: No StreamAdvisors available to execute 
at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.lambda$nextStream$6(DefaultAroundAdvisorChain.java:127) 
at reactor.core.publisher.FluxDeferContextual.subscribe(FluxDeferContextual.java:49) 
at reactor.core.publisher.Flux.subscribe(Flux.java:8891) 
at reactor.core.publisher.MonoFlatMapMany$FlatMapManyMain.onNext(MonoFlatMapMany.java:196) 
at reactor.core.publisher.FluxMap$MapSubscriber.onNext(FluxMap.java:122) 
at reactor.core.publisher.FluxSubscribeOnValue$ScheduledScalar.run(FluxSubscribeOnValue.java:181) at reactor.core.scheduler.SchedulerTask.call(SchedulerTask.java:68) 
at reactor.core.scheduler.SchedulerTask.call(SchedulerTask.java:28)

代码书写:

Disposable disposable = ChatClient.create(openAiChatModel)
        .prompt()
        .user(messsage)
        .system(systemMessage)
        .advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build())
        .stream()
        .content()
        .timeout(Duration.ofMinutes(2))
        .subscribe();

之后的三个小时,我就把源码基本上扒了一遍,找到了抛异常的原因: org.springframework.ai.chat.client.DefaultChatClient.DefaultStreamResponseSpec#doGetObservableFluxChatResponse

private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
    return Flux.deferContextual(contextView -> {

       ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
          .request(chatClientRequest)
          .advisors(this.advisorChain.getStreamAdvisors())
          .stream(true)
          .build();

       Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
             this.observationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
             () -> observationContext, this.observationRegistry);

       observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null))
          .start();

       // @formatter:off
       // Apply the advisor chain that terminates with the ChatModelStreamAdvisor.
       // 这里触发advisorChain.nextStream方法
       Flux<ChatClientResponse> chatClientResponse = this.advisorChain.nextStream(chatClientRequest)
             .doOnError(observation::error)
             .doFinally(s -> observation.stop())
             .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
       // @formatter:on
       return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse,
             observationContext::setResponse);
    });
}

org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain#nextStream

@Override
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
    Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");

    return Flux.deferContextual(contextView -> {
        // 手动订阅的话  直接上下文丢失 这里抛出异常
       if (this.streamAdvisors.isEmpty()) {
          return Flux.error(new IllegalStateException("No StreamAdvisors available to execute"));
       }

       var advisor = this.streamAdvisors.pop();

       AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
          .advisorName(advisor.getName())
          .chatClientRequest(chatClientRequest)
          .order(advisor.getOrder())
          .build();

       var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(this.observationConvention,
             DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);

       observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

       // @formatter:off
       Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
                .doOnError(observation::error)
                .doFinally(s -> observation.stop())
                .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
       // @formatter:on
       return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse,
             observationContext::setChatClientResponse);
    });
}

org.springframework.web.servlet.mvc.method.annotation.ReactiveTypeHandler#handleValue

@Nullable
public ResponseBodyEmitter handleValue(
       Object returnValue, MethodParameter returnType, @Nullable MediaType presetContentType,
       ModelAndViewContainer mav, NativeWebRequest request) throws Exception {

    Assert.notNull(returnValue, "Expected return value");
    Class<?> clazz = returnValue.getClass();
    ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz);
    Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz);

    TaskDecorator taskDecorator = null;
    if (isContextPropagationPresent) {
       ContextSnapshotHelper helper = (ContextSnapshotHelper) this.contextSnapshotHelper;
       Assert.notNull(helper, "No ContextSnapshotHelper");
       returnValue = helper.writeReactorContext(returnValue);
       taskDecorator = helper.getTaskDecorator();
    }

    ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric();
    Class<?> elementClass = elementType.toClass();

    Collection<MediaType> mediaTypes = getMediaTypes(request, presetContentType);
    Optional<MediaType> mediaType = mediaTypes.stream().filter(MimeType::isConcrete).findFirst();

    if (adapter.isMultiValue()) {
        // 主要看这里 是SSE响应直接在这里订阅返回的Flux对象,会触发Flux的订阅
       if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) ||
             ServerSentEvent.class.isAssignableFrom(elementClass)) {
          SseEmitter emitter = new SseEmitter(STREAMING_TIMEOUT_VALUE);
          new SseEmitterSubscriber(emitter, this.taskExecutor, taskDecorator).connect(adapter, returnValue);
          return emitter;
       }
       if (CharSequence.class.isAssignableFrom(elementClass)) {
          ResponseBodyEmitter emitter = getEmitter(mediaType.orElse(MediaType.TEXT_PLAIN));
          new TextEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
          return emitter;
       }
       MediaType streamingResponseType = findConcreteJsonStreamMediaType(mediaTypes);
       if (streamingResponseType != null) {
          ResponseBodyEmitter emitter = getEmitter(streamingResponseType);
          new JsonEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
          return emitter;
       }
    }

    // Not streaming...
    DeferredResult<Object> result = new DeferredResult<>();
    new DeferredResultSubscriber(result, adapter, elementType).connect(adapter, returnValue);
    WebAsyncUtils.getAsyncManager(request).startDeferredResultProcessing(result, mav);

    return null;
}

以上代码就可以了解到异常抛出的原因:

  1. 在我们调用ChatClient的stream方法之后只是返回了一个Flux冷流,Flux冷流不被订阅的话不会执行初始化的动作
  2. 我们手动订阅的时候会触发Flux的初始化动作,但是此时初始化的时候上下文里面没有Spring内部管理的bean,导致报错
  3. 同时,还有一个原因是因为同一个Flux被订阅了两次,Spring MVC的返回值处理器会订阅一次,我们手动在订阅,就上下文丢失了

那显然这个方式报错没办法解决,也会破坏Spring的内部机制。因为本来就是需要Spring MVC的对应返回值处理器去订阅的。

方案二

既然我们不能破坏Spring AI的Flux流订阅,就想着在ChatClient返回响应之后直接手动订阅,然后把LLM给的Flux直接塞给我们手动创建的冷流里面,让SpringMVC去订阅我们手动创建的冷流,我们就可以拿到跟LLM交互的Flux的Disposable对象对象了。我们就可以在新请求过来的时候先停掉旧的连接了。 满心欢喜,试过之后发现没有什么蛋用,直接裂开。

心情不好就不贴代码了。

解决方案

在我快要崩溃的时候,我看了一眼聊天记忆的记录,发现了一个问题:如果是前端主动断开连接的话,后端生成的片段的响应文案并没有保存在聊天记忆里面!导致后续的新的请求带着新问题的时候,和老问题中间并没有AssistantMessage!

这样的提示词就会出现问题:LLM会解析到我们的老问题,以为那就是本次交互的问题,后面的用户提示词都只是老问题的补充文案!所以会响应老问题的答案!

所以问题根本就不是处在连接是不是断干净了,而是在交互的提示词链上出现了问题。

那为什么没有保存上LLM响应的片段呢?我们通过源码会了解: org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor#after

@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
    List<Message> assistantMessages = new ArrayList<>();
    if (chatClientResponse.chatResponse() != null) {
       assistantMessages = chatClientResponse.chatResponse()
          .getResults()
          .stream()
          .map(g -> (Message) g.getOutput())
          .toList();
    }
    // 这里写了,只有在LLM响应完成之后,才会把assistantMessages放在聊天记忆里面
    this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
          assistantMessages);
    return chatClientResponse;
}

也就是说,Spring AI处理聊天记忆会在开始和LLM交互之前,把用户的提示词存储到聊天记忆里面,在LLM正常完成响应之后才会把响应文案存储到聊天记忆里面! 那我们前端主动中断的场景根本就不会触发响应完成的事件,那响应的片段也就直接丢失了。

到这里,其实我们的解决方案就呼之欲出了:我们只需要订阅ChatClient返回的Flux的doOnCancel事件,在流正常推送的时候搞一下中间的缓存来缓存一下LLM已经生成的响应片段,之后再监听到前端主动断开连接的时候,手动把响应片段保存到聊天记忆里面就可以了。 简单示例:

ChatClient.create(openAiChatModel)
        .prompt()
        .user(finalMessage)
        .system(AiServerCommonConstant.BLOG_ASSISTANT_SYSTEM_PROMPT)
        .advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build())
        .stream()
        .content()
        .map(token -> {
            // 这里保存token到缓存中
            chatResponseCacheUtil.append(taskId, token);
            return token;
        })
        .timeout(Duration.ofMinutes(2))
        .doOnCancel(() -> {
            // 这里获取到缓存片段,手动保存到聊天记忆里面
            String cacheResponseText = chatResponseCacheUtil.getCache(taskId);
            savePartialResponse(conversationId, finalMessage, cacheResponseText);
            chatResponseCacheUtil.removeCache(taskId);
        })
        // 后面做一些缓存资源的清除,防止内存泄漏
        .doOnComplete(() -> chatResponseCacheUtil.removeCache(taskId))
        .doOnTerminate(() -> chatResponseCacheUtil.removeCache(taskId))
        .doOnError(error -> chatResponseCacheUtil.removeCache(taskId));

这样修改之后,我们的问题就解决了。