集成 spring-ai 2.x 实践中遇到的一些问题及解决方案

25 阅读6分钟

一、接入 spring-ai

版本选择

本来选择的 1.1.2 版本,完成基础集成后,查看 github 仓库,发现后续仅基于 2.x 开发,而 2.x 基于 springboot4.x + jdk 21,于是直接升级到了 2.0.0-M1,核心 maven 依赖如下:

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-mcp-server</artifactId>
    <version>2.0.0-M1</version>
</dependency>
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-mcp-server-webmvc</artifactId>
    <version>2.0.0-M1</version>
</dependency>

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-mcp-client</artifactId>
    <version>2.0.0-M1</version>
</dependency>

二、遇到的一些集成问题

1、保存聊天记忆时如何获取请求上下文?

如下代码:

return this.openapiChatClient.prompt()
        .user(input)
        .advisors(a -> a.param("key1", "value1"))
        .stream()
        .content();

那么这段代码在 ChatMemoryRepository#saveAll 方法执行时,是无法获取到设置的 key1=value1 的上下文的。如果想要根据一些参数做特定业务或者扩展,就无法实现。怎么解决呢?
自定义一个 DecorateMessageChatMemoryAdvisor 即可,装饰下 MessageChatMemoryAdvisor。
在 MessageChatMemoryAdvisor#before 和 MessageChatMemoryAdvisor#after 方法中,将请求上下文设置到 ThreadLocal 中,这样后续就可以拿到上下文了。
如何配置呢?

    @Bean
    public ChatClient openapiChatClient(@Qualifier("openaiChatClientBuilder") ChatClient.Builder chatClientBuilder,
                                        ChatMemory chatMemory,
                                        ToolCallbackProvider[] mcpServerTools) {
        return chatClientBuilder
                .defaultAdvisors(new DecorateMessageChatMemoryAdvisor(chatMemory))
                .defaultToolCallbacks(mcpServerTools)
                .build();
    }

public class DecorateMessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {
    /**
     * 聊天记忆
     */
    private final ChatMemory chatMemory;

    /**
     * 装饰目标
     */
    private final MessageChatMemoryAdvisor advisor;

    public DecorateMessageChatMemoryAdvisor(ChatMemory chatMemory) {
        this.chatMemory = chatMemory;
        this.advisor = MessageChatMemoryAdvisor.builder(chatMemory)
                .conversationId(Constant.DEFAULT_CONVERSATION_ID)
                .scheduler(Schedulers.fromExecutorService(DecorateExecutorService.decorateTrace(Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name("vchat-memory").factory()))))
                .build();
    }

    @Override
    public String getName() {
        return this.advisor.getName();
    }

    @Override
    public Scheduler getScheduler() {
        return this.advisor.getScheduler();
    }

    @Override
    public String getConversationId(Map<String, Object> context, String defaultConversationId) {
        return this.advisor.getConversationId(context, defaultConversationId);
    }

    /**
     * 这里要复制重写,否则 this 就不对了
     */
    @Override
    public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
        // trace_id
        String traceId = MDC.get(Constant.TRACE_ID);

        // Get the scheduler from BaseAdvisor
        Scheduler scheduler = this.getScheduler();

        // Process the request with the before method
        return Mono.just(chatClientRequest)
                .publishOn(scheduler)
                .map(request -> this.before(request, streamAdvisorChain))
                .flatMapMany(streamAdvisorChain::nextStream)
                .transform(flux -> new DecorateChatClientMessageAggregator(traceId, chatClientRequest).aggregateChatClientResponse(flux, response -> this.after(response, streamAdvisorChain)));
    }

    @Override
    public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
        Map<String, Object> context = chatClientRequest.context();
        context.put(ChatMemoryConfig.CHAT_TYPE, MessageType.USER);
        Map<String, Object> prev = ChatMemoryConfig.setChatContext(context);
        try {
            // spring-ai 不存在系统提示词,这里手动保存
            SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
            this.chatMemory.add(this.getConversationId(chatClientRequest.context(), Constant.DEFAULT_CONVERSATION_ID), systemMessage);

            return this.advisor.before(chatClientRequest, advisorChain);
        } finally {
            ChatMemoryConfig.setChatContext(prev);
        }
    }

    @Override
    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        Map<String, Object> context = chatClientResponse.context();
        context.put(ChatMemoryConfig.CHAT_TYPE, MessageType.ASSISTANT);
        Map<String, Object> prev = ChatMemoryConfig.setChatContext(context);
        try {
            return this.advisor.after(chatClientResponse, advisorChain);
        } finally {
            ChatMemoryConfig.setChatContext(prev);
        }
    }

    @Override
    public int getOrder() {
        return this.advisor.getOrder();
    }

    @Override
    public int hashCode() {
        return this.advisor.hashCode();
    }

    @Override
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (obj instanceof DecorateMessageChatMemoryAdvisor decorate) {
            return decorate.advisor.equals(this.advisor);
        }
        return false;
    }

    @Override
    public String toString() {
        return this.advisor.toString();
    }
}

