项目“瑞士军刀”:全局 Spring Bean 与 Web 上下文访问工具类

50 阅读4分钟

前言

在标准 Spring 开发中,我们通常使用 @Autowired 进行依赖注入。但在一些非 Bean 管理的场景(如:静态工具类、某些三方框架的扩展点、或复杂的单例模式)中,我们无法直接注入对象。

在 Web 场景,我们只能在 Controller 的方法参数中获取 HttpServletRequestResponse。但在架构底层(如切面、拦截器、甚至是 Service 层),如果需要获取当前请求的 Header、IP 地址或 Session 信息,层层传递参数显然太笨重。

为了解决上面两个项目开发过程中的痛点,我们需要实现 ApplicationContextAware 接口,构建一个全局的 SpringContextHolder;利用 RequestContextHolder,在代码的任何角落静默获取当前线程关联的请求上下文。

ApplicationContextAware

ApplicationContextAware 接口可以感知 Spring 中 ApplicationContext,通过实现 ApplicationContextAware 接口,Spring 会在启动时自动把容器引用(ApplicationContext)“注入”到这个类中。我们把它存入一个静态变量,就能在任何地方手动获取 Bean。

基本用法:

@Component
public class SpringContextHolder implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext context) throws BeansException {
        applicationContext = context;
    }

    // 通过类型获取 Bean
    public static <T> T getBean(Class<T> clazz) {
        return applicationContext.getBean(clazz);
    }

    // 通过名称获取 Bean
    public static Object getBean(String name) {
        return applicationContext.getBean(name);
    }
}

RequestContextHolder

RequestContextHolder 利用了 ThreadLocal。Spring MVC 在接收到请求时,会将当前请求的 RequestResponse 绑定到处理该请求的线程上。只要你还在同一个线程内,就能随时“隔空取物”。

基本用法:

public class WebUtils {

    public static HttpServletRequest getRequest() {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        return attributes != null ? attributes.getRequest() : null;
    }

    // 快捷获取 Header
    public static String getHeader(String name) {
        HttpServletRequest request = getRequest();
        return request != null ? request.getHeader(name) : null;
    }
}

扩展 ApplicationContextAware 能力

下面内容是对实际项目中,常见场景的封装:

@Slf4j
public class SpringContextHolder implements ApplicationContextAware, DisposableBean {

    private static ApplicationContext applicationContext = null;

    /**
     * 获取 ApplicationContext
     */
    public static ApplicationContext getApplicationContext() {
        assertContextInjected();
        return applicationContext;
    }

    /**
     * 获取环境属性
     *
     * @param key 属性名
     * @return 属性值
     */
    public static String getProperty(String key) {
        assertContextInjected();
        return applicationContext.getEnvironment().getProperty(key);
    }

    /**
     * 获取环境属性
     *
     * @param key          属性名
     * @param defaultValue 默认值
     * @return 属性值
     */
    public static String getProperty(String key, String defaultValue) {
        assertContextInjected();
        return applicationContext.getEnvironment().getProperty(key, defaultValue);
    }

    /**
     * 获取环境属性
     *
     * @param key        属性名
     * @param targetType 目标类型
     * @param <T>        泛型
     * @return 属性值
     */
    public static <T> T getProperty(String key, Class<T> targetType) {
        assertContextInjected();
        return applicationContext.getEnvironment().getProperty(key, targetType);
    }

    /**
     * 从静态变量 applicationContext 中取得 Bean, 自动转型为所赋值对象的类型
     */
    @SuppressWarnings("unchecked")
    public static <T> T getBean(String name) {
        assertContextInjected();
        return (T) applicationContext.getBean(name);
    }

    /**
     * 从静态变量 applicationContext 中取得 Bean, 自动转型为所赋值对象的类型
     */
    public static <T> T getBean(Class<T> requiredType) {
        assertContextInjected();
        return applicationContext.getBean(requiredType);
    }

