SpringAI进阶:MCP服务鉴权

200 阅读7分钟

原文链接:SpringAI进阶:MCP服务鉴权

教程说明

说明:本教程将采用2025年5月20日正式的GA版,给出如下内容

  1. 核心功能模块的快速上手教程
  2. 核心功能模块的源码级解读
  3. Spring ai alibaba增强的快速上手教程 + 源码级解读

版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba 1.0.0.2

将陆续完成如下章节教程。本章是第七章(MCP使用范式)下的MCP服务鉴权

代码开源如下:github.com/GTyingzi/sp…

微信推文往届解读可参考:

第一章内容

SpringAI(GA)的chat:快速上手+自动注入源码解读

SpringAI(GA):ChatClient调用链路解读

第二章内容

SpringAI的Advisor:快速上手+源码解读

SpringAI(GA):Sqlite、Mysql、Redis消息存储快速上手

第三章内容

SpringAI(GA):Tool工具整合—快速上手

SpringAI(GA):Tool源码+工具触发链路解读

第四章内容

SpringAI(GA):结构化输出的快速上手+源码解读

第五章内容

SpringAI(GA):内存、Redis、ES的向量数据库存储—快速上手

SpringAI(GA):向量数据库理论源码解读+Redis、Es接入源码

第六章内容

SpringAI(GA):RAG快速上手+模块化解读

SpringAI(GA):RAG下的ETL快速上手

SpringAI(GA):RAG下的ETL源码解读

第七章内容

SpringAI(GA):Nacos2下的分布式MCP

SpringAI(GA):Nacos3下的分布式MCP

SpringAI(GA):MCP源码解读

SpringAI(GA): SpringAI下的MCP源码解读

第八章内容

SpringAI(GA): 多模型评估篇

第九章内容

SpringAI(GA):观测篇快速上手+源码解读

第十章内容

Spring AI Alibaba Graph:快速入门

Spring AI Alibaba Graph:多节点并行—快速上手

Spring AI Alibaba Graph:节点流式透传案例

Spring AI Alibaba Graph:分配MCP到指定节点

Spring AI Alibaba Graph:中断!人类反馈介入,流程丝滑走完~

付费服务(现收费59.9元):飞书云文档在线排版方式+专属会员群教的程代码答疑

MCP 服务鉴权

[!TIP] 在企业中实际使用 MCP 过程中,或在调用第三方 MCP 服务,需要识别 Client 侧调用者身份

实战代码可见:github.com/GTyingzi/sp… 下的 mcp 目录下的 mcp-auth-server、mcp-auth-client 模块,以及需要鉴权的接口模块(other 下的 restful)

Restful 服务

pom.xml
<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
</dependencies>
TimeController
package com.spring.ai.tutorial.controller;

import com.spring.ai.tutorial.utils.ZoneUtils;
import jakarta.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Collections;

@RestController
@RequestMapping("/time")
public class TimeController {

    private static final Logger logger = LoggerFactory.getLogger(TimeController.class);

    /**
     * 获取指定时区的时间
     */
    @GetMapping("/city")
    public String getCiteTimeMethod(
            @RequestParam("timeZoneId") String timeZoneId,
            HttpServletRequest request) {
        // 打印请求头信息
        for (String headerName : Collections.list(request.getHeaderNames())) {
            logger.info("Header {}: {}", headerName, request.getHeader(headerName));
        }
        logger.info("The current time zone is {}", timeZoneId);
        return String.format("The current time zone is %s and the current time is " + "%s", timeZoneId,
                ZoneUtils.getTimeByZoneId(timeZoneId));
    }
}

以 101 端口对外提供获取时间服务

MCP Client 侧

pom.xml
<dependencies>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-openai</artifactId>
    </dependency>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
    </dependency>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
    </dependency>

</dependencies>
application.yml
server:
  port: 19101

spring:
  application:
    name: mcp-auth-client
  main:
    web-application-type: none
  ai:
    openai:
      api-key: ${DASHSCOPEAPIKEY}
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      chat:
        options:
          model: qwen-max
    mcp:
      client:
        enabled: true
        name: my-mcp-client
        version: 1.0.0
        request-timeout: 600s
        type: ASYNC  # or ASYNC for reactive applications
        sse:
          connections:
            server1:
              url: http://localhost:19000 # 本地
              headers:
                token-yingzi-1: yingzi-1
                token-yingzi-2: yingzi-2