2、怎么输出思考过程?

spring-ai 默认不支持输出思考过程(也可能我没找到配置方式),那么怎么实现呢? 答案还是自定义 DecorateMessageChatMemoryAdvisor + 自定义 DecorateChatClientMessageAggregator。
在 DecorateMessageChatMemoryAdvisor#adviseStream 方法中,使用自定义聚合器输出思考过程即可,自定义聚合器代码如下(其中的 SSE_EMITTER_KEY 需要设置到请求的上下文中):

adviseStream@Slf4j
@RequiredArgsConstructor
public class DecorateChatClientMessageAggregator extends ChatClientMessageAggregator {
    private final String traceId;
    private final ChatClientRequest chatClientRequest;

    @Override
    public Flux<ChatClientResponse> aggregateChatClientResponse(Flux<ChatClientResponse> chatClientResponses, Consumer<ChatClientResponse> aggregationHandler) {
        String emitterKey = Mapping.from(chatClientRequest.context()).notNullMap(e -> e.get(SSE_EMITTER_KEY)).notNullMap(e -> (String) e).get();
        return super.aggregateChatClientResponse(chatClientResponses.doOnNext(e -> sendReasoning(e, emitterKey)), aggregationHandler);
    }

    protected void sendReasoning(ChatClientResponse response, String emitterKey) {
        if (response.chatResponse() == null) {
            return;
        }
        List<Generation> results = response.chatResponse().getResults();
        for (Generation result : results) {
            Map<String, Object> metadata = result.getOutput().getMetadata();
            String reasoningContent = (String) metadata.get("reasoningContent");
            if (reasoningContent != null && !reasoningContent.isBlank()) {
                if (emitterKey == null) {
                    Logs.runOnTraceId(traceId, () -> log.info("reasoning: {}", reasoningContent));
                    continue;
                }
                Emitters.sendSseEvent(emitterKey, Constant.SseEvent.REASONING_CONTENT, reasoningContent);
            }
        }
    }
}

3、流式响应时怎么传递 MDC 的 trace_id?

spring-ai 使用响应式的 WebClient 发起 ai 会话请求,因此想要在日志中打印 trace_id 需要一些特别的配置
解决方案还是自定义 DecorateMessageChatMemoryAdvisor,在 DecorateMessageChatMemoryAdvisor#getScheduler 方法中返回自定义的调度器,然后该自定义调度器配置一个任务包装器即可。

4、本地工具调用时怎么传递 MDC 的 trace_id?

所谓本地工具调用,也即工具是定义在 mcp-client 里的,无需远程调用 mcp-server 的工具,这种情况下比较简单,将 trace_id 放入 ToolContext 即可:

        return this.openapiChatClient.prompt()
                .user(input)
                .toolContext(Map.of("traceId", MDC.get("traceId")))
                .stream()
                .content();

然后编写一个 McpTools 工具:

    public static <T> T runOnToolContext(Map<String, Object> context, Supplier<T> supplier) {
        String traceId = (String) context.get(Constant.TRACE_ID);

        if (traceId == null) {
            T result = supplier.get();
            log.info("工具调用完成: {}", result);
            return result;
        }

        return runOnTraceId(traceId, supplier, result -> log.info("工具调用完成: {}", result));
    }
	
    public static <T> T runOnTraceId(final String traceId, final Supplier<T> supplier, final Consumer<T> consumer) {
        String prev = MDC.get(Constant.TRACE_ID);
        try {
            MDC.put(Constant.TRACE_ID, traceId);
            T result = supplier.get();
            if (consumer != null) {
                consumer.accept(result);
            }
            return result;
        } finally {
            if (prev == null) {
                MDC.remove(Constant.TRACE_ID);
            } else {
                MDC.put(Constant.TRACE_ID, prev);
            }
        }
    }

5、远程 mcp-server 工具怎么传递 trace_id?

远程工具调用传递时,我以为也可以在 ToolContext 中传递就可以了,结果没想到远程调用时,ToolContext 中只有一个 exchange 的 key,自定义的工具上下文根本没有!只好研究源码了。最终得到如下解决方案:

第一步、mcp-client 自定义 McpAsyncHttpClientRequestCustomizer

