[Spring AI] 拦截MCP调用

27 阅读4分钟

引子

在 Spring AI中,我们一般通过这样的方式来判断是否含有在一次AI请求中是否含有工具调用:

 chat.prompt().user(message)
     .stream()
     .tools(mcpTools)
     .chatResponse().subscribe((s)-> {
        if (s.hasToolCalls()) {
            //...have tool calls
        }
     });

但是我在接入DeepSeek的过程中,即使deepseek调用了MCP工具,hasToolCalls依然为false。

在这种情况下,我们该如何监听MCP调用呢?

前置知识

1. Tool Calls

Spring AI 为我们提供了很多方式来注册 mcp-tools,其中包括:

  • 添加 @Tool 注解,然后使用 tools() 函数注册

  • 使用 ToolCallbacks.from 生成 ToolCallback,然后使用 toolCallbacks()注册

  • 编程式生成 ToolCallback,例如:

    Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
    ToolCallback toolCallback = MethodToolCallback.builder()
        .toolDefinition(ToolDefinition.builder(method)
                .description("Get the current date and time in the user's timezone")
                .build())
        .toolMethod(method)
        .toolObject(new DateTimeTools())
        .build();
    

如果我们要拦截MCP的调用,那肯定需要在ToolCallback中进行hook,而我们不可能手动解析@Tool注解然后编程式生成ToolCallback(其实这个任务就是由ToolCallbacks.from函数完成的)。

综上所述,在本文中我决定使用 ToolCallbacks.from 的方式进行监听。

2. Reactor

注意到 chatResponse()的返回值为 Flux<T>,为了减少代码量,我们的对话函数肯定也得返回Flux<T>

在这里直接贴一下github的定义,感兴趣的童鞋可以自己翻:

面向JVM的非阻塞响应式流基础框架,既实现了受Reactive Extensions启发的API,又提供了高效的事件流支持。

实现细节

我们先定义一个Service,写一些样板代码:

