SpringCloudGateway处理multipart/form-data带文件请求文件不丢失

62 阅读3分钟

需求场景

需要在gateway层对带文件的multipart/form-data类型请求的参数进行加解密、鉴权、认证等操作,同时也不能丢失文件参数,处理完成后一起透传到后端服务。

版本情况

spring-cloud-starter-gateway:3.1.3

完整代码实现

注:以下为AbstractGatewayFilterFactory的实现。如果要在GlobalFilter中实现,核心代码不变。


import com.alibaba.nacos.common.utils.StringUtils;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.http.client.reactive.ClientHttpRequest;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;

/**
 * 带文件传输的请求认证的网关过滤器工厂。(本例为解密请求参数)
 * 主要功能:
 * 1. 验证是否为multipart/form-data类型的请求
 * 2. 验证请求体中必须认证数据的字段
 * 3. 对认证数据处理
 * 4. 重构请求体,将解密后的数据和原始图片数据重新封装
 *
 *
 * 使用场景:适用于需要在前置网关对请求数据认证的场景
 *
 * @author zhongjie
 */
@Slf4j
@Component
public class AuthWithFileGatewayFilterFactory extends AbstractGatewayFilterFactory<AuthWithFileGatewayFilterFactory.Config> {
    
    
    public AuthWithFileGatewayFilterFactory() {
        super(Config.class);
    }
    /**
     * 创建并配置网关过滤器
     * @param config 过滤器配置对象
     * @return 配置好的GatewayFilter实例
     */
    @Override
    public GatewayFilter apply(Config config) {
        return (exchange, chain) -> {
            ServerHttpRequest request = exchange.getRequest();
            if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(request.getHeaders().getContentType())) {
                return chain.filter(exchange);
            }
            
            // 要进行解密的字段
            String encryptedField = Optional.ofNullable(config).map(Config::getEncryptedField).orElse("paramData");

            return exchange.getMultipartData()
                    .map(parts -> {
                        if(parts.isEmpty()) {
                            throw new IllegalArgumentException("empty multipart data");
                        }
                        if(parts.getFirst(encryptedField) == null) {
                            throw new IllegalArgumentException("empty encrypted data");
                        }
                        FormFieldPart paramData = (FormFieldPart) parts.getFirst(encryptedField);
                        if(paramData == null) {
                            throw new IllegalArgumentException("empty encrypted data");
                        }
                        // 解密paramData
                        String decryptedParamData = decrypt(paramData.value());
                        // 创建新的multipart请求体
                        MultipartBodyBuilder multipartBodyBuilder = new MultipartBodyBuilder();
                        for (Map.Entry<String, List<Part>> entry : parts.entrySet()) {
                            String key = entry.getKey();
                            if(StringUtils.equals(key, encryptedField)) {
                                multipartBodyBuilder.part(encryptedField, decryptedParamData); // 添加解密后的文本
                            } else {
                                Part part = entry.getValue().get(0);
                                if(part instanceof FilePart) {
                                    FilePart filePart = (FilePart) part;
                                    multipartBodyBuilder.asyncPart(key, filePart.content(), DataBuffer.class) // 添加原始文件
                                            .filename(filePart.filename())
                                            .contentType(filePart.headers().getContentType());
                                } else if(part instanceof FormFieldPart)  {
                                    FormFieldPart formFieldPart = (FormFieldPart) part;
                                    multipartBodyBuilder.part(key, formFieldPart.value()); // 添加文本参数
                                } else {
                                    log.warn("unknow part type");
                                }
                            }
                        }
                        BodyInserter bodyInserter = BodyInserters.fromMultipartData(multipartBodyBuilder.build());
                        return bodyInserter;
                    }).flatMap(bodyInserter -> {
                        HttpHeaders headers = new HttpHeaders();
                        headers.putAll(exchange.getRequest().getHeaders());
                        // the new content type will be computed by bodyInserter
                        // and then set in the request decorator
                        headers.remove(HttpHeaders.CONTENT_LENGTH);

                        CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, headers);
                        return bodyInserter.insert(outputMessage, new BodyInserterContext())
                                .then(Mono.defer(() -> {
                                    ServerHttpRequest decorator = decorate(exchange, headers, outputMessage);
                                    return chain.filter(exchange.mutate().request(decorator).build());
                                }))
                                .onErrorResume((Function<Throwable, Mono<Void>>) throwable -> release(exchange,
                                        outputMessage, throwable));
                    })
                    .doOnSuccess(aVoid -> log.info("request completed"))
                    .doOnError(throwable -> log.error("request error", throwable))
                    .doFinally(signalType -> {
                        if(!StringUtils.equals(String.valueOf(signalType), SignalType.ON_ERROR.toString()) && !StringUtils.equals(String.valueOf(signalType), SignalType.ON_COMPLETE.toString())) {
                            log.info("unknow signal type is {}", signalType);
                        }
                    });

        };
    }


    protected Mono<Void> release(ServerWebExchange exchange, CachedBodyOutputMessage outputMessage,
                                 Throwable throwable) {
        if (outputMessage.isCached()) {
            return outputMessage.getBody().map(DataBufferUtils::release).then(Mono.error(throwable));
        }
        return Mono.error(throwable);
    }

    /**
     * 装饰服务器HTTP请求,主要用于处理请求头和请求体
     * @param exchange 服务器Web交换对象
     * @param headers HTTP头信息
     * @param outputMessage 缓存的消息输出对象
     * @return 装饰后的ServerHttpRequest
     */
    ServerHttpRequestDecorator decorate(ServerWebExchange exchange, HttpHeaders headers,
                                        CachedBodyOutputMessage outputMessage) {
        return new ServerHttpRequestDecorator(exchange.getRequest()) {
            @Override
            public HttpHeaders getHeaders() {
                long contentLength = headers.getContentLength();
                HttpHeaders httpHeaders = new HttpHeaders();
                httpHeaders.putAll(headers);
                if (contentLength > 0) {
                    httpHeaders.setContentLength(contentLength);
                }
                else {
                    httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                }
                return httpHeaders;
            }

            @Override
            public Flux<DataBuffer> getBody() {
                return outputMessage.getBody();
            }
        };
    }

    // 示例:解密方法(替换为你自己的解密逻辑)
    private String decrypt(String encryptedData) {
        return encryptedData;
    }

    static class CachedBodyOutputMessage implements ClientHttpRequest {

        private ServerWebExchange serverWebExchange;

        private ServerHttpRequest serverRequest;

        private HttpHeaders httpHeaders;

        private boolean cached = false;
        //
        private Flux<DataBuffer> body = Flux
                .error(new IllegalStateException("The body is not set. " + "Did handling complete with success?"));

        public CachedBodyOutputMessage(ServerWebExchange exchange, HttpHeaders httpHeaders) {
            this.serverWebExchange = exchange;
            this.serverRequest = exchange.getRequest();
            this.httpHeaders = httpHeaders;
        }

        @Override
        public HttpMethod getMethod() {
            return serverRequest.getMethod();
        }

        @Override
        public URI getURI() {
            return serverRequest.getURI();
        }

        @Override
        public MultiValueMap<String, HttpCookie> getCookies() {
            return serverRequest.getCookies();
        }

        @Override
        public <T> T getNativeRequest() {
            return (T) serverRequest;
        }

        @Override
        public DataBufferFactory bufferFactory() {
            return serverWebExchange.getResponse().bufferFactory();
        }

        @Override
        public void beforeCommit(Supplier<? extends Mono<Void>> action) {

        }

        @Override
        public boolean isCommitted() {
            return false;
        }

        @Override
        public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
            this.body = Flux.from(body);
            this.cached = true;
            return Mono.empty();
        }

        @Override
        public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
            return writeWith(Flux.from(body).flatMap(p -> p));
        }

        @Override
        public Mono<Void> setComplete() {
            return writeWith(Flux.empty());
        }

        @Override
        public HttpHeaders getHeaders() {
            return this.httpHeaders;
        }

        boolean isCached() {
            return this.cached;
        }

        /**
         * Return the request body, or an error stream if the body was never set or when.
         * @return body as {@link Flux}
         */
        public Flux<DataBuffer> getBody() {
            return this.body;
        }

    }

    @Data
    public static class Config {
        // 要在网关中做认证处理的请求字段,如 "paramData"
        private String encryptedField;
    }
}

路由配置

spring:
  cloud:
    gateway:
      routes:
        # 路由标识
        - id: upload-with-auth
          # 下游服务地址(示例为 http://localhost:8081)
          uri: http://localhost:8081
          predicates:
            # 匹配上传接口
            - Path=/upload/**
          filters:
            # 使用我们自定义的 AuthWithFileGatewayFilterFactory
            # 语法:AuthWithFile=加密的字段名
            # 本例中上由把加密后的 JSON 放在字段 paramData 里
            - name: AuthWithFile
              args:
                encryptedField: paramData