引子
在 Spring AI中,我们一般通过这样的方式来判断是否含有在一次AI请求中是否含有工具调用:
chat.prompt().user(message)
.stream()
.tools(mcpTools)
.chatResponse().subscribe((s)-> {
if (s.hasToolCalls()) {
//...have tool calls
}
});
但是我在接入DeepSeek的过程中,即使deepseek调用了MCP工具,hasToolCalls依然为false。
在这种情况下,我们该如何监听MCP调用呢?
前置知识
1. Tool Calls
Spring AI 为我们提供了很多方式来注册 mcp-tools,其中包括:
-
添加 @Tool 注解,然后使用
tools()
函数注册 -
使用
ToolCallbacks.from
生成 ToolCallback,然后使用toolCallbacks()
注册 -
编程式生成 ToolCallback,例如:
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime"); ToolCallback toolCallback = MethodToolCallback.builder() .toolDefinition(ToolDefinition.builder(method) .description("Get the current date and time in the user's timezone") .build()) .toolMethod(method) .toolObject(new DateTimeTools()) .build();
如果我们要拦截MCP的调用,那肯定需要在ToolCallback中进行hook,而我们不可能手动解析@Tool注解然后编程式生成ToolCallback(其实这个任务就是由ToolCallbacks.from
函数完成的)。
综上所述,在本文中我决定使用 ToolCallbacks.from
的方式进行监听。
2. Reactor
注意到 chatResponse()
的返回值为 Flux<T>
,为了减少代码量,我们的对话函数肯定也得返回Flux<T>
在这里直接贴一下github的定义,感兴趣的童鞋可以自己翻:
面向JVM的非阻塞响应式流基础框架,既实现了受Reactive Extensions启发的API,又提供了高效的事件流支持。
实现细节
我们先定义一个Service,写一些样板代码:
@Service
public class LLMJavaService {
private final ChatClient chat;
private final MCPService mcpService;
public LLMJavaService(ChatClient.Builder builder, MCPService mcpService) {
this.mcpService = mcpService;
this.chat = builder.defaultSystem("你是一个助手。你可以调用 get-date 的mcp工具获取时间。").build();
}
public Flux<ModelResponse> chatAsync(String message) {
}
public sealed interface ModelResponse permits ModelResponse.Message, ModelResponse.ToolCall, ModelResponse.ToolResponse, ModelResponse.ToolError {
@Data
@AllArgsConstructor
final class Message implements ModelResponse {
private String content;
}
@Data
@AllArgsConstructor
final class ToolCall implements ModelResponse {
private String callId;
private String name;
private String input;
}
@Data
@AllArgsConstructor
final class ToolResponse implements ModelResponse {
private String callId;
private String name;
private String input;
private String content;
}
@Data
@AllArgsConstructor
final class ToolError implements ModelResponse {
private String callId;
private String name;
private String input;
private Throwable error;
}
}
}
@Service
class MCPService {
@Tool(name = "get-date", description = "获取当前日期")
public String getDate() {
return LocalDateTime.now().toString();
}
}
我们需要实现 public Flux<ModelResponse> chatAsync(String message)
里的内容,先在这里返回一个流:
Flux<ModelResponse> contentFlux = chat.prompt()
.user(message)
.toolCallbacks(loggerTools)
.stream()
.content()
.map(ModelResponse.Message::new);
但是只有一个流怎么行呢?我们还需要拦截MCP调用。因此编写一个委托类来进行一个静态代理:
private static class DelegateToolCallback implements ToolCallback {
private final ToolCallback toolCallback;
private final Consumer<ModelResponse> listener;
DelegateToolCallback(ToolCallback toolCallback, Consumer<ModelResponse> listener) {
this.toolCallback = toolCallback;
this.listener = listener;
}
@NotNull
@Override
public ToolDefinition getToolDefinition() {
return toolCallback.getToolDefinition();
}
@NotNull
@Override
public ToolMetadata getToolMetadata() {
return toolCallback.getToolMetadata();
}
@NotNull
@Override
public String call(@NotNull String toolInput, ToolContext toolContext) {
String uuid = UUID.randomUUID().toString().replace("-", "");
listener.accept(new ModelResponse.ToolCall(uuid, this.getToolDefinition().name(), toolInput));
String result;
try {
result = toolCallback.call(toolInput, toolContext);
} catch (Throwable e) {
listener.accept(new ModelResponse.ToolError(uuid, this.getToolDefinition().name(), toolInput, e));
throw e;
}
listener.accept(new ModelResponse.ToolResponse(uuid, this.getToolDefinition().name(), toolInput, result));
return result;
}
@NotNull
@Override
public String call(@NotNull String toolInput) {
return call(toolInput, null);
}
}
这个经过代理的工具唯一作用就是在 call
函数被调用之前/之后,记录传参并回报。
我们还需要再产生一个流,这个流负责接收ToolCallback所产生的ModelResponse:
Sinks.Many<ModelResponse> toolSink = Sinks.many().unicast().onBackpressureBuffer();
ToolCallback[] loggerTools = Stream.of(ToolCallbacks.from(mcpService))
.map(callback -> wrapperListeners(callback, toolSink::tryEmitNext))
.toArray(value -> new ToolCallback[0]);
Flux<ModelResponse> toolFlux = toolSink.asFlux();
在这里我们直接使用Reactor的Sinks来生成Sink,然后使用asFlux()将其转成Flux。
最后将两个流合并后返回:
return Flux.merge(toolFlux, contentFlux);
然后编写测试代码(Kotlin.ver):
@SpringBootTest
class SpringMcpDemoApplicationTests {
@Autowired
private lateinit var lLMService: LLMService
@Test
fun contextLoads() {
val resp = lLMService.chatAsync("今天的日期是多少?")
val latch = CountDownLatch(1)
resp.doOnEach { println(it) }.doOnComplete { latch.countDown() }.subscribe()
latch.await()
}
}
点击测试,发现程序在输出完毕后卡死了?哦,原来merge的流要所有子流全部关闭后才能关闭,而我们没有任何关闭toolFlux的代码,让我们关闭一下:
Flux<ModelResponse> contentFlux = chat.prompt()
.user(message)
.toolCallbacks(loggerTools)
.stream()
.content()
.doOnComplete(toolSink::tryEmitComplete) //这里是新加的
.map(ModelResponse.Message::new);
现在测试就没有问题了!
完整代码
package top.kagg886.mcpdemo.springmcpdemo.service;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import java.util.UUID;
import java.util.function.IntFunction;
import java.util.stream.Stream;
@Service
public class LLMJavaService {
private final ChatClient chat;
private final MCPService mcpService;
public LLMJavaService(ChatClient.Builder builder, MCPService mcpService) {
this.mcpService = mcpService;
this.chat = builder.defaultSystem("你是一个助手。你可以调用 get-date 的mcp工具获取时间。").build();
}
public Flux<ModelResponse> chatAsync(String message) {
Sinks.Many<ModelResponse> toolSink = Sinks.many().unicast().onBackpressureBuffer();
ToolCallback[] loggerTools = Stream.of(ToolCallbacks.from(mcpService))
.map(callback -> wrapperListeners(callback, toolSink::tryEmitNext))
.toArray(value -> new ToolCallback[0]);
Flux<ModelResponse> toolFlux = toolSink.asFlux();
Flux<ModelResponse> contentFlux = chat.prompt()
.user(message)
.toolCallbacks(loggerTools)
.stream()
.content()
.doOnComplete(toolSink::tryEmitComplete)
.map(ModelResponse.Message::new);
return Flux.merge(toolFlux, contentFlux);
}
private ToolCallback wrapperListeners(ToolCallback toolCallback, java.util.function.Consumer<ModelResponse> listener) {
return new DelegateToolCallback(toolCallback, listener);
}
private static class DelegateToolCallback implements ToolCallback {
private final ToolCallback toolCallback;
private final Consumer<ModelResponse> listener;
DelegateToolCallback(ToolCallback toolCallback, Consumer<ModelResponse> listener) {
this.toolCallback = toolCallback;
this.listener = listener;
}
@NotNull
@Override
public ToolDefinition getToolDefinition() {
return toolCallback.getToolDefinition();
}
@NotNull
@Override
public ToolMetadata getToolMetadata() {
return toolCallback.getToolMetadata();
}
@NotNull
@Override
public String call(@NotNull String toolInput, ToolContext toolContext) {
String uuid = UUID.randomUUID().toString().replace("-", "");
listener.accept(new ModelResponse.ToolCall(uuid, this.getToolDefinition().name(), toolInput));
String result;
try {
result = toolCallback.call(toolInput, toolContext);
} catch (Throwable e) {
listener.accept(new ModelResponse.ToolError(uuid, this.getToolDefinition().name(), toolInput, e));
throw e;
}
listener.accept(new ModelResponse.ToolResponse(uuid, this.getToolDefinition().name(), toolInput, result));
return result;
}
@NotNull
@Override
public String call(@NotNull String toolInput) {
return call(toolInput, null);
}
}
public sealed interface ModelResponse permits ModelResponse.Message, ModelResponse.ToolCall, ModelResponse.ToolResponse, ModelResponse.ToolError {
@Data
@AllArgsConstructor
final class Message implements ModelResponse {
private String content;
}
@Data
@AllArgsConstructor
final class ToolCall implements ModelResponse {
private String callId;
private String name;
private String input;
}
@Data
@AllArgsConstructor
final class ToolResponse implements ModelResponse {
private String callId;
private String name;
private String input;
private String content;
}
@Data
@AllArgsConstructor
final class ToolError implements ModelResponse {
private String callId;
private String name;
private String input;
private Throwable error;
}
}
}