springboot 重复读取请求的inputStream

105 阅读1分钟
public class CacheRequestFilter implements Filter {

    /**
     * BodyInputStream匿名类构造器
     */
    private static Constructor<?> BODY_INPUT_STREAM_CONSTRUCTOR = null;

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException, ServletException {
        // 如不提前获取,后续接口获取不到参数
        if (request instanceof MultipartHttpServletRequest) {
            MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
            Map<String, MultipartFile> fileMap = multipartRequest.getFileMap();
        }
        if (request instanceof HttpServletRequest) {
            ContentCachingRequestWrapper wrapper = buildWrapper((HttpServletRequest) request);
            Map<String, String[]> parameterMap = request.getParameterMap();
            IOUtils.copy(wrapper.getInputStream(), new ByteArrayOutputStream());  // 将请求体读取并缓存到wrapper中
            chain.doFilter(wrapper, response);
        } else {
            chain.doFilter(request, response);
        }
    }

    public ContentCachingRequestWrapper buildWrapper(HttpServletRequest request) {
        return new ContentCachingRequestWrapper(request) {
            /**
             * 重写获取流方法
             * @return
             * @throws IOException
             */
            @Override
            public ServletInputStream getInputStream() throws IOException {
                ServletInputStream inputStream = super.getInputStream();
                //流已经被读取,则读取ContentCachingRequestWrapper的缓存
                if (inputStream.isFinished()) {
                    try {
                        //利用ContentCachingRequestWrapper可重复读字节数组的特性,创建一个ServletInputStream类型的流并返回
                        return createInstanceStream(this.getContentAsByteArray());
                    } catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
                        throw new RuntimeException(e);
                    }

                } else {
                    //第一次读则返回原始流,获取原始流之后才会有缓存即ContentCachingRequestWrapper.getContentAsByteArray方法才能使用
                    return inputStream;
                }
            }
        };
    }

    /**
     * 构造BodyInputStream输入流
     *
     * @return
     */
    public ServletInputStream createInstanceStream(byte[] body) throws InvocationTargetException, InstantiationException, IllegalAccessException {
        Object instance = BODY_INPUT_STREAM_CONSTRUCTOR.newInstance(body);
        return (ServletInputStream) instance;
    }

    /**
     * 初始化BodyInputStream匿名类构造器,静态初始化加载节省性能
     */
    static {
        //使用的依赖包是 spring-webmvc-5.3.19.jar
        try {
            Class<?> clazz = Class.forName("org.springframework.web.servlet.function.DefaultServerRequestBuilder$BodyInputStream");
            Constructor<?> bodyInputStream = clazz.getDeclaredConstructor(byte[].class);
            bodyInputStream.setAccessible(true);
            BODY_INPUT_STREAM_CONSTRUCTOR = bodyInputStream;
        } catch (ClassNotFoundException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        }

    }

}
@Configuration
public class FilterConfig {
    @Bean
    public FilterRegistrationBean<CacheRequestFilter> loggingFilter() {
        FilterRegistrationBean<CacheRequestFilter> registrationBean = new FilterRegistrationBean<>();
        registrationBean.setFilter(new CacheRequestFilter());
        registrationBean.addUrlPatterns("/*"); // 适配的路径
        return registrationBean;
    }
}