McpSseClientProperties
package org.springframework.ai.mcp.client.autoconfigure.properties;

import java.util.HashMap;
import java.util.Map;
import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties("spring.ai.mcp.client.sse")
public class McpSseClientProperties {
    public static final String CONFIGPREFIX = "spring.ai.mcp.client.sse";
    private final Map<String, SseParameters> connections = new HashMap();

    public Map<String, SseParameters> getConnections() {
        return this.connections;
    }

    public static record SseParameters(String url, String sseEndpoint, Map<String, String> headers) {
    }
}
SseWebFluxTransportAutoConfiguration
package org.springframework.ai.mcp.client.autoconfigure;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpSseClientProperties;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.client.WebClient;

@AutoConfiguration
@ConditionalOnClass({WebFluxSseClientTransport.class})
@EnableConfigurationProperties({McpSseClientProperties.class, McpClientCommonProperties.class})
@ConditionalOnProperty(
        prefix = "spring.ai.mcp.client",
        name = {"enabled"},
        havingValue = "true",
        matchIfMissing = true
)
public class SseWebFluxTransportAutoConfiguration {
    @Bean
    public List<NamedClientMcpTransport> webFluxClientTransports(McpSseClientProperties sseProperties, ObjectProvider<WebClient.Builder> webClientBuilderProvider, ObjectProvider<ObjectMapper> objectMapperProvider) {
        List<NamedClientMcpTransport> sseTransports = new ArrayList();
        WebClient.Builder webClientBuilderTemplate = (WebClient.Builder)webClientBuilderProvider.getIfAvailable(WebClient::builder);
        ObjectMapper objectMapper = (ObjectMapper)objectMapperProvider.getIfAvailable(ObjectMapper::new);

        for(Map.Entry<String, McpSseClientProperties.SseParameters> serverParameters : sseProperties.getConnections().entrySet()) {
            WebClient.Builder webClientBuilder = webClientBuilderTemplate.clone().baseUrl(((McpSseClientProperties.SseParameters)serverParameters.getValue()).url())
                    // 添加请求头
                    .defaultHeaders((headers) ->
                            {
                                if (serverParameters.getValue().headers() != null) {
                                    serverParameters.getValue().headers().forEach(headers::add);
                                }
                            }
            );
            String sseEndpoint = ((McpSseClientProperties.SseParameters)serverParameters.getValue()).sseEndpoint() != null ? ((McpSseClientProperties.SseParameters)serverParameters.getValue()).sseEndpoint() : "/sse";
            WebFluxSseClientTransport transport = WebFluxSseClientTransport.builder(webClientBuilder).sseEndpoint(sseEndpoint).objectMapper(objectMapper).build();
            sseTransports.add(new NamedClientMcpTransport((String)serverParameters.getKey(), transport));
        }

        return sseTransports;
    }
}
AuthClientApplication
package com.spring.ai.tutorial.mcp.client;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;

import java.util.Scanner;

@SpringBootApplication
public class AuthClientApplication {

    public static void main(String[] args) {
        SpringApplication.run(AuthClientApplication.class, args);
    }

    @Bean
    public CommandLineRunner predefinedQuestions(ChatClient.Builder chatClientBuilder, ToolCallbackProvider tools,
                                                 ConfigurableApplicationContext context) {

        return args -> {
            var chatClient = chatClientBuilder
                    .defaultToolCallbacks(tools.getToolCallbacks())
                    .build();

            Scanner scanner = new Scanner(System.in);
            while (true) {
                System.out.print("\n>>> QUESTION: ");
                String userInput = scanner.nextLine();
                if (userInput.equalsIgnoreCase("exit")) {
                    break;
                }
                System.out.println("\n>>> ASSISTANT: " + chatClient.prompt(userInput).call().content());
            }
            scanner.close();
            context.close();
        };
    }
}

MCP Server 侧

pom.xml
<dependencies>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-starter-mcp-server-webflux</artifactId>
    </dependency>
</dependencies>
工具
RestfulToolDefinition
package org.springframework.ai.mcp;

import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;