@Service
public class LLMJavaService {
    private final ChatClient chat;
    private final MCPService mcpService;
​
    public LLMJavaService(ChatClient.Builder builder, MCPService mcpService) {
        this.mcpService = mcpService;
        this.chat = builder.defaultSystem("你是一个助手。你可以调用 get-date 的mcp工具获取时间。").build();
    }
​
    public Flux<ModelResponse> chatAsync(String message) {
        
    }
​
    public sealed interface ModelResponse permits ModelResponse.Message, ModelResponse.ToolCall, ModelResponse.ToolResponse, ModelResponse.ToolError {
        @Data
        @AllArgsConstructor
        final class Message implements ModelResponse {
            private String content;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolCall implements ModelResponse {
            private String callId;
            private String name;
            private String input;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolResponse implements ModelResponse {
            private String callId;
            private String name;
            private String input;
            private String content;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolError implements ModelResponse {
            private String callId;
            private String name;
            private String input;
            private Throwable error;
        }
    }
}
​
@Service
class MCPService {
    @Tool(name = "get-date", description = "获取当前日期")
    public String getDate() {
        return LocalDateTime.now().toString();
    }
}

我们需要实现 public Flux<ModelResponse> chatAsync(String message)里的内容,先在这里返回一个流:

Flux<ModelResponse> contentFlux = chat.prompt()
        .user(message)
        .toolCallbacks(loggerTools)
        .stream()
        .content()
        .map(ModelResponse.Message::new);

但是只有一个流怎么行呢?我们还需要拦截MCP调用。因此编写一个委托类来进行一个静态代理:

private static class DelegateToolCallback implements ToolCallback {
        private final ToolCallback toolCallback;
        private final Consumer<ModelResponse> listener;
​
        DelegateToolCallback(ToolCallback toolCallback, Consumer<ModelResponse> listener) {
            this.toolCallback = toolCallback;
            this.listener = listener;
        }
​
        @NotNull
        @Override
        public ToolDefinition getToolDefinition() {
            return toolCallback.getToolDefinition();
        }
​
        @NotNull
        @Override
        public ToolMetadata getToolMetadata() {
            return toolCallback.getToolMetadata();
        }
​
        @NotNull
        @Override
        public String call(@NotNull String toolInput, ToolContext toolContext) {
            String uuid = UUID.randomUUID().toString().replace("-", "");
            listener.accept(new ModelResponse.ToolCall(uuid, this.getToolDefinition().name(), toolInput));
            String result;
            try {
                result = toolCallback.call(toolInput, toolContext);
            } catch (Throwable e) {
                listener.accept(new ModelResponse.ToolError(uuid, this.getToolDefinition().name(), toolInput, e));
                throw e;
            }
​
            listener.accept(new ModelResponse.ToolResponse(uuid, this.getToolDefinition().name(), toolInput, result));
            return result;
        }
​
        @NotNull
        @Override
        public String call(@NotNull String toolInput) {
            return call(toolInput, null);
        }
    }

这个经过代理的工具唯一作用就是在 call函数被调用之前/之后,记录传参并回报。

我们还需要再产生一个流,这个流负责接收ToolCallback所产生的ModelResponse:

Sinks.Many<ModelResponse> toolSink = Sinks.many().unicast().onBackpressureBuffer();
​
ToolCallback[] loggerTools = Stream.of(ToolCallbacks.from(mcpService))
        .map(callback -> wrapperListeners(callback, toolSink::tryEmitNext))
        .toArray(value -> new ToolCallback[0]);
​
Flux<ModelResponse> toolFlux = toolSink.asFlux();

在这里我们直接使用Reactor的Sinks来生成Sink,然后使用asFlux()将其转成Flux。

最后将两个流合并后返回:

return Flux.merge(toolFlux, contentFlux);

然后编写测试代码(Kotlin.ver):

@SpringBootTest
class SpringMcpDemoApplicationTests {
​
    @Autowired
    private lateinit var lLMService: LLMService
​
    @Test
    fun contextLoads() {
        val resp = lLMService.chatAsync("今天的日期是多少?")
        val latch = CountDownLatch(1)
        resp.doOnEach { println(it) }.doOnComplete { latch.countDown() }.subscribe()
        latch.await()
    }
​
}

点击测试,发现程序在输出完毕后卡死了?哦,原来merge的流要所有子流全部关闭后才能关闭,而我们没有任何关闭toolFlux的代码,让我们关闭一下:

Flux<ModelResponse> contentFlux = chat.prompt()
        .user(message)
        .toolCallbacks(loggerTools)
        .stream()
        .content()
        .doOnComplete(toolSink::tryEmitComplete) //这里是新加的
        .map(ModelResponse.Message::new);

现在测试就没有问题了!

完整代码

package top.kagg886.mcpdemo.springmcpdemo.service;
​
import lombok.AllArgsConstructor;
import lombok.Data;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
​
import java.util.UUID;
import java.util.function.IntFunction;
import java.util.stream.Stream;
​
@Service
public class LLMJavaService {
    private final ChatClient chat;
    private final MCPService mcpService;
​
    public LLMJavaService(ChatClient.Builder builder, MCPService mcpService) {
        this.mcpService = mcpService;
        this.chat = builder.defaultSystem("你是一个助手。你可以调用 get-date 的mcp工具获取时间。").build();
    }
​
    public Flux<ModelResponse> chatAsync(String message) {
        Sinks.Many<ModelResponse> toolSink = Sinks.many().unicast().onBackpressureBuffer();
​
        ToolCallback[] loggerTools = Stream.of(ToolCallbacks.from(mcpService))
                .map(callback -> wrapperListeners(callback, toolSink::tryEmitNext))
                .toArray(value -> new ToolCallback[0]);
​
        Flux<ModelResponse> toolFlux = toolSink.asFlux();
​
        Flux<ModelResponse> contentFlux = chat.prompt()
                .user(message)
                .toolCallbacks(loggerTools)
                .stream()
                .content()
                .doOnComplete(toolSink::tryEmitComplete)
                .map(ModelResponse.Message::new);
​
        return Flux.merge(toolFlux, contentFlux);
    }
​
    private ToolCallback wrapperListeners(ToolCallback toolCallback, java.util.function.Consumer<ModelResponse> listener) {
        return new DelegateToolCallback(toolCallback, listener);
    }
​
    private static class DelegateToolCallback implements ToolCallback {
        private final ToolCallback toolCallback;
        private final Consumer<ModelResponse> listener;
​
        DelegateToolCallback(ToolCallback toolCallback, Consumer<ModelResponse> listener) {
            this.toolCallback = toolCallback;
            this.listener = listener;
        }
​
        @NotNull
        @Override
        public ToolDefinition getToolDefinition() {
            return toolCallback.getToolDefinition();
        }
​
        @NotNull
        @Override
        public ToolMetadata getToolMetadata() {
            return toolCallback.getToolMetadata();
        }
​
        @NotNull
        @Override
        public String call(@NotNull String toolInput, ToolContext toolContext) {
            String uuid = UUID.randomUUID().toString().replace("-", "");
            listener.accept(new ModelResponse.ToolCall(uuid, this.getToolDefinition().name(), toolInput));
            String result;
            try {
                result = toolCallback.call(toolInput, toolContext);
            } catch (Throwable e) {
                listener.accept(new ModelResponse.ToolError(uuid, this.getToolDefinition().name(), toolInput, e));
                throw e;
            }
​
            listener.accept(new ModelResponse.ToolResponse(uuid, this.getToolDefinition().name(), toolInput, result));
            return result;
        }
​
        @NotNull
        @Override
        public String call(@NotNull String toolInput) {
            return call(toolInput, null);
        }
    }
​
    public sealed interface ModelResponse permits ModelResponse.Message, ModelResponse.ToolCall, ModelResponse.ToolResponse, ModelResponse.ToolError {
        @Data
        @AllArgsConstructor
        final class Message implements ModelResponse {
            private String content;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolCall implements ModelResponse {
            private String callId;
            private String name;
            private String input;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolResponse implements ModelResponse {
            private String callId;
            private String name;
            private String input;
            private String content;
        }
​
        @Data
        @AllArgsConstructor
        final class ToolError implements ModelResponse {
            private String callId;
            private String name;
            private String input;
            private Throwable error;
        }
    }
}