该扩展接口,在请求远程工具调用时,会回调,让开发者自定义请求,我们可以在这里将 trace_id 放入请求头:

@Component
public class MyMcpAsyncHttpClientRequestCustomizer implements McpAsyncHttpClientRequestCustomizer {

    @Override
    public Publisher<HttpRequest.Builder> customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) {
        String traceId = Optional.ofNullable(JSON.parseObject(body))
                .map(e -> e.getJSONObject("params"))
                .map(e -> e.getJSONObject("_meta"))
                .map(e -> e.getString("traceId"))
                .orElse(null);
        if (StrUtil.isNotBlank(traceId)) {
            builder = builder.header("traceId", traceId);
        }
        return Mono.just(builder);
    }
}

有的人可能会觉得工具上下文是不是在 McpTransportContext 里,很遗憾,并没有,所以才自己解析 body 的。

第二步、mcp-server 自定义 WebFluxStreamableServerTransportProvider

自定义 WebFluxStreamableServerTransportProvider 的目的是自定义 McpTransportContextExtractor,因为 McpTransportContextExtractor spring-ai 没有暴露为 bean:

@Configuration
public class StreamableServerTransportProviderConfig {

    @Bean
    @ConditionalOnMissingBean
    public WebFluxStreamableServerTransportProvider webFluxStreamableServerTransportProvider(@Qualifier("mcpServerObjectMapper") ObjectMapper objectMapper,
                                                                                             McpServerStreamableHttpProperties serverProperties) {

        return WebFluxStreamableServerTransportProvider.builder()
                .jsonMapper(new JacksonMcpJsonMapper(objectMapper))
                .messageEndpoint(serverProperties.getMcpEndpoint())
                .keepAliveInterval(serverProperties.getKeepAliveInterval())
                .disallowDelete(serverProperties.isDisallowDelete())
                .contextExtractor(new AgentMcpTransportContextExtractor())
                .build();
    }

    static class AgentMcpTransportContextExtractor implements McpTransportContextExtractor<ServerRequest> {

        @Override
        public McpTransportContext extract(ServerRequest request) {
            String traceId = request.headers().firstHeader("traceId");
            if (traceId == null) {
                return McpTransportContext.EMPTY;
            }
            return McpTransportContext.create(Map.of("traceId", traceId));
        }
    }
}

这样就得到了请求头的 traceId!

第三步、mcp-server 再添加一个 runOnToolContext 方法

	public static <T> T runOnToolContext(Map<String, Object> context, Supplier<T> supplier) {
        if (context.containsKey(Constant.TRACE_ID)) {
            return {mcp-client}.McpTools.runOnToolContext(context, supplier);
        }

        McpTransportContext transportContext;

        Object object = context.get(McpToolUtils.TOOL_CONTEXT_MCP_EXCHANGE_KEY);

        if (object instanceof McpSyncServerExchange exchange) {
            transportContext = exchange.transportContext();
        } else if (object instanceof McpAsyncServerExchange exchange) {
            transportContext = exchange.transportContext();
        } else {
            transportContext = null;
        }

        if (transportContext == null) {
            return {mcp-client}.McpTools.runOnToolContext(context, supplier);
        }

        Object traceId = transportContext.get(Constant.TRACE_ID);

        if (traceId == null) {
            return {mcp-client}.McpTools.runOnToolContext(context, supplier);
        }

        return {mcp-client}.McpTools.runOnToolContext(Map.of(Constant.TRACE_ID, traceId.toString()), supplier);
    }

这样就传递下来啦!

6、怎么拦截工具调用?

想要在工具调用时,给前端发一个 sse 消息,工具拦截器是最合适的,因为无论本地工具还是远程工具都适用,可惜 spring-ai 目前没有提供!
没有就自己造!

第一步、自定义 DecorateToolCallingManager

自定义 DecorateToolCallingManager 目的是,对 ToolCallback 进行包装:

@Slf4j
@RequiredArgsConstructor
public class DecorateToolCallingManager implements ToolCallingManager {
    /**
     * 装饰目标
     */
    private final ToolCallingManager decorate;

    /**
     * 拦截器
     */
    private volatile Collection<ToolCallbackInterceptor> interceptors;