public record RestfulToolDefinition(String name, String description, String inputSchema,
                                    String url, String method, String path, HttpMethod httpMethod) implements ToolDefinition {

    public RestfulToolDefinition {
        Assert.hasText(name, "name cannot be null or empty");
        Assert.hasText(description, "description cannot be null or empty");
        Assert.hasText(inputSchema, "inputSchema cannot be null or empty");
        Assert.hasText(url, "url cannot be null or empty");
        Assert.hasText(method, "method cannot be null or empty");
        Assert.hasText(path, "path cannot be null or empty");
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String name;
        private String description;
        private String inputSchema;

        private String url;
        private String method;
        private String path;
        private HttpMethod httpMethod;

        public Builder name(String name) {
            this.name = name;
            return this;
        }

        public Builder description(String description) {
            this.description = description;
            return this;
        }

        public Builder inputSchema(String inputSchema) {
            this.inputSchema = inputSchema;
            return this;
        }

        public Builder url(String url) {
            this.url = url;
            return this;
        }

        public Builder method(String method) {
            this.method = method;
            return this;
        }

        public Builder path(String path) {
            this.path = path;
            return this;
        }

        public Builder httpMethod(HttpMethod httpMethod) {
            this.httpMethod = httpMethod;
            return this;
        }

        public RestfulToolDefinition build() {
            return new RestfulToolDefinition(name, description, inputSchema, url, method, path, httpMethod);
        }
    }
}
McpRestfulToolCallback
package org.springframework.ai.mcp;

import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;

import java.util.HashMap;
import java.util.Map;

