webflux 代理实现过滤器

387 阅读3分钟

核心原理

借鉴 spring-cloud-gateway 的过滤链。

  1. NettyRoutingFilter 通过 httpclient 转发
  2. NettyWriteResponseFilter将结果回写

核心调整点

  1. 通过cookie 或 请求头识别记录代理目标地址
  2. 修改访问地址为当前服务器地址,这样访问的时候,都由代理逻辑处理
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.buffer.*;
import org.springframework.http.*;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.Connection;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;
import io.netty.handler.codec.http.HttpMethod;
import javax.annotation.Resource;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
/**
 * @todo    拦截带 proxyIpv6 变量的请求转发给设备, 或sessionId与IPv6映射的请求, 转发给对应的目标
 * @author  liangguohun
 * @date    2022/5/19 15:16
 */
@Slf4j
@Component
public class ProxyWebFilter implements WebFilter {

    private List<MediaType> streamingMediaTypes = Arrays.asList(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_STREAM_JSON);

    private final HttpClient httpClient = HttpClient.create();

    @Resource
    CacheService cacheService;

    @Value("${applicationAddress}")
    private String applicationAddress;

    @Resource
    ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;

    // do not use this headersFilters directly, use getHeadersFilters() instead.
    private volatile List<HttpHeadersFilter> headersFilters;