    /**
     * 处理器
     */
    private volatile Collection<ToolExecutionResultHandler> handlers;

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
        Collection<ToolCallbackInterceptor> interceptors = ensureInterceptors();
        List<ToolCallback> toolCallbacks = chatOptions.getToolCallbacks()
                .stream()
                .map(e -> e instanceof DecorateToolCallback ? e : new DecorateToolCallback(e.getToolDefinition().name(), e, interceptors))
                .toList();
        chatOptions.setToolCallbacks(toolCallbacks);
        return decorate.resolveToolDefinitions(chatOptions);
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
        ContextView context = ToolCallReactiveContextHolder.getContext();
        ToolExecutionResult result = decorate.executeToolCalls(prompt, chatResponse);
        for (ToolExecutionResultHandler handler : ensureHandlers()) {
            if (handler.support(prompt, chatResponse, context)) {
                if (handler.handle(prompt, chatResponse, result, context)) {
                    continue;
                }
                break;
            }
        }
        return result;
    }

    protected Collection<ToolCallbackInterceptor> ensureInterceptors() {
        if (this.interceptors == null) {
            synchronized (this) {
                if (this.interceptors == null) {
                    this.interceptors = $.getBeans(ToolCallbackInterceptor.class);
                }
            }
        }
        return this.interceptors;
    }

    protected Collection<ToolExecutionResultHandler> ensureHandlers() {
        if (this.handlers == null) {
            synchronized (this) {
                if (this.handlers == null) {
                    this.handlers = $.getBeans(ToolExecutionResultHandler.class);
                }
            }
        }
        return this.handlers;
    }
}

这里不仅扩展了工具调用拦截器,还扩展了工具调用结果处理器,之所以不在拦截器里面添加处理器,是因为拦截器里面拿不到原始提示词以及原始响应。

第二步、自定义 ToolCallback
@RequiredArgsConstructor
public class DecorateToolCallback implements ToolCallback {
    private final String toolName;
    private final ToolCallback decorate;
    private final Collection<ToolCallbackInterceptor> interceptors;

    @Override
    public ToolDefinition getToolDefinition() {
        return decorate.getToolDefinition();
    }

    @Override
    public String call(String toolInput) {
        return call(toolInput, null);
    }

    @Override
    public ToolMetadata getToolMetadata() {
        return decorate.getToolMetadata();
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        for (ToolCallbackInterceptor interceptor : this.interceptors) {
            interceptor.beforeToolCallback(toolName, toolInput, toolContext, this);
        }

        Throwable ex = null;

        try {
            return decorate.call(toolInput, toolContext);
        } catch (Throwable e) {
            ex = e;
            throw e;
        } finally {
            for (ToolCallbackInterceptor interceptor : this.interceptors) {
                interceptor.afterToolCallback(toolName, toolInput, toolContext, this, ex);
            }
        }
    }
}
第三步、将自定义的装饰器生效

添加一个 bean 后置处理器即可:

@Configuration
public class DecorateBeanConfig implements BeanPostProcessor {

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (bean instanceof ToolCallingManager manager) {
            return new DecorateToolCallingManager(manager);
        }
        return bean;
    }
}
第四步、直接实现拦截器即可

拦截器定义:

public interface ToolCallbackInterceptor {
    /**
     * 工具调用前处理
     */
    void beforeToolCallback(String toolName, String toolInput, ToolContext toolContext, ToolCallback callback);

    /**
     * 工具调用后处理
     *
     * @param throwable 如果工具调用异常,则不为空
     */
    void afterToolCallback(String toolName, String toolInput, ToolContext toolContext, ToolCallback callback, Throwable throwable);
}

结果处理器定义:

public interface ToolExecutionResultHandler {
    /**
     * 是否支持处理
     *
     * @return true/false
     */
    boolean support(Prompt prompt, ChatResponse chatResponse, ContextView context);

    /**
     * 处理工具调用结果
     *
     * @return 返回 true 时,继续后续链的处理,返回 false 时,将忽略后续链的处理
     */
    boolean handle(Prompt prompt, ChatResponse chatResponse, ToolExecutionResult result, ContextView context);
}

注意:处理器拿不到工具上下文,但是可以拿到 reactor-core 的 ContextView(类似 java 的 ThreadLocal),将相关数据写入 Flux 的上下文即可。

7、springboot4.x 的 swagger 文档集成

文档集成也废了一些事,这里记录一下。 maven 依赖:

<!-- doc -->
<dependency>
    <groupId>com.github.xiaoymin</groupId>
    <artifactId>knife4j-openapi3-jakarta-spring-boot-starter</artifactId>
    <version>4.4.0</version>
    <exclusions>
        <exclusion>
            <groupId>org.springdoc</groupId>
            <artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
        </exclusion>
    </exclusions>
</dependency>

<dependency>
    <groupId>org.springdoc</groupId>
    <artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
    <version>3.0.0</version>
</dependency>

yaml 配置:

springdoc:
  group-configs:
    - group: 'default'
      packages-to-scan: com.aaa.bbb.controller

然后访问 /doc.html 就好啦