public class McpRestfulToolCallback implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(McpRestfulToolCallback.class);

    private final RestfulToolDefinition toolDefinition;
    private Map<String, String> headers = new HashMap<>();

    public McpRestfulToolCallback(RestfulToolDefinition toolDefinition) {
        Assert.notNull(toolDefinition, "toolDefinition cannot be null");
        this.toolDefinition = toolDefinition;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return toolDefinition;
    }

    @Override
    public String call(String toolInput) {
        return this.call(toolInput, (ToolContext) null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        Assert.hasText(toolInput, "toolInput cannot be null or empty");
        logger.debug("Starting execution of tool: {}", this.toolDefinition.name());

        Map<String, Object> toolArguments = JsonParser.fromJson(toolInput, new TypeReference<Map<String, Object>>() {
        });
        String result = "";
        if (HttpMethod.GET.equals(toolDefinition.httpMethod())) {
            StringBuilder uriBuilder = new StringBuilder().append(toolDefinition.path()).append("?");
            toolArguments.forEach((key, value) -> {
                uriBuilder.append(key).append("=").append(value).append("&");
            });
            String uri = uriBuilder.toString();

            result = WebClient.builder().build().get()
                    .uri(toolDefinition.url() + uri)
                    .headers(headers -> {
                        this.headers.forEach(headers::add);
                    })
                    .retrieve()
                    .bodyToMono(String.class)
                    .block();
        } else if (HttpMethod.POST.equals(toolDefinition.httpMethod())) {
            result = WebClient.builder().build().post()
                    .uri(toolDefinition.url())
                    .headers(headers -> {
                        this.headers.forEach(headers::add);
                    })
                    .bodyValue(toolArguments)
                    .retrieve()
                    .bodyToMono(String.class)
                    .block();

        }
        logger.debug("Successful execution of tool: {}, result: {}", this.toolDefinition.name(), result);
        return result;
    }

    public void setHeaders(Map<String, String> headersMap) {
        this.headers = headersMap;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {

        private RestfulToolDefinition toolDefinition;

        private Builder() {
        }

        public Builder toolDefinition(RestfulToolDefinition toolDefinition) {
            this.toolDefinition = toolDefinition;
            return this;
        }


        public McpRestfulToolCallback build() {
            return new McpRestfulToolCallback(this.toolDefinition);
        }

    }
}
McpRestfulToolCallbackProvider
package org.springframework.ai.mcp;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.util.Assert;

public class McpRestfulToolCallbackProvider implements ToolCallbackProvider {

    private final ToolCallback[] toolCallbacks;

    private static final Logger logger = LoggerFactory.getLogger(McpRestfulToolCallbackProvider.class);

    public McpRestfulToolCallbackProvider(McpRestfulToolCallback... toolCallbacks) {
        Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
        this.toolCallbacks = toolCallbacks;
    }

    @Override
    public ToolCallback[] getToolCallbacks() {
        return this.toolCallbacks;
    }

    public static Builder builder() {
        return new McpRestfulToolCallbackProvider.Builder();
    }

    public static class Builder {
        private McpRestfulToolCallback[] toolCallbacks;

        private Builder() {
        }

        public Builder toolCallbacks(McpRestfulToolCallback... toolCallbacks) {
            this.toolCallbacks = toolCallbacks;
            return this;
        }

        public McpRestfulToolCallbackProvider build() {
            return new McpRestfulToolCallbackProvider(this.toolCallbacks);
        }
    }
}
WebFluxSseServerTransportProvider 改造
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.spring.ai.tutorial.mcp.server.util.ApplicationContextHolder;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.McpRestfulToolCallback;
import org.springframework.ai.mcp.McpRestfulToolCallbackProvider;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

public class WebFluxSseServerTransportProvider implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class);
    public static final String MESSAGEEVENTTYPE = "message";
    public static final String ENDPOINTEVENTTYPE = "endpoint";
    public static final String DEFAULTSSEENDPOINT = "/sse";
    public static final String DEFAULTBASEURL = "";
    private final ObjectMapper objectMapper;
    private final String baseUrl;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final RouterFunction<?> routerFunction;
    private McpServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpServerSession> sessions;
    private final ConcurrentHashMap<String, Map<String, String>> session2headers;
    private final McpRestfulToolCallbackProvider mcpRestfulToolCallbackProvider;
    private volatile boolean isClosing;

    public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, "/sse");
    }

    public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
        this(objectMapper, "", messageEndpoint, sseEndpoint);
    }

    public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
        this.sessions = new ConcurrentHashMap();
        this.session2headers = new ConcurrentHashMap<>();
        this.mcpRestfulToolCallbackProvider = ApplicationContextHolder.getBean(McpRestfulToolCallbackProvider.class);
        this.isClosing = false;
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(baseUrl, "Message base path must not be null");
        Assert.notNull(messageEndpoint, "Message endpoint must not be null");
        Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
        this.objectMapper = objectMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
        this.routerFunction = RouterFunctions.route().GET(this.sseEndpoint, this::handleSseConnection).POST(this.messageEndpoint, this::handleMessage).build();
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        } else {
            logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
            return Flux.fromIterable(this.sessions.values()).flatMap((session) -> session.sendNotification(method, params).doOnError((e) -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())).onErrorComplete()).then();
        }
    }

    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size())).flatMap(McpServerSession::closeGracefully).then();
    }

    public RouterFunction<?> getRouterFunction() {
        return this.routerFunction;
    }

    private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
        return this.isClosing ? ServerResponse.status(HttpStatus.SERVICEUNAVAILABLE).bodyValue("Server is shutting down") : ServerResponse.ok().contentType(MediaType.TEXTEVENTSTREAM).body(Flux.create((FluxSink<ServerSentEvent<?>> sink) -> {
            WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);
            McpServerSession session = this.sessionFactory.create(sessionTransport);
            String sessionId = session.getId();
            logger.debug("Created new SSE connection for session: {}", sessionId);
            this.sessions.put(sessionId, session);
            // 获取请求头
            Map<String, String> headers = request.headers().asHttpHeaders().toSingleValueMap();
            logger.debug("sessionId: {} with headers: {}", sessionId, headers);
            session2headers.put(sessionId, headers);

            logger.debug("Sending initial endpoint event to session: {}", sessionId);
            sink.next(ServerSentEvent.builder().event("endpoint").data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId).build());
            sink.onCancel(() -> {
                logger.debug("Session {} cancelled", sessionId);
                this.sessions.remove(sessionId);
            });
        }), ServerSentEvent.class);
    }

    private Mono<ServerResponse> handleMessage(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status(HttpStatus.SERVICEUNAVAILABLE).bodyValue("Server is shutting down");
        } else if (request.queryParam("sessionId").isEmpty()) {
            return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
        } else {
            McpServerSession session = (McpServerSession)this.sessions.get(request.queryParam("sessionId").get());
            return session == null ? ServerResponse.status(HttpStatus.NOTFOUND).bodyValue(new McpError("Session not found: " + (String)request.queryParam("sessionId").get())) : request.bodyToMono(String.class).flatMap((body) -> {
                try {
                    McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body);

                    if (message instanceof McpSchema.JSONRPCRequest) {
                        String method = ((McpSchema.JSONRPCRequest) message).method();
                        if (McpSchema.METHODTOOLSCALL.equals(method)) {
                            // 工具触发消息,获取此时对应的工具信息,塞入对应的请求头姐信息
                            Map<String, String> headers = this.session2headers.get(session.getId());

                            LinkedHashMap<String, String> params = (LinkedHashMap<String, String>) ((McpSchema.JSONRPCRequest) message).params();
                            String toolName = params.get("name");
                            Assert.notNull(toolName, "Tool name cannot be null");
                            for (McpRestfulToolCallback toolCallback : (McpRestfulToolCallback[]) mcpRestfulToolCallbackProvider.getToolCallbacks()) {
                                if (toolName.equals(toolCallback.getToolDefinition().name())) {
                                    toolCallback.setHeaders( headers);
                                }
                            }
                        }
                    }

                    return session.handle(message).flatMap((response) -> ServerResponse.ok().build()).onErrorResume((error) -> {
                        logger.error("Error processing  message: {}", error.getMessage());
                        return ServerResponse.status(HttpStatus.INTERNALSERVERERROR).bodyValue(new McpError(error.getMessage()));
                    });
                } catch (IOException | IllegalArgumentException e) {
                    logger.error("Failed to deserialize message: {}", ((Exception)e).getMessage());
                    return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
                }
            });
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    private class WebFluxMcpSessionTransport implements McpServerTransport {
        private final FluxSink<ServerSentEvent<?>> sink;

        public WebFluxMcpSessionTransport(FluxSink<ServerSentEvent<?>> sink) {
            this.sink = sink;
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromSupplier(() -> {
                try {
                    return WebFluxSseServerTransportProvider.this.objectMapper.writeValueAsString(message);
                } catch (IOException e) {
                    throw Exceptions.propagate(e);
                }
            }).doOnNext((jsonText) -> {
                ServerSentEvent<Object> event = ServerSentEvent.builder().event("message").data(jsonText).build();
                this.sink.next(event);
            }).doOnError((e) -> {
                Throwable exception = Exceptions.unwrap(e);
                this.sink.error(exception);
            }).then();
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)WebFluxSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            FluxSink var10000 = this.sink;
            Objects.requireNonNull(var10000);
            return Mono.fromRunnable(var10000::complete);
        }

        public void close() {
            this.sink.complete();
        }
    }

    public static class Builder {
        private ObjectMapper objectMapper;
        private String baseUrl = "";
        private String messageEndpoint;
        private String sseEndpoint = "/sse";

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder basePath(String baseUrl) {
            Assert.notNull(baseUrl, "basePath must not be null");
            this.baseUrl = baseUrl;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.notNull(messageEndpoint, "Message endpoint must not be null");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public WebFluxSseServerTransportProvider build() {
            Assert.notNull(this.objectMapper, "ObjectMapper must be set");
            Assert.notNull(this.messageEndpoint, "Message endpoint must be set");
            return new WebFluxSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint);
        }
    }
}
Restful 接口改造为工具
ParseRestful
package com.spring.ai.tutorial.mcp.server.parse;