    public List<HttpHeadersFilter> getHeadersFilters() {
        if (headersFilters == null) {
            headersFilters = headersFiltersProvider.getIfAvailable();
        }
        return headersFilters;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        MultiValueMap<String, HttpCookie> req_cookies = exchange.getRequest().getCookies();
        // 获取url 里头包含 targetIPv6 就保存cookie 与IPv6 对应关系
        MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
        String proxyAddress = null;
        String cookie = null;

        // 映射关系中包含IPv6 结果就代理请求
        if(req_cookies!=null && req_cookies.size()>0){
            // cookie = req_cookies.getFirst("SECKEY_ABVK").getValue();
            cookie = req_cookies.getFirst("BMAP_SECKEY").getValue();
        }
        String uuid = null;
        if(queryParams!=null && queryParams.size()>0){
            proxyAddress = queryParams.getFirst("proxyAddress");
            if(cookie!=null && proxyAddress!=null){
                cacheService.putValue(cookie, proxyAddress);
            } else if(proxyAddress!=null) {
                uuid = UUID.randomUUID().toString();
                cacheService.putValue(uuid, proxyAddress);
            }
        }

        // 缓存映射种存在IPv6代理地址,则通过代理转发
        if(proxyAddress == null && cookie!=null){
            proxyAddress = cacheService.getValue(cookie);
        } else if(proxyAddress == null){
            // 通过请求头获取
            uuid = exchange.getRequest().getHeaders().getFirst("proxyAddress");
            if(uuid!=null){
                proxyAddress = cacheService.getValue(uuid);
            }
        }
        // 直接代理访问
        if(proxyAddress!=null ){
            URI requestUrl = exchange.getRequest().getURI();
            // 请求的核心逻辑
            String scheme = requestUrl.getScheme();
            if (isAlreadyRouted(exchange) || (!"http".equals(scheme) && !"https".equals(scheme))) {
                return chain.filter(exchange);
            }
            setAlreadyRouted(exchange);

            ServerHttpRequest request = exchange.getRequest();
            final HttpMethod method = HttpMethod.valueOf(request.getMethodValue());
            String url = requestUrl.toASCIIString();
            url = turn2Target(url, proxyAddress);
            // 2.拿到所有前置过滤器要添加的request headers
            exchange.getAttributes().put("GATEWAY_REQUEST_URL_ATTR", url);
            HttpHeaders filtered = exchange.getRequest().getHeaders();
            final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders();
            // new一个httpHeaders把要加的header都塞进去
            filtered.forEach(httpHeaders::set);

            // 根据exchange里的属性,判断是否要保留header里的host属性
            boolean preserveHost = exchange.getAttributeOrDefault("PRESERVE_HOST_HEADER_ATTRIBUTE", false);

            // 3.整个httpClient工具人来发request
            String finalUuid = uuid;
            Flux<HttpClientResponse> responseFlux = httpClient
                .headers(headers -> {
                    // 先把之前收集好的要加的header塞进request的headers里
                    headers.add(httpHeaders);
                    // 下面这段其实就是判断要不要保留host这个header,跟整体逻辑没有太大联系
                    headers.remove(HttpHeaders.HOST);
                    if (preserveHost) {
                        String host = request.getHeaders().getFirst(HttpHeaders.HOST);
                        headers.add(HttpHeaders.HOST, host);
                    }
                    // 具体调用发送request的位置
                }).request(method).uri(url).send((req, nettyOutbound) -> {
                    return nettyOutbound.send(request.getBody().map(this::getByteBuf));
                    // 4.配置response connection
                    // 之后NettyWriteResponseFilter就是靠这个连接来接收response字节流
                }).responseConnection((res, connection) -> {
                        // Defer committing the response until all route filters have run
                        // Put client response as ServerWebExchange attribute and write
                        // response later NettyWriteResponseFilter
                        exchange.getAttributes().put("CLIENT_RESPONSE_ATTR", res);
                        exchange.getAttributes().put("CLIENT_RESPONSE_CONN_ATTR", connection);

                        ServerHttpResponse response = exchange.getResponse();
                        // put headers and status so filters can modify the response
                        HttpHeaders headers = new HttpHeaders();

                        res.responseHeaders().forEach(entry -> headers.add(entry.getKey(), entry.getValue()));

                        String contentTypeValue = headers.getFirst(HttpHeaders.CONTENT_TYPE);
                        if (StringUtils.hasLength(contentTypeValue)) {
                            exchange.getAttributes().put("ORIGINAL_RESPONSE_CONTENT_TYPE_ATTR", contentTypeValue);
                        }

                        setResponseStatus(res, response);
                        // make sure headers filters run after setting status so it is
                        // available in response
                        HttpHeaders filteredResponseHeaders = HttpHeadersFilter.filter(getHeadersFilters(), headers, exchange, HttpHeadersFilter.Type.RESPONSE);
                        turnTarget(filteredResponseHeaders);
                        if (!filteredResponseHeaders.containsKey(HttpHeaders.TRANSFER_ENCODING)
                                && filteredResponseHeaders.containsKey(HttpHeaders.CONTENT_LENGTH)) {
                            // It is not valid to have both the transfer-encoding header and
                            // the content-length header.
                            // Remove the transfer-encoding header in the response if the
                            // content-length header is present.
                            response.getHeaders().remove(HttpHeaders.TRANSFER_ENCODING);
                        }

                        exchange.getAttributes().put("RESPONSE_HEADER_NAMES", filteredResponseHeaders.keySet());
                        response.getHeaders().putAll(filteredResponseHeaders);
                        response.getHeaders().add("proxyAddress", finalUuid);
                        return Mono.just(res);
                    });
            Duration responseTimeout = Duration.ofMillis(3000L);
            if (responseTimeout != null) {
                responseFlux = responseFlux
                        .timeout(responseTimeout, Mono.error(new TimeoutException("Response took longer than timeout: " + responseTimeout)))
                        .onErrorMap(TimeoutException.class, th -> new ResponseStatusException(HttpStatus.GATEWAY_TIMEOUT, th.getMessage(), th));
            }
            return responseFlux.then(
                Mono.defer(() -> {
                    ServerHttpResponse response = exchange.getResponse();
                    Connection connection = exchange.getAttribute("CLIENT_RESPONSE_CONN_ATTR");
                    if (connection == null) {
                        return Mono.empty();
                    }
                        // TODO: needed?
                        final Flux<DataBuffer> body = connection
                                .inbound()
                                .receive()
                                .retain()
                                .map(byteBuf -> wrap(byteBuf, response));

                        MediaType contentType = null;
                        try {
                            contentType = response.getHeaders().getContentType();
                        } catch (Exception e) {
                            if (log.isTraceEnabled()) {
                                log.trace("invalid media type", e);
                            }
                        }
                        return (isStreamingMediaType(contentType)
                                ? response.writeAndFlushWith(body.map(Flux::just))
                                : response.writeWith(body));
                })).doOnCancel(() -> cleanup(exchange));
        }
        // 界面刷新路由404 处理
        String path = exchange.getRequest().getURI().getPath();
        List<String> menus = Arrays.asList("/",
                "/alarm",
                "/authority",
                "/company",
                "/config",
                "/firmware",
                "/firmwarelist",
                "/gatewaylist",
                "/group",
                "/grouplist",
                "/log",
                "/map",
                "/statistics",
                "/task",
                "/userlist",
                "/alert-rule"
        );
        if (menus.contains(path)) {
            return chain.filter(exchange.mutate().request(exchange.getRequest().mutate().path("/index.html").build()).build());
        }

        return chain.filter(exchange);
    }

