原文链接:SpringAI进阶:MCP服务鉴权
教程说明
说明:本教程将采用2025年5月20日正式的GA版,给出如下内容
- 核心功能模块的快速上手教程
- 核心功能模块的源码级解读
- Spring ai alibaba增强的快速上手教程 + 源码级解读
版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba 1.0.0.2
将陆续完成如下章节教程。本章是第七章(MCP使用范式)下的MCP服务鉴权
代码开源如下:github.com/GTyingzi/sp…
微信推文往届解读可参考:
第一章内容
SpringAI(GA)的chat:快速上手+自动注入源码解读
第二章内容
SpringAI(GA):Sqlite、Mysql、Redis消息存储快速上手
第三章内容
第四章内容
第五章内容
SpringAI(GA):内存、Redis、ES的向量数据库存储—快速上手
SpringAI(GA):向量数据库理论源码解读+Redis、Es接入源码
第六章内容
第七章内容
SpringAI(GA): SpringAI下的MCP源码解读
第八章内容
第九章内容
第十章内容
Spring AI Alibaba Graph:多节点并行—快速上手
Spring AI Alibaba Graph:节点流式透传案例
Spring AI Alibaba Graph:分配MCP到指定节点
Spring AI Alibaba Graph:中断!人类反馈介入,流程丝滑走完~
付费服务(现收费59.9元):飞书云文档在线排版方式+专属会员群教的程代码答疑
MCP 服务鉴权
[!TIP] 在企业中实际使用 MCP 过程中,或在调用第三方 MCP 服务,需要识别 Client 侧调用者身份
实战代码可见:github.com/GTyingzi/sp… 下的 mcp 目录下的 mcp-auth-server、mcp-auth-client 模块,以及需要鉴权的接口模块(other 下的 restful)
Restful 服务
pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>
TimeController
package com.spring.ai.tutorial.controller;
import com.spring.ai.tutorial.utils.ZoneUtils;
import jakarta.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Collections;
@RestController
@RequestMapping("/time")
public class TimeController {
private static final Logger logger = LoggerFactory.getLogger(TimeController.class);
/**
* 获取指定时区的时间
*/
@GetMapping("/city")
public String getCiteTimeMethod(
@RequestParam("timeZoneId") String timeZoneId,
HttpServletRequest request) {
// 打印请求头信息
for (String headerName : Collections.list(request.getHeaderNames())) {
logger.info("Header {}: {}", headerName, request.getHeader(headerName));
}
logger.info("The current time zone is {}", timeZoneId);
return String.format("The current time zone is %s and the current time is " + "%s", timeZoneId,
ZoneUtils.getTimeByZoneId(timeZoneId));
}
}
以 101 端口对外提供获取时间服务
MCP Client 侧
pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
</dependency>
</dependencies>
application.yml
server:
port: 19101
spring:
application:
name: mcp-auth-client
main:
web-application-type: none
ai:
openai:
api-key: ${DASHSCOPEAPIKEY}
base-url: https://dashscope.aliyuncs.com/compatible-mode
chat:
options:
model: qwen-max
mcp:
client:
enabled: true
name: my-mcp-client
version: 1.0.0
request-timeout: 600s
type: ASYNC # or ASYNC for reactive applications
sse:
connections:
server1:
url: http://localhost:19000 # 本地
headers:
token-yingzi-1: yingzi-1
token-yingzi-2: yingzi-2
McpSseClientProperties
package org.springframework.ai.mcp.client.autoconfigure.properties;
import java.util.HashMap;
import java.util.Map;
import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties("spring.ai.mcp.client.sse")
public class McpSseClientProperties {
public static final String CONFIGPREFIX = "spring.ai.mcp.client.sse";
private final Map<String, SseParameters> connections = new HashMap();
public Map<String, SseParameters> getConnections() {
return this.connections;
}
public static record SseParameters(String url, String sseEndpoint, Map<String, String> headers) {
}
}
SseWebFluxTransportAutoConfiguration
package org.springframework.ai.mcp.client.autoconfigure;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpSseClientProperties;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.client.WebClient;
@AutoConfiguration
@ConditionalOnClass({WebFluxSseClientTransport.class})
@EnableConfigurationProperties({McpSseClientProperties.class, McpClientCommonProperties.class})
@ConditionalOnProperty(
prefix = "spring.ai.mcp.client",
name = {"enabled"},
havingValue = "true",
matchIfMissing = true
)
public class SseWebFluxTransportAutoConfiguration {
@Bean
public List<NamedClientMcpTransport> webFluxClientTransports(McpSseClientProperties sseProperties, ObjectProvider<WebClient.Builder> webClientBuilderProvider, ObjectProvider<ObjectMapper> objectMapperProvider) {
List<NamedClientMcpTransport> sseTransports = new ArrayList();
WebClient.Builder webClientBuilderTemplate = (WebClient.Builder)webClientBuilderProvider.getIfAvailable(WebClient::builder);
ObjectMapper objectMapper = (ObjectMapper)objectMapperProvider.getIfAvailable(ObjectMapper::new);
for(Map.Entry<String, McpSseClientProperties.SseParameters> serverParameters : sseProperties.getConnections().entrySet()) {
WebClient.Builder webClientBuilder = webClientBuilderTemplate.clone().baseUrl(((McpSseClientProperties.SseParameters)serverParameters.getValue()).url())
// 添加请求头
.defaultHeaders((headers) ->
{
if (serverParameters.getValue().headers() != null) {
serverParameters.getValue().headers().forEach(headers::add);
}
}
);
String sseEndpoint = ((McpSseClientProperties.SseParameters)serverParameters.getValue()).sseEndpoint() != null ? ((McpSseClientProperties.SseParameters)serverParameters.getValue()).sseEndpoint() : "/sse";
WebFluxSseClientTransport transport = WebFluxSseClientTransport.builder(webClientBuilder).sseEndpoint(sseEndpoint).objectMapper(objectMapper).build();
sseTransports.add(new NamedClientMcpTransport((String)serverParameters.getKey(), transport));
}
return sseTransports;
}
}
AuthClientApplication
package com.spring.ai.tutorial.mcp.client;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import java.util.Scanner;
@SpringBootApplication
public class AuthClientApplication {
public static void main(String[] args) {
SpringApplication.run(AuthClientApplication.class, args);
}
@Bean
public CommandLineRunner predefinedQuestions(ChatClient.Builder chatClientBuilder, ToolCallbackProvider tools,
ConfigurableApplicationContext context) {
return args -> {
var chatClient = chatClientBuilder
.defaultToolCallbacks(tools.getToolCallbacks())
.build();
Scanner scanner = new Scanner(System.in);
while (true) {
System.out.print("\n>>> QUESTION: ");
String userInput = scanner.nextLine();
if (userInput.equalsIgnoreCase("exit")) {
break;
}
System.out.println("\n>>> ASSISTANT: " + chatClient.prompt(userInput).call().content());
}
scanner.close();
context.close();
};
}
}
MCP Server 侧
pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-server-webflux</artifactId>
</dependency>
</dependencies>
工具
RestfulToolDefinition
package org.springframework.ai.mcp;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;
public record RestfulToolDefinition(String name, String description, String inputSchema,
String url, String method, String path, HttpMethod httpMethod) implements ToolDefinition {
public RestfulToolDefinition {
Assert.hasText(name, "name cannot be null or empty");
Assert.hasText(description, "description cannot be null or empty");
Assert.hasText(inputSchema, "inputSchema cannot be null or empty");
Assert.hasText(url, "url cannot be null or empty");
Assert.hasText(method, "method cannot be null or empty");
Assert.hasText(path, "path cannot be null or empty");
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String name;
private String description;
private String inputSchema;
private String url;
private String method;
private String path;
private HttpMethod httpMethod;
public Builder name(String name) {
this.name = name;
return this;
}
public Builder description(String description) {
this.description = description;
return this;
}
public Builder inputSchema(String inputSchema) {
this.inputSchema = inputSchema;
return this;
}
public Builder url(String url) {
this.url = url;
return this;
}
public Builder method(String method) {
this.method = method;
return this;
}
public Builder path(String path) {
this.path = path;
return this;
}
public Builder httpMethod(HttpMethod httpMethod) {
this.httpMethod = httpMethod;
return this;
}
public RestfulToolDefinition build() {
return new RestfulToolDefinition(name, description, inputSchema, url, method, path, httpMethod);
}
}
}
McpRestfulToolCallback
package org.springframework.ai.mcp;
import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;
import java.util.HashMap;
import java.util.Map;
public class McpRestfulToolCallback implements ToolCallback {
private static final Logger logger = LoggerFactory.getLogger(McpRestfulToolCallback.class);
private final RestfulToolDefinition toolDefinition;
private Map<String, String> headers = new HashMap<>();
public McpRestfulToolCallback(RestfulToolDefinition toolDefinition) {
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
this.toolDefinition = toolDefinition;
}
@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}
@Override
public String call(String toolInput) {
return this.call(toolInput, (ToolContext) null);
}
@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");
logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
Map<String, Object> toolArguments = JsonParser.fromJson(toolInput, new TypeReference<Map<String, Object>>() {
});
String result = "";
if (HttpMethod.GET.equals(toolDefinition.httpMethod())) {
StringBuilder uriBuilder = new StringBuilder().append(toolDefinition.path()).append("?");
toolArguments.forEach((key, value) -> {
uriBuilder.append(key).append("=").append(value).append("&");
});
String uri = uriBuilder.toString();
result = WebClient.builder().build().get()
.uri(toolDefinition.url() + uri)
.headers(headers -> {
this.headers.forEach(headers::add);
})
.retrieve()
.bodyToMono(String.class)
.block();
} else if (HttpMethod.POST.equals(toolDefinition.httpMethod())) {
result = WebClient.builder().build().post()
.uri(toolDefinition.url())
.headers(headers -> {
this.headers.forEach(headers::add);
})
.bodyValue(toolArguments)
.retrieve()
.bodyToMono(String.class)
.block();
}
logger.debug("Successful execution of tool: {}, result: {}", this.toolDefinition.name(), result);
return result;
}
public void setHeaders(Map<String, String> headersMap) {
this.headers = headersMap;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private RestfulToolDefinition toolDefinition;
private Builder() {
}
public Builder toolDefinition(RestfulToolDefinition toolDefinition) {
this.toolDefinition = toolDefinition;
return this;
}
public McpRestfulToolCallback build() {
return new McpRestfulToolCallback(this.toolDefinition);
}
}
}
McpRestfulToolCallbackProvider
package org.springframework.ai.mcp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.util.Assert;
public class McpRestfulToolCallbackProvider implements ToolCallbackProvider {
private final ToolCallback[] toolCallbacks;
private static final Logger logger = LoggerFactory.getLogger(McpRestfulToolCallbackProvider.class);
public McpRestfulToolCallbackProvider(McpRestfulToolCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.toolCallbacks = toolCallbacks;
}
@Override
public ToolCallback[] getToolCallbacks() {
return this.toolCallbacks;
}
public static Builder builder() {
return new McpRestfulToolCallbackProvider.Builder();
}
public static class Builder {
private McpRestfulToolCallback[] toolCallbacks;
private Builder() {
}
public Builder toolCallbacks(McpRestfulToolCallback... toolCallbacks) {
this.toolCallbacks = toolCallbacks;
return this;
}
public McpRestfulToolCallbackProvider build() {
return new McpRestfulToolCallbackProvider(this.toolCallbacks);
}
}
}
WebFluxSseServerTransportProvider 改造
package io.modelcontextprotocol.server.transport;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.spring.ai.tutorial.mcp.server.util.ApplicationContextHolder;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.McpRestfulToolCallback;
import org.springframework.ai.mcp.McpRestfulToolCallbackProvider;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
public class WebFluxSseServerTransportProvider implements McpServerTransportProvider {
private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class);
public static final String MESSAGEEVENTTYPE = "message";
public static final String ENDPOINTEVENTTYPE = "endpoint";
public static final String DEFAULTSSEENDPOINT = "/sse";
public static final String DEFAULTBASEURL = "";
private final ObjectMapper objectMapper;
private final String baseUrl;
private final String messageEndpoint;
private final String sseEndpoint;
private final RouterFunction<?> routerFunction;
private McpServerSession.Factory sessionFactory;
private final ConcurrentHashMap<String, McpServerSession> sessions;
private final ConcurrentHashMap<String, Map<String, String>> session2headers;
private final McpRestfulToolCallbackProvider mcpRestfulToolCallbackProvider;
private volatile boolean isClosing;
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, "/sse");
}
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, "", messageEndpoint, sseEndpoint);
}
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
this.sessions = new ConcurrentHashMap();
this.session2headers = new ConcurrentHashMap<>();
this.mcpRestfulToolCallbackProvider = ApplicationContextHolder.getBean(McpRestfulToolCallbackProvider.class);
this.isClosing = false;
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.routerFunction = RouterFunctions.route().GET(this.sseEndpoint, this::handleSseConnection).POST(this.messageEndpoint, this::handleMessage).build();
}
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
public Mono<Void> notifyClients(String method, Object params) {
if (this.sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
} else {
logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
return Flux.fromIterable(this.sessions.values()).flatMap((session) -> session.sendNotification(method, params).doOnError((e) -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())).onErrorComplete()).then();
}
}
public Mono<Void> closeGracefully() {
return Flux.fromIterable(this.sessions.values()).doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size())).flatMap(McpServerSession::closeGracefully).then();
}
public RouterFunction<?> getRouterFunction() {
return this.routerFunction;
}
private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
return this.isClosing ? ServerResponse.status(HttpStatus.SERVICEUNAVAILABLE).bodyValue("Server is shutting down") : ServerResponse.ok().contentType(MediaType.TEXTEVENTSTREAM).body(Flux.create((FluxSink<ServerSentEvent<?>> sink) -> {
WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);
McpServerSession session = this.sessionFactory.create(sessionTransport);
String sessionId = session.getId();
logger.debug("Created new SSE connection for session: {}", sessionId);
this.sessions.put(sessionId, session);
// 获取请求头
Map<String, String> headers = request.headers().asHttpHeaders().toSingleValueMap();
logger.debug("sessionId: {} with headers: {}", sessionId, headers);
session2headers.put(sessionId, headers);
logger.debug("Sending initial endpoint event to session: {}", sessionId);
sink.next(ServerSentEvent.builder().event("endpoint").data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId).build());
sink.onCancel(() -> {
logger.debug("Session {} cancelled", sessionId);
this.sessions.remove(sessionId);
});
}), ServerSentEvent.class);
}
private Mono<ServerResponse> handleMessage(ServerRequest request) {
if (this.isClosing) {
return ServerResponse.status(HttpStatus.SERVICEUNAVAILABLE).bodyValue("Server is shutting down");
} else if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
} else {
McpServerSession session = (McpServerSession)this.sessions.get(request.queryParam("sessionId").get());
return session == null ? ServerResponse.status(HttpStatus.NOTFOUND).bodyValue(new McpError("Session not found: " + (String)request.queryParam("sessionId").get())) : request.bodyToMono(String.class).flatMap((body) -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body);
if (message instanceof McpSchema.JSONRPCRequest) {
String method = ((McpSchema.JSONRPCRequest) message).method();
if (McpSchema.METHODTOOLSCALL.equals(method)) {
// 工具触发消息,获取此时对应的工具信息,塞入对应的请求头姐信息
Map<String, String> headers = this.session2headers.get(session.getId());
LinkedHashMap<String, String> params = (LinkedHashMap<String, String>) ((McpSchema.JSONRPCRequest) message).params();
String toolName = params.get("name");
Assert.notNull(toolName, "Tool name cannot be null");
for (McpRestfulToolCallback toolCallback : (McpRestfulToolCallback[]) mcpRestfulToolCallbackProvider.getToolCallbacks()) {
if (toolName.equals(toolCallback.getToolDefinition().name())) {
toolCallback.setHeaders( headers);
}
}
}
}
return session.handle(message).flatMap((response) -> ServerResponse.ok().build()).onErrorResume((error) -> {
logger.error("Error processing message: {}", error.getMessage());
return ServerResponse.status(HttpStatus.INTERNALSERVERERROR).bodyValue(new McpError(error.getMessage()));
});
} catch (IOException | IllegalArgumentException e) {
logger.error("Failed to deserialize message: {}", ((Exception)e).getMessage());
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
}
});
}
}
public static Builder builder() {
return new Builder();
}
private class WebFluxMcpSessionTransport implements McpServerTransport {
private final FluxSink<ServerSentEvent<?>> sink;
public WebFluxMcpSessionTransport(FluxSink<ServerSentEvent<?>> sink) {
this.sink = sink;
}
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromSupplier(() -> {
try {
return WebFluxSseServerTransportProvider.this.objectMapper.writeValueAsString(message);
} catch (IOException e) {
throw Exceptions.propagate(e);
}
}).doOnNext((jsonText) -> {
ServerSentEvent<Object> event = ServerSentEvent.builder().event("message").data(jsonText).build();
this.sink.next(event);
}).doOnError((e) -> {
Throwable exception = Exceptions.unwrap(e);
this.sink.error(exception);
}).then();
}
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return (T)WebFluxSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
}
public Mono<Void> closeGracefully() {
FluxSink var10000 = this.sink;
Objects.requireNonNull(var10000);
return Mono.fromRunnable(var10000::complete);
}
public void close() {
this.sink.complete();
}
}
public static class Builder {
private ObjectMapper objectMapper;
private String baseUrl = "";
private String messageEndpoint;
private String sseEndpoint = "/sse";
public Builder objectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
this.objectMapper = objectMapper;
return this;
}
public Builder basePath(String baseUrl) {
Assert.notNull(baseUrl, "basePath must not be null");
this.baseUrl = baseUrl;
return this;
}
public Builder messageEndpoint(String messageEndpoint) {
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
this.messageEndpoint = messageEndpoint;
return this;
}
public Builder sseEndpoint(String sseEndpoint) {
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
this.sseEndpoint = sseEndpoint;
return this;
}
public WebFluxSseServerTransportProvider build() {
Assert.notNull(this.objectMapper, "ObjectMapper must be set");
Assert.notNull(this.messageEndpoint, "Message endpoint must be set");
return new WebFluxSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint);
}
}
}
Restful 接口改造为工具
ParseRestful
package com.spring.ai.tutorial.mcp.server.parse;
import com.spring.ai.tutorial.mcp.server.model.Parameter;
import com.spring.ai.tutorial.mcp.server.model.RestfulModel;
import com.spring.ai.tutorial.mcp.server.util.JSONSchemaUtil;
import org.springframework.ai.mcp.McpRestfulToolCallback;
import org.springframework.ai.mcp.McpRestfulToolCallbackProvider;
import org.springframework.ai.mcp.RestfulToolDefinition;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
public class ParseRestful {
public McpRestfulToolCallbackProvider getRestfulToolCallbackProvider() {
List<McpRestfulToolCallback> toolCallbacks = new ArrayList<>();
getRestfulModels().forEach(
restfulModel -> {
RestfulToolDefinition restfulToolDefinition = RestfulToolDefinition.builder()
.name(restfulModel.name())
.description(restfulModel.description())
.inputSchema(restfulModel.inputSchema())
.url(restfulModel.url())
.method(restfulModel.method())
.path(restfulModel.path())
.httpMethod(restfulModel.httpMethod())
.build();
McpRestfulToolCallback mcpRestfulToolCallback = McpRestfulToolCallback.builder().toolDefinition(restfulToolDefinition).build();
toolCallbacks.add(mcpRestfulToolCallback);
});
return McpRestfulToolCallbackProvider.builder()
.toolCallbacks(toolCallbacks.toArray(new McpRestfulToolCallback[0]))
.build();
}
public List<RestfulModel> getRestfulModels() {
Parameter parameter = Parameter.builder()
.parameteNname("timeZoneId")
.description("time zone id, such as Asia/Shanghai")
.required(true)
.type("string")
.build();
return List.of(
new RestfulModel("getCiteTimeMethod", "获取指定时区的时间", JSONSchemaUtil.getInputSchema(List.of(parameter)), "http://localhost:101", "getCiteTimeMethod", "/time/city", HttpMethod.GET)
);
}
}
RestfulModel
public record RestfulModel(String name, String description, String inputSchema, String url, String method, String path, HttpMethod httpMethod) {
}
Parameter
package com.spring.ai.tutorial.mcp.server.model;
public record Parameter(String parameteNname,
String description,
boolean required,
String type
) {
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String parameteNname;
private String description;
private boolean required;
private String type;
public Builder parameteNname(String parameteNname) {
this.parameteNname = parameteNname;
return this;
}
public Builder description(String description) {
this.description = description;
return this;
}
public Builder required(boolean required) {
this.required = required;
return this;
}
public Builder type(String type) {
this.type = type;
return this;
}
public Parameter build() {
return new Parameter(parameteNname, description, required, type);
}
}
}
辅助工具类
ApplicationContextHolder
package com.spring.ai.tutorial.mcp.server.util;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
@Component
public class ApplicationContextHolder implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ApplicationContextHolder.applicationContext = applicationContext;
}
public static <T> T getBean(Class<T> clazz) {
return applicationContext.getBean(clazz);
}
}
JSONSchemaUtil
package com.spring.ai.tutorial.mcp.server.util;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.victools.jsonschema.generator.Module;
import com.github.victools.jsonschema.generator.Option;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import com.github.victools.jsonschema.module.swagger2.Swagger2Module;
import com.spring.ai.tutorial.mcp.server.model.Parameter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.ai.util.json.schema.SpringAiSchemaModule;
import org.springframework.util.StringUtils;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
public class JSONSchemaUtil {
private static final SchemaGenerator SUBTYPESCHEMAGENERATOR;
private static final Logger logger = LoggerFactory.getLogger(JSONSchemaUtil.class);
static {
Module jacksonModule = new JacksonModule(new JacksonOption[]{JacksonOption.RESPECTJSONPROPERTYREQUIRED});
Module openApiModule = new Swagger2Module();
Module springAiSchemaModule = new SpringAiSchemaModule(new SpringAiSchemaModule.Option[0]);
SchemaGeneratorConfigBuilder schemaGeneratorConfigBuilder = (new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT202012, OptionPreset.PLAINJSON)).with(jacksonModule).with(openApiModule).with(springAiSchemaModule).with(Option.EXTRAOPENAPIFORMATVALUES, new Option[0]).with(Option.PLAINDEFINITIONKEYS, new Option[0]);
SchemaGeneratorConfig subtypeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.without(Option.SCHEMAVERSIONINDICATOR, new Option[0]).build();
SUBTYPESCHEMAGENERATOR = new SchemaGenerator(subtypeSchemaGeneratorConfig);
}
public static String getInputSchema(List<Parameter> parameters) {
ObjectNode schema = JsonParser.getObjectMapper().createObjectNode();
schema.put("$schema", SchemaVersion.DRAFT202012.getIdentifier());
schema.put("type", "object");
ObjectNode properties = schema.putObject("properties");
List<String> required = new ArrayList<>();
for (Parameter parameter : parameters) {
String parameterName = parameter.parameteNname();
Type parameterType = null;
if (parameter.required()) {
required.add(parameterName);
}
try {
if (parameter.type() != null) {
// 示例:获取String类型的Type
parameterType = getTypeFromString(parameter.type());
} else {
logger.warn("参数 {} 的 schema 或 format 为空,跳过类型转换", parameterName);
}
} catch (ClassNotFoundException e) {
logger.error("无法将字符串类型转换为Type: {}", parameter.type(), e);
}
ObjectNode parameterNode = SUBTYPESCHEMAGENERATOR.generateSchema(parameterType, new Type[0]);
if (StringUtils.hasText(parameter.description())) {
parameterNode.put("description", parameter.description());
}
properties.set(parameterName, parameterNode);
}
ArrayNode requiredArray = schema.putArray("required");
Objects.requireNonNull(requiredArray);
required.forEach(requiredArray::add);
JsonSchemaGenerator.SchemaOption[] schemaOptions = new JsonSchemaGenerator.SchemaOption[0];
processSchemaOptions(schemaOptions, schema);
return schema.toPrettyString();
}
public static Type getTypeFromString(String typeString) throws ClassNotFoundException {
return switch (typeString) {
case "string" -> Class.forName("java.lang.String");
case "number" -> Class.forName("java.lang.Number");
case "integer" -> Class.forName("java.lang.Integer");
case "boolean" -> Class.forName("java.lang.Boolean");
default -> throw new ClassNotFoundException("Unsupported type: " + typeString);
};
}
private static void processSchemaOptions(JsonSchemaGenerator.SchemaOption[] schemaOptions, ObjectNode schema) {
if (Stream.of(schemaOptions).noneMatch((option) -> {
return option == JsonSchemaGenerator.SchemaOption.ALLOWADDITIONALPROPERTIESBYDEFAULT;
})) {
schema.put("additionalProperties", false);
}
if (Stream.of(schemaOptions).anyMatch((option) -> {
return option == JsonSchemaGenerator.SchemaOption.UPPERCASETYPEVALUES;
})) {
convertTypeValuesToUpperCase(schema);
}
}
public static void convertTypeValuesToUpperCase(ObjectNode node) {
if (node.isObject()) {
node.fields().forEachRemaining((entry) -> {
JsonNode value = (JsonNode) entry.getValue();
if (value.isObject()) {
convertTypeValuesToUpperCase((ObjectNode) value);
} else if (value.isArray()) {
value.elements().forEachRemaining((element) -> {
if (element.isObject() || element.isArray()) {
convertTypeValuesToUpperCase((ObjectNode) element);
}
});
} else if (value.isTextual() && ((String) entry.getKey()).equals("type")) {
String oldValue = node.get("type").asText();
node.put("type", oldValue.toUpperCase());
}
});
} else if (node.isArray()) {
node.elements().forEachRemaining((element) -> {
if (element.isObject() || element.isArray()) {
convertTypeValuesToUpperCase((ObjectNode) element);
}
});
}
}
}
AuthServerApplication
package com.spring.ai.tutorial.mcp.server;
import com.spring.ai.tutorial.mcp.server.parse.ParseRestful;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
@SpringBootApplication
public class AuthServerApplication {
public static void main(String[] args) {
SpringApplication.run(AuthServerApplication.class, args);
}
@Bean
public ToolCallbackProvider mcpRestfulToolCallbackProvider(ParseRestful parseRestful) {
return parseRestful.getRestfulToolCallbackProvider();
}
}
最终效果
- 先启动 Restful 服务,以 101 端口对外暴露时间服务
- 然后启动 MCP Server 服务,解析 Restful 方法,对外提供工具
- 最后启动 MCP Client 服务,传入对应的请求头配置
在 client 侧配置的请求头信息,在触发工具时,可观察到携带请求头信息经过 MCP Server 最终传递到了 restful 服务
学习交流圈
你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师,同时作为Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取