核心原理
借鉴 spring-cloud-gateway 的过滤链。
- NettyRoutingFilter 通过 httpclient 转发
- NettyWriteResponseFilter将结果回写
核心调整点
- 通过cookie 或 请求头识别记录代理目标地址
- 修改访问地址为当前服务器地址,这样访问的时候,都由代理逻辑处理
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.buffer.*;
import org.springframework.http.*;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.Connection;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;
import io.netty.handler.codec.http.HttpMethod;
import javax.annotation.Resource;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
/**
* @todo 拦截带 proxyIpv6 变量的请求转发给设备, 或sessionId与IPv6映射的请求, 转发给对应的目标
* @author liangguohun
* @date 2022/5/19 15:16
*/
@Slf4j
@Component
public class ProxyWebFilter implements WebFilter {
private List<MediaType> streamingMediaTypes = Arrays.asList(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_STREAM_JSON);
private final HttpClient httpClient = HttpClient.create();
@Resource
CacheService cacheService;
@Value("${applicationAddress}")
private String applicationAddress;
@Resource
ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;
// do not use this headersFilters directly, use getHeadersFilters() instead.
private volatile List<HttpHeadersFilter> headersFilters;
public List<HttpHeadersFilter> getHeadersFilters() {
if (headersFilters == null) {
headersFilters = headersFiltersProvider.getIfAvailable();
}
return headersFilters;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
MultiValueMap<String, HttpCookie> req_cookies = exchange.getRequest().getCookies();
// 获取url 里头包含 targetIPv6 就保存cookie 与IPv6 对应关系
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
String proxyAddress = null;
String cookie = null;
// 映射关系中包含IPv6 结果就代理请求
if(req_cookies!=null && req_cookies.size()>0){
// cookie = req_cookies.getFirst("SECKEY_ABVK").getValue();
cookie = req_cookies.getFirst("BMAP_SECKEY").getValue();
}
String uuid = null;
if(queryParams!=null && queryParams.size()>0){
proxyAddress = queryParams.getFirst("proxyAddress");
if(cookie!=null && proxyAddress!=null){
cacheService.putValue(cookie, proxyAddress);
} else if(proxyAddress!=null) {
uuid = UUID.randomUUID().toString();
cacheService.putValue(uuid, proxyAddress);
}
}
// 缓存映射种存在IPv6代理地址,则通过代理转发
if(proxyAddress == null && cookie!=null){
proxyAddress = cacheService.getValue(cookie);
} else if(proxyAddress == null){
// 通过请求头获取
uuid = exchange.getRequest().getHeaders().getFirst("proxyAddress");
if(uuid!=null){
proxyAddress = cacheService.getValue(uuid);
}
}
// 直接代理访问
if(proxyAddress!=null ){
URI requestUrl = exchange.getRequest().getURI();
// 请求的核心逻辑
String scheme = requestUrl.getScheme();
if (isAlreadyRouted(exchange) || (!"http".equals(scheme) && !"https".equals(scheme))) {
return chain.filter(exchange);
}
setAlreadyRouted(exchange);
ServerHttpRequest request = exchange.getRequest();
final HttpMethod method = HttpMethod.valueOf(request.getMethodValue());
String url = requestUrl.toASCIIString();
url = turn2Target(url, proxyAddress);
// 2.拿到所有前置过滤器要添加的request headers
exchange.getAttributes().put("GATEWAY_REQUEST_URL_ATTR", url);
HttpHeaders filtered = exchange.getRequest().getHeaders();
final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders();
// new一个httpHeaders把要加的header都塞进去
filtered.forEach(httpHeaders::set);
// 根据exchange里的属性,判断是否要保留header里的host属性
boolean preserveHost = exchange.getAttributeOrDefault("PRESERVE_HOST_HEADER_ATTRIBUTE", false);
// 3.整个httpClient工具人来发request
String finalUuid = uuid;
Flux<HttpClientResponse> responseFlux = httpClient
.headers(headers -> {
// 先把之前收集好的要加的header塞进request的headers里
headers.add(httpHeaders);
// 下面这段其实就是判断要不要保留host这个header,跟整体逻辑没有太大联系
headers.remove(HttpHeaders.HOST);
if (preserveHost) {
String host = request.getHeaders().getFirst(HttpHeaders.HOST);
headers.add(HttpHeaders.HOST, host);
}
// 具体调用发送request的位置
}).request(method).uri(url).send((req, nettyOutbound) -> {
return nettyOutbound.send(request.getBody().map(this::getByteBuf));
// 4.配置response connection
// 之后NettyWriteResponseFilter就是靠这个连接来接收response字节流
}).responseConnection((res, connection) -> {
// Defer committing the response until all route filters have run
// Put client response as ServerWebExchange attribute and write
// response later NettyWriteResponseFilter
exchange.getAttributes().put("CLIENT_RESPONSE_ATTR", res);
exchange.getAttributes().put("CLIENT_RESPONSE_CONN_ATTR", connection);
ServerHttpResponse response = exchange.getResponse();
// put headers and status so filters can modify the response
HttpHeaders headers = new HttpHeaders();
res.responseHeaders().forEach(entry -> headers.add(entry.getKey(), entry.getValue()));
String contentTypeValue = headers.getFirst(HttpHeaders.CONTENT_TYPE);
if (StringUtils.hasLength(contentTypeValue)) {
exchange.getAttributes().put("ORIGINAL_RESPONSE_CONTENT_TYPE_ATTR", contentTypeValue);
}
setResponseStatus(res, response);
// make sure headers filters run after setting status so it is
// available in response
HttpHeaders filteredResponseHeaders = HttpHeadersFilter.filter(getHeadersFilters(), headers, exchange, HttpHeadersFilter.Type.RESPONSE);
turnTarget(filteredResponseHeaders);
if (!filteredResponseHeaders.containsKey(HttpHeaders.TRANSFER_ENCODING)
&& filteredResponseHeaders.containsKey(HttpHeaders.CONTENT_LENGTH)) {
// It is not valid to have both the transfer-encoding header and
// the content-length header.
// Remove the transfer-encoding header in the response if the
// content-length header is present.
response.getHeaders().remove(HttpHeaders.TRANSFER_ENCODING);
}
exchange.getAttributes().put("RESPONSE_HEADER_NAMES", filteredResponseHeaders.keySet());
response.getHeaders().putAll(filteredResponseHeaders);
response.getHeaders().add("proxyAddress", finalUuid);
return Mono.just(res);
});
Duration responseTimeout = Duration.ofMillis(3000L);
if (responseTimeout != null) {
responseFlux = responseFlux
.timeout(responseTimeout, Mono.error(new TimeoutException("Response took longer than timeout: " + responseTimeout)))
.onErrorMap(TimeoutException.class, th -> new ResponseStatusException(HttpStatus.GATEWAY_TIMEOUT, th.getMessage(), th));
}
return responseFlux.then(
Mono.defer(() -> {
ServerHttpResponse response = exchange.getResponse();
Connection connection = exchange.getAttribute("CLIENT_RESPONSE_CONN_ATTR");
if (connection == null) {
return Mono.empty();
}
// TODO: needed?
final Flux<DataBuffer> body = connection
.inbound()
.receive()
.retain()
.map(byteBuf -> wrap(byteBuf, response));
MediaType contentType = null;
try {
contentType = response.getHeaders().getContentType();
} catch (Exception e) {
if (log.isTraceEnabled()) {
log.trace("invalid media type", e);
}
}
return (isStreamingMediaType(contentType)
? response.writeAndFlushWith(body.map(Flux::just))
: response.writeWith(body));
})).doOnCancel(() -> cleanup(exchange));
}
// 界面刷新路由404 处理
String path = exchange.getRequest().getURI().getPath();
List<String> menus = Arrays.asList("/",
"/alarm",
"/authority",
"/company",
"/config",
"/firmware",
"/firmwarelist",
"/gatewaylist",
"/group",
"/grouplist",
"/log",
"/map",
"/statistics",
"/task",
"/userlist",
"/alert-rule"
);
if (menus.contains(path)) {
return chain.filter(exchange.mutate().request(exchange.getRequest().mutate().path("/index.html").build()).build());
}
return chain.filter(exchange);
}
/**
* @todo 修改调用的ip+端口 第一个请求后后续的请求目标地址改为当前服务的
* @author liangguohun
* @date 2022/5/27 16:34
*/
private void turnTarget(HttpHeaders filteredResponseHeaders){
List<String> hrefs = filteredResponseHeaders.get("Location");
if(null!=hrefs && !hrefs.isEmpty()){
String regex = "(https?://)[^/|^?]+";
List<String> newHrefs = new ArrayList<>();
hrefs.forEach(str->{
String tgt = str.replaceFirst(regex, applicationAddress);
newHrefs.add(tgt);
log.info(">>>>>>>>>{}",tgt);
});
filteredResponseHeaders.put("Location", newHrefs);
}
}
private String turn2Target(String orinalUrl, String proxyAddress){
String regex = "(https?://)[^/|^?]+";
String tgt = orinalUrl.replaceFirst(regex, proxyAddress);
return tgt;
}
protected DataBuffer wrap(ByteBuf byteBuf, ServerHttpResponse response) {
DataBufferFactory bufferFactory = response.bufferFactory();
if (bufferFactory instanceof NettyDataBufferFactory) {
NettyDataBufferFactory factory = (NettyDataBufferFactory) bufferFactory;
return factory.wrap(byteBuf);
}
// MockServerHttpResponse creates these
else if (bufferFactory instanceof DefaultDataBufferFactory) {
DataBuffer buffer = ((DefaultDataBufferFactory) bufferFactory).allocateBuffer(byteBuf.readableBytes());
buffer.write(byteBuf.nioBuffer());
byteBuf.release();
return buffer;
}
throw new IllegalArgumentException("Unkown DataBufferFactory type " + bufferFactory.getClass());
}
private void cleanup(ServerWebExchange exchange) {
Connection connection = exchange.getAttribute("CLIENT_RESPONSE_CONN_ATTR");
if (connection != null && connection.channel().isActive() && !connection.isPersistent()) {
connection.dispose();
}
}
// TODO: use framework if possible
private boolean isStreamingMediaType(@Nullable MediaType contentType) {
if (contentType != null) {
for (int i = 0; i < streamingMediaTypes.size(); i++) {
if (streamingMediaTypes.get(i).isCompatibleWith(contentType)) {
return true;
}
}
}
return false;
}
protected ByteBuf getByteBuf(DataBuffer dataBuffer) {
if (dataBuffer instanceof NettyDataBuffer) {
NettyDataBuffer buffer = (NettyDataBuffer) dataBuffer;
return buffer.getNativeBuffer();
}
// MockServerHttpResponse creates these
else if (dataBuffer instanceof DefaultDataBuffer) {
DefaultDataBuffer buffer = (DefaultDataBuffer) dataBuffer;
return Unpooled.wrappedBuffer(buffer.getNativeBuffer());
}
throw new IllegalArgumentException("Unable to handle DataBuffer of type " + dataBuffer.getClass());
}
private void setResponseStatus(HttpClientResponse clientResponse, ServerHttpResponse response) {
HttpStatus status = HttpStatus.resolve(clientResponse.status().code());
if (status != null) {
response.setStatusCode(status);
}
else {
while (response instanceof ServerHttpResponseDecorator) {
response = ((ServerHttpResponseDecorator) response).getDelegate();
}
if (response instanceof AbstractServerHttpResponse) {
((AbstractServerHttpResponse) response).setRawStatusCode(clientResponse.status().code());
}
else {
// TODO: log warning here, not throw error?
throw new IllegalStateException("Unable to set status code " + clientResponse.status().code()
+ " on response of type " + response.getClass().getName());
}
}
}
}