import com.spring.ai.tutorial.mcp.server.model.Parameter;
import com.spring.ai.tutorial.mcp.server.model.RestfulModel;
import com.spring.ai.tutorial.mcp.server.util.JSONSchemaUtil;
import org.springframework.ai.mcp.McpRestfulToolCallback;
import org.springframework.ai.mcp.McpRestfulToolCallbackProvider;
import org.springframework.ai.mcp.RestfulToolDefinition;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;

@Component
public class ParseRestful {

    public McpRestfulToolCallbackProvider getRestfulToolCallbackProvider() {
        List<McpRestfulToolCallback> toolCallbacks = new ArrayList<>();
        getRestfulModels().forEach(
                restfulModel -> {
                    RestfulToolDefinition restfulToolDefinition = RestfulToolDefinition.builder()
                            .name(restfulModel.name())
                            .description(restfulModel.description())
                            .inputSchema(restfulModel.inputSchema())
                            .url(restfulModel.url())
                            .method(restfulModel.method())
                            .path(restfulModel.path())
                            .httpMethod(restfulModel.httpMethod())
                            .build();
                    McpRestfulToolCallback mcpRestfulToolCallback = McpRestfulToolCallback.builder().toolDefinition(restfulToolDefinition).build();

                    toolCallbacks.add(mcpRestfulToolCallback);
                });
        return McpRestfulToolCallbackProvider.builder()
                .toolCallbacks(toolCallbacks.toArray(new McpRestfulToolCallback[0]))
                .build();
    }