    /**
     * @todo    修改调用的ip+端口 第一个请求后后续的请求目标地址改为当前服务的
     * @author  liangguohun
     * @date    2022/5/27 16:34
     */
    private void turnTarget(HttpHeaders filteredResponseHeaders){
        List<String> hrefs = filteredResponseHeaders.get("Location");
        if(null!=hrefs && !hrefs.isEmpty()){
            String regex = "(https?://)[^/|^?]+";
            List<String> newHrefs = new ArrayList<>();
            hrefs.forEach(str->{
                String tgt = str.replaceFirst(regex, applicationAddress);
                newHrefs.add(tgt);
                log.info(">>>>>>>>>{}",tgt);
            });
            filteredResponseHeaders.put("Location", newHrefs);
        }
    }

    private String turn2Target(String orinalUrl, String proxyAddress){
        String regex = "(https?://)[^/|^?]+";
        String tgt = orinalUrl.replaceFirst(regex, proxyAddress);
        return tgt;
    }

    protected DataBuffer wrap(ByteBuf byteBuf, ServerHttpResponse response) {
        DataBufferFactory bufferFactory = response.bufferFactory();
        if (bufferFactory instanceof NettyDataBufferFactory) {
            NettyDataBufferFactory factory = (NettyDataBufferFactory) bufferFactory;
            return factory.wrap(byteBuf);
        }
        // MockServerHttpResponse creates these
        else if (bufferFactory instanceof DefaultDataBufferFactory) {
            DataBuffer buffer = ((DefaultDataBufferFactory) bufferFactory).allocateBuffer(byteBuf.readableBytes());
            buffer.write(byteBuf.nioBuffer());
            byteBuf.release();
            return buffer;
        }
        throw new IllegalArgumentException("Unkown DataBufferFactory type " + bufferFactory.getClass());
    }

    private void cleanup(ServerWebExchange exchange) {
        Connection connection = exchange.getAttribute("CLIENT_RESPONSE_CONN_ATTR");
        if (connection != null && connection.channel().isActive() && !connection.isPersistent()) {
            connection.dispose();
        }
    }

    // TODO: use framework if possible
    private boolean isStreamingMediaType(@Nullable MediaType contentType) {
        if (contentType != null) {
            for (int i = 0; i < streamingMediaTypes.size(); i++) {
                if (streamingMediaTypes.get(i).isCompatibleWith(contentType)) {
                    return true;
                }
            }
        }
        return false;
    }

    protected ByteBuf getByteBuf(DataBuffer dataBuffer) {
        if (dataBuffer instanceof NettyDataBuffer) {
            NettyDataBuffer buffer = (NettyDataBuffer) dataBuffer;
            return buffer.getNativeBuffer();
        }
        // MockServerHttpResponse creates these
        else if (dataBuffer instanceof DefaultDataBuffer) {
            DefaultDataBuffer buffer = (DefaultDataBuffer) dataBuffer;
            return Unpooled.wrappedBuffer(buffer.getNativeBuffer());
        }
        throw new IllegalArgumentException("Unable to handle DataBuffer of type " + dataBuffer.getClass());
    }

    private void setResponseStatus(HttpClientResponse clientResponse, ServerHttpResponse response) {
        HttpStatus status = HttpStatus.resolve(clientResponse.status().code());
        if (status != null) {
            response.setStatusCode(status);
        }
        else {
            while (response instanceof ServerHttpResponseDecorator) {
                response = ((ServerHttpResponseDecorator) response).getDelegate();
            }
            if (response instanceof AbstractServerHttpResponse) {
                ((AbstractServerHttpResponse) response).setRawStatusCode(clientResponse.status().code());
            }
            else {
                // TODO: log warning here, not throw error?
                throw new IllegalStateException("Unable to set status code " + clientResponse.status().code()
                        + " on response of type " + response.getClass().getName());
            }
        }
    }

}