需求场景
需要在gateway层对带文件的multipart/form-data类型请求的参数进行加解密、鉴权、认证等操作,同时也不能丢失文件参数,处理完成后一起透传到后端服务。
版本情况
spring-cloud-starter-gateway:3.1.3
完整代码实现
注:以下为AbstractGatewayFilterFactory的实现。如果要在GlobalFilter中实现,核心代码不变。
import com.alibaba.nacos.common.utils.StringUtils;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.http.client.reactive.ClientHttpRequest;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
/**
* 带文件传输的请求认证的网关过滤器工厂。(本例为解密请求参数)
* 主要功能:
* 1. 验证是否为multipart/form-data类型的请求
* 2. 验证请求体中必须认证数据的字段
* 3. 对认证数据处理
* 4. 重构请求体,将解密后的数据和原始图片数据重新封装
*
*
* 使用场景:适用于需要在前置网关对请求数据认证的场景
*
* @author zhongjie
*/
@Slf4j
@Component
public class AuthWithFileGatewayFilterFactory extends AbstractGatewayFilterFactory<AuthWithFileGatewayFilterFactory.Config> {
public AuthWithFileGatewayFilterFactory() {
super(Config.class);
}
/**
* 创建并配置网关过滤器
* @param config 过滤器配置对象
* @return 配置好的GatewayFilter实例
*/
@Override
public GatewayFilter apply(Config config) {
return (exchange, chain) -> {
ServerHttpRequest request = exchange.getRequest();
if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(request.getHeaders().getContentType())) {
return chain.filter(exchange);
}
// 要进行解密的字段
String encryptedField = Optional.ofNullable(config).map(Config::getEncryptedField).orElse("paramData");
return exchange.getMultipartData()
.map(parts -> {
if(parts.isEmpty()) {
throw new IllegalArgumentException("empty multipart data");
}
if(parts.getFirst(encryptedField) == null) {
throw new IllegalArgumentException("empty encrypted data");
}
FormFieldPart paramData = (FormFieldPart) parts.getFirst(encryptedField);
if(paramData == null) {
throw new IllegalArgumentException("empty encrypted data");
}
// 解密paramData
String decryptedParamData = decrypt(paramData.value());
// 创建新的multipart请求体
MultipartBodyBuilder multipartBodyBuilder = new MultipartBodyBuilder();
for (Map.Entry<String, List<Part>> entry : parts.entrySet()) {
String key = entry.getKey();
if(StringUtils.equals(key, encryptedField)) {
multipartBodyBuilder.part(encryptedField, decryptedParamData); // 添加解密后的文本
} else {
Part part = entry.getValue().get(0);
if(part instanceof FilePart) {
FilePart filePart = (FilePart) part;
multipartBodyBuilder.asyncPart(key, filePart.content(), DataBuffer.class) // 添加原始文件
.filename(filePart.filename())
.contentType(filePart.headers().getContentType());
} else if(part instanceof FormFieldPart) {
FormFieldPart formFieldPart = (FormFieldPart) part;
multipartBodyBuilder.part(key, formFieldPart.value()); // 添加文本参数
} else {
log.warn("unknow part type");
}
}
}
BodyInserter bodyInserter = BodyInserters.fromMultipartData(multipartBodyBuilder.build());
return bodyInserter;
}).flatMap(bodyInserter -> {
HttpHeaders headers = new HttpHeaders();
headers.putAll(exchange.getRequest().getHeaders());
// the new content type will be computed by bodyInserter
// and then set in the request decorator
headers.remove(HttpHeaders.CONTENT_LENGTH);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, headers);
return bodyInserter.insert(outputMessage, new BodyInserterContext())
.then(Mono.defer(() -> {
ServerHttpRequest decorator = decorate(exchange, headers, outputMessage);
return chain.filter(exchange.mutate().request(decorator).build());
}))
.onErrorResume((Function<Throwable, Mono<Void>>) throwable -> release(exchange,
outputMessage, throwable));
})
.doOnSuccess(aVoid -> log.info("request completed"))
.doOnError(throwable -> log.error("request error", throwable))
.doFinally(signalType -> {
if(!StringUtils.equals(String.valueOf(signalType), SignalType.ON_ERROR.toString()) && !StringUtils.equals(String.valueOf(signalType), SignalType.ON_COMPLETE.toString())) {
log.info("unknow signal type is {}", signalType);
}
});
};
}
protected Mono<Void> release(ServerWebExchange exchange, CachedBodyOutputMessage outputMessage,
Throwable throwable) {
if (outputMessage.isCached()) {
return outputMessage.getBody().map(DataBufferUtils::release).then(Mono.error(throwable));
}
return Mono.error(throwable);
}
/**
* 装饰服务器HTTP请求,主要用于处理请求头和请求体
* @param exchange 服务器Web交换对象
* @param headers HTTP头信息
* @param outputMessage 缓存的消息输出对象
* @return 装饰后的ServerHttpRequest
*/
ServerHttpRequestDecorator decorate(ServerWebExchange exchange, HttpHeaders headers,
CachedBodyOutputMessage outputMessage) {
return new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public HttpHeaders getHeaders() {
long contentLength = headers.getContentLength();
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
if (contentLength > 0) {
httpHeaders.setContentLength(contentLength);
}
else {
httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
}
return httpHeaders;
}
@Override
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
};
}
// 示例:解密方法(替换为你自己的解密逻辑)
private String decrypt(String encryptedData) {
return encryptedData;
}
static class CachedBodyOutputMessage implements ClientHttpRequest {
private ServerWebExchange serverWebExchange;
private ServerHttpRequest serverRequest;
private HttpHeaders httpHeaders;
private boolean cached = false;
//
private Flux<DataBuffer> body = Flux
.error(new IllegalStateException("The body is not set. " + "Did handling complete with success?"));
public CachedBodyOutputMessage(ServerWebExchange exchange, HttpHeaders httpHeaders) {
this.serverWebExchange = exchange;
this.serverRequest = exchange.getRequest();
this.httpHeaders = httpHeaders;
}
@Override
public HttpMethod getMethod() {
return serverRequest.getMethod();
}
@Override
public URI getURI() {
return serverRequest.getURI();
}
@Override
public MultiValueMap<String, HttpCookie> getCookies() {
return serverRequest.getCookies();
}
@Override
public <T> T getNativeRequest() {
return (T) serverRequest;
}
@Override
public DataBufferFactory bufferFactory() {
return serverWebExchange.getResponse().bufferFactory();
}
@Override
public void beforeCommit(Supplier<? extends Mono<Void>> action) {
}
@Override
public boolean isCommitted() {
return false;
}
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
this.body = Flux.from(body);
this.cached = true;
return Mono.empty();
}
@Override
public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
return writeWith(Flux.from(body).flatMap(p -> p));
}
@Override
public Mono<Void> setComplete() {
return writeWith(Flux.empty());
}
@Override
public HttpHeaders getHeaders() {
return this.httpHeaders;
}
boolean isCached() {
return this.cached;
}
/**
* Return the request body, or an error stream if the body was never set or when.
* @return body as {@link Flux}
*/
public Flux<DataBuffer> getBody() {
return this.body;
}
}
@Data
public static class Config {
// 要在网关中做认证处理的请求字段,如 "paramData"
private String encryptedField;
}
}
路由配置
spring:
cloud:
gateway:
routes:
# 路由标识
- id: upload-with-auth
# 下游服务地址(示例为 http://localhost:8081)
uri: http://localhost:8081
predicates:
# 匹配上传接口
- Path=/upload/**
filters:
# 使用我们自定义的 AuthWithFileGatewayFilterFactory
# 语法:AuthWithFile=加密的字段名
# 本例中上由把加密后的 JSON 放在字段 paramData 里
- name: AuthWithFile
args:
encryptedField: paramData