    public List<RestfulModel> getRestfulModels() {

        Parameter parameter = Parameter.builder()
                .parameteNname("timeZoneId")
                .description("time zone id, such as Asia/Shanghai")
                .required(true)
                .type("string")
                .build();
        return List.of(
                new RestfulModel("getCiteTimeMethod", "获取指定时区的时间", JSONSchemaUtil.getInputSchema(List.of(parameter)), "http://localhost:101", "getCiteTimeMethod", "/time/city", HttpMethod.GET)
        );
    }

}
RestfulModel
public record RestfulModel(String name, String description, String inputSchema, String url, String method, String path, HttpMethod httpMethod) {

}
Parameter
package com.spring.ai.tutorial.mcp.server.model;

public record Parameter(String parameteNname,
                        String description,
                        boolean required,
                        String type
) {

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String parameteNname;
        private String description;
        private boolean required;
        private String type;

        public Builder parameteNname(String parameteNname) {
            this.parameteNname = parameteNname;
            return this;
        }

        public Builder description(String description) {
            this.description = description;
            return this;
        }

        public Builder required(boolean required) {
            this.required = required;
            return this;
        }

        public Builder type(String type) {
            this.type = type;
            return this;
        }

        public Parameter build() {
            return new Parameter(parameteNname, description, required, type);
        }
        
    }

}
辅助工具类
ApplicationContextHolder
package com.spring.ai.tutorial.mcp.server.util;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

@Component
public class ApplicationContextHolder implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        ApplicationContextHolder.applicationContext = applicationContext;
    }

    public static <T> T getBean(Class<T> clazz) {
        return applicationContext.getBean(clazz);
    }
}
JSONSchemaUtil
package com.spring.ai.tutorial.mcp.server.util;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.victools.jsonschema.generator.Module;
import com.github.victools.jsonschema.generator.Option;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import com.github.victools.jsonschema.module.swagger2.Swagger2Module;
import com.spring.ai.tutorial.mcp.server.model.Parameter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.ai.util.json.schema.SpringAiSchemaModule;
import org.springframework.util.StringUtils;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;

public class JSONSchemaUtil {

    private static final SchemaGenerator SUBTYPESCHEMAGENERATOR;
    private static final Logger logger = LoggerFactory.getLogger(JSONSchemaUtil.class);