    /**
     * 发布事件
     *
     * @param event 事件对象
     */
    public static void publishEvent(ApplicationEvent event) {
        if (applicationContext == null) {
            return;
        }
        applicationContext.publishEvent(event);
    }

    /**
     * 清除 SpringContextHolder 中的 ApplicationContext 为 null
     */
    public static void clearHolder() {
        if (log.isDebugEnabled()) {
            log.debug("清除 SpringContextHolder 中的 ApplicationContext: {}", applicationContext);
        }
        applicationContext = null;
    }

    /**
     * 当 Spring 容器初始化这个 Bean 时,将 ApplicationContext 注入到静态变量中
     */
    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringContextHolder.applicationContext = applicationContext;
    }

    @Override
    public void destroy() {
        SpringContextHolder.clearHolder();
    }

    /**
     * 检查 ApplicationContext 是否注入
     */
    private static void assertContextInjected() {
        if (applicationContext == null) {
            throw new IllegalStateException(
                    "applicationContext 属性未注入, 请确保 SpringContextHolder 已注册为 Bean,且在 Spring 上下文初始化完成后调用。");
        }
    }
}

扩展 RequestContextHolder 能力

下面内容是对实际项目中,常见场景的封装:

@Slf4j
@UtilityClass
public class WebUtils {

    /**
     * 获取 HttpServletRequest
     */
    public HttpServletRequest getRequest() {
        return Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                .map(ServletRequestAttributes.class::cast)
                .map(ServletRequestAttributes::getRequest)
                .orElse(null);
    }

    /**
     * 获取 HttpServletResponse
     */
    public HttpServletResponse getResponse() {
        return Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                .map(ServletRequestAttributes.class::cast)
                .map(ServletRequestAttributes::getResponse)
                .orElse(null);
    }

    /**
     * 判断是否是 JSON 请求
     */
    public boolean isJsonRequest(HandlerMethod handlerMethod) {
        // 1. 类上有 @RestController
        if (handlerMethod.getBeanType().isAnnotationPresent(RestController.class)) {
            return true;
        }
        // 2. 方法或类上有 @ResponseBody
        if (handlerMethod.getMethodAnnotation(ResponseBody.class) != null ||
                handlerMethod.getBeanType().isAnnotationPresent(ResponseBody.class)) {
            return true;
        }
        return false;
    }

    /**
     * 渲染 JSON 数据到响应
     *
     * @param response 响应对象
     * @param json     JSON 字符串
     */
    public void renderJson(HttpServletResponse response, String json) {
        if (response == null) {
            log.warn("HttpServletResponse 为空,跳过 JSON 渲染");
            return;
        }
        Assert.notNull(json, "JSON 字符串不能为空");

        response.setCharacterEncoding(StandardCharsets.UTF_8.name());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        try (PrintWriter writer = response.getWriter()) {
            writer.write(json);
            writer.flush();
        } catch (IOException e) {
            log.error("渲染 JSON 失败", e);
        }
    }

    /**
     * 获取客户端 IP
     * <p>
     * 考虑了 Nginx 等反向代理的场景
     *
     * @return IP 地址
     */
    public String getIpAddr() {
        HttpServletRequest request = getRequest();
        if (request == null) {
            return null;
        }
        return getIpAddr(request);
    }

    /**
     * 获取客户端 IP
     *
     * @param request 请求对象
     * @return IP 地址
     */
    public String getIpAddr(HttpServletRequest request) {
        if (request == null) {
            return null;
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        return ip;
    }

    /**
     * 获取请求头
     *
     * @param name header 名称
     * @return header 值
     */
    public String getHeader(String name) {
        HttpServletRequest request = getRequest();
        return request != null ? request.getHeader(name) : null;
    }

    /**
     * 获取 Parameter
     *
     * @param name parameter 名称
     * @return parameter 值
     */
    public String getParameter(String name) {
        HttpServletRequest request = getRequest();
        return request != null ? request.getParameter(name) : null;
    }
}