    static {
        Module jacksonModule = new JacksonModule(new JacksonOption[]{JacksonOption.RESPECTJSONPROPERTYREQUIRED});
        Module openApiModule = new Swagger2Module();
        Module springAiSchemaModule = new SpringAiSchemaModule(new SpringAiSchemaModule.Option[0]);
        SchemaGeneratorConfigBuilder schemaGeneratorConfigBuilder = (new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT202012, OptionPreset.PLAINJSON)).with(jacksonModule).with(openApiModule).with(springAiSchemaModule).with(Option.EXTRAOPENAPIFORMATVALUES, new Option[0]).with(Option.PLAINDEFINITIONKEYS, new Option[0]);
        SchemaGeneratorConfig subtypeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.without(Option.SCHEMAVERSIONINDICATOR, new Option[0]).build();
        SUBTYPESCHEMAGENERATOR = new SchemaGenerator(subtypeSchemaGeneratorConfig);
    }

    public static String getInputSchema(List<Parameter> parameters) {
        ObjectNode schema = JsonParser.getObjectMapper().createObjectNode();
        schema.put("$schema", SchemaVersion.DRAFT202012.getIdentifier());
        schema.put("type", "object");
        ObjectNode properties = schema.putObject("properties");
        List<String> required = new ArrayList<>();

        for (Parameter parameter : parameters) {
            String parameterName = parameter.parameteNname();
            Type parameterType = null;
            if (parameter.required()) {
                required.add(parameterName);
            }
            try {
                if (parameter.type() != null) {
                    // 示例:获取String类型的Type
                    parameterType = getTypeFromString(parameter.type());
                } else {
                    logger.warn("参数 {} 的 schema 或 format 为空,跳过类型转换", parameterName);
                }
            } catch (ClassNotFoundException e) {
                logger.error("无法将字符串类型转换为Type: {}", parameter.type(), e);
            }

            ObjectNode parameterNode = SUBTYPESCHEMAGENERATOR.generateSchema(parameterType, new Type[0]);
            if (StringUtils.hasText(parameter.description())) {
                parameterNode.put("description", parameter.description());
            }
            properties.set(parameterName, parameterNode);
        }
        ArrayNode requiredArray = schema.putArray("required");
        Objects.requireNonNull(requiredArray);
        required.forEach(requiredArray::add);

        JsonSchemaGenerator.SchemaOption[] schemaOptions = new JsonSchemaGenerator.SchemaOption[0];
        processSchemaOptions(schemaOptions, schema);
        return schema.toPrettyString();
    }

    public static Type getTypeFromString(String typeString) throws ClassNotFoundException {
        return switch (typeString) {
            case "string" -> Class.forName("java.lang.String");
            case "number" -> Class.forName("java.lang.Number");
            case "integer" -> Class.forName("java.lang.Integer");
            case "boolean" -> Class.forName("java.lang.Boolean");
            default -> throw new ClassNotFoundException("Unsupported type: " + typeString);
        };
    }

    private static void processSchemaOptions(JsonSchemaGenerator.SchemaOption[] schemaOptions, ObjectNode schema) {
        if (Stream.of(schemaOptions).noneMatch((option) -> {
            return option == JsonSchemaGenerator.SchemaOption.ALLOWADDITIONALPROPERTIESBYDEFAULT;
        })) {
            schema.put("additionalProperties", false);
        }

        if (Stream.of(schemaOptions).anyMatch((option) -> {
            return option == JsonSchemaGenerator.SchemaOption.UPPERCASETYPEVALUES;
        })) {
            convertTypeValuesToUpperCase(schema);
        }

    }

    public static void convertTypeValuesToUpperCase(ObjectNode node) {
        if (node.isObject()) {
            node.fields().forEachRemaining((entry) -> {
                JsonNode value = (JsonNode) entry.getValue();
                if (value.isObject()) {
                    convertTypeValuesToUpperCase((ObjectNode) value);
                } else if (value.isArray()) {
                    value.elements().forEachRemaining((element) -> {
                        if (element.isObject() || element.isArray()) {
                            convertTypeValuesToUpperCase((ObjectNode) element);
                        }

                    });
                } else if (value.isTextual() && ((String) entry.getKey()).equals("type")) {
                    String oldValue = node.get("type").asText();
                    node.put("type", oldValue.toUpperCase());
                }

            });
        } else if (node.isArray()) {
            node.elements().forEachRemaining((element) -> {
                if (element.isObject() || element.isArray()) {
                    convertTypeValuesToUpperCase((ObjectNode) element);
                }

            });
        }

    }

}
AuthServerApplication
package com.spring.ai.tutorial.mcp.server;

import com.spring.ai.tutorial.mcp.server.parse.ParseRestful;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

@SpringBootApplication
public class AuthServerApplication {

    public static void main(String[] args) {
        SpringApplication.run(AuthServerApplication.class, args);
    }

    @Bean
    public ToolCallbackProvider mcpRestfulToolCallbackProvider(ParseRestful parseRestful) {
         return parseRestful.getRestfulToolCallbackProvider();
    }
}

最终效果

  1. 先启动 Restful 服务,以 101 端口对外暴露时间服务
  2. 然后启动 MCP Server 服务,解析 Restful 方法,对外提供工具
  3. 最后启动 MCP Client 服务,传入对应的请求头配置

在 client 侧配置的请求头信息,在触发工具时,可观察到携带请求头信息经过 MCP Server 最终传递到了 restful 服务

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师,同时作为Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取