别再让用户ID信息污染你的业务参数了!ThreadLocal一招搞定上下文传递

34 阅读6分钟

ThreadLocal核心概念

什么是ThreadLocal?

ThreadLocal是Java提供的线程局部变量存储类,每个线程都有自己独立的变量副本,互不干扰,实现了线程隔离的数据存储。

// 简单的使用demo
public class ThreadLocalDemo {
    // 创建一个 ThreadLocal 变量
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    
    public static void main(String[] args) {
        threadLocal.set("Main Thread Value");
        System.out.println(Thread.currentThread().getName() + ": " + threadLocal.get());
    }
}

与普通变量区别

普通变量:所有线程共享

ThreadLocal 变量 - 线程隔离

// 共享变量 - 所有线程共享
class SharedVariable {
    private static String sharedValue = "initial";
    
    public static void main(String[] args) {
        new Thread(() -> {
            sharedValue = "Thread1";
            System.out.println("Thread1: " + sharedValue);
        }).start();
        
        new Thread(() -> {
            sharedValue = "Thread2";
            System.out.println("Thread2: " + sharedValue);
        }).start();
        
        // 输出不确定,存在线程安全问题
    }
}

// ThreadLocal 变量 - 线程隔离
class ThreadLocalVariable {
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    
    public static void main(String[] args) {
        Thread thread1 = new Thread(() -> {
            threadLocal.set("Thread1 Value");
            System.out.println("Thread1: " + threadLocal.get());
        });
        
        Thread thread2 = new Thread(() -> {
            threadLocal.set("Thread2 Value");
            System.out.println("Thread2: " + threadLocal.get());
        });
        
        thread1.start();
        thread2.start();
        // 输出确定:每个线程有自己的值
    }
}

ThreadLocal的典型应用场景

场景背景

用户请求的控制器链中,多个服务层方法需要共享一些上下文信息(如用户IP、用户代理、跟踪ID等)。比如:在Spring项目中有个控制器接口,该控制器接口调用ServiceA方法,ServiceA方法又调用ServiceB方法,ServiceB方法又调用了ServiceC的方法,此时ServiceC中想要获取当前用户的IP。

传统方案

方案: 在控制器层接口中使用request获取当前用户IP,然后将IP作为参数一层层传递到ServiceC方法中。 缺点: 所有经过的方法都需要加上IP参数。如果ServiceC方法还需要其他用户信息,比如用户的UA标识等等。就需要在所有经过的方法上加上这些所需要的参数。这样会导致方法参数过多,耦合很重。

使用ThreadLocal的优化解决方案

可以在拦截器或者在控制器接口中把用户IP或者其他的一些上下文信息存放在ThreadLocal中,这样无论是哪个方法想要使用这些上下文信息,只要是在当前线程的执行路径上,都可以通过ThreadLocal直接获取。

示例实现代码

package com.example.demo.context;

import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
 * 请求上下文管理器
 * 使用 ThreadLocal 存储请求级别的数据
 */
@Component
public class RequestContext {
    
    // ThreadLocal 存储上下文数据
    private static final ThreadLocal<Map<String, Object>> CONTEXT_HOLDER = 
        ThreadLocal.withInitial(HashMap::new);
    
    // 上下文键常量
    public static final String TRACE_ID = "traceId";
    public static final String USER_IP = "userIp";
    public static final String USER_AGENT = "userAgent";
    public static final String USER_ID = "userId";
    public static final String REQUEST_START_TIME = "requestStartTime";
    
    /**
     * 设置上下文值
     */
    public static void set(String key, Object value) {
        CONTEXT_HOLDER.get().put(key, value);
    }
    
    /**
     * 获取上下文值
     */
    @SuppressWarnings("unchecked")
    public static <T> T get(String key) {
        return (T) CONTEXT_HOLDER.get().get(key);
    }
    
    /**
     * 获取上下文值,带默认值
     */
    @SuppressWarnings("unchecked")
    public static <T> T get(String key, T defaultValue) {
        T value = (T) CONTEXT_HOLDER.get().get(key);
        return value != null ? value : defaultValue;
    }
    
    /**
     * 获取 Trace ID
     */
    public static String getTraceId() {
        String traceId = get(TRACE_ID);
        if (traceId == null) {
            traceId = generateTraceId();
            set(TRACE_ID, traceId);
        }
        return traceId;
    }
    
    /**
     * 获取用户 IP
     */
    public static String getUserIp() {
        return get(USER_IP, "unknown");
    }
    
    /**
     * 获取用户代理
     */
    public static String getUserAgent() {
        return get(USER_AGENT, "unknown");
    }
    
    /**
     * 获取用户 ID
     */
    public static Long getUserId() {
        return get(USER_ID, 0L);
    }
    
    /**
     * 获取请求处理时间
     */
    public static long getRequestDuration() {
        Long startTime = get(REQUEST_START_TIME, 0L);
        if (startTime > 0) {
            return System.currentTimeMillis() - startTime;
        }
        return 0;
    }
    
    /**
     * 从当前请求中提取客户端 IP
     * 考虑了代理、负载均衡等情况
     */
    public static String extractClientIp(HttpServletRequest request) {
        // 1. 从 X-Forwarded-For 获取
        String ip = request.getHeader("X-Forwarded-For");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            // 可能有多个 IP,取第一个
            int index = ip.indexOf(",");
            if (index != -1) {
                ip = ip.substring(0, index);
            }
            return ip.trim();
        }
        
        // 2. 从 X-Real-IP 获取
        ip = request.getHeader("X-Real-IP");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            return ip.trim();
        }
        
        // 3. 从 Proxy-Client-IP 获取
        ip = request.getHeader("Proxy-Client-IP");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            return ip.trim();
        }
        
        // 4. 从 WL-Proxy-Client-IP 获取
        ip = request.getHeader("WL-Proxy-Client-IP");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            return ip.trim();
        }
        
        // 5. 从 HTTP_CLIENT_IP 获取
        ip = request.getHeader("HTTP_CLIENT_IP");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            return ip.trim();
        }
        
        // 6. 从 HTTP_X_FORWARDED_FOR 获取
        ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        if (ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip)) {
            return ip.trim();
        }
        
        // 7. 最后从 request.getRemoteAddr() 获取
        return request.getRemoteAddr();
    }
    
    /**
     * 生成跟踪 ID
     */
    private static String generateTraceId() {
        return UUID.randomUUID().toString().replace("-", "").substring(0, 16);
    }
    
    /**
     * 清理上下文
     */
    public static void clear() {
        CONTEXT_HOLDER.remove();
    }
    
    /**
     * 获取整个上下文(用于调试)
     */
    public static Map<String, Object> getAll() {
        return new HashMap<>(CONTEXT_HOLDER.get());
    }
    
    /**
     * 打印上下文(用于调试)
     */
    public static void printContext() {
        Map<String, Object> context = getAll();
        System.out.println("=== Request Context ===");
        context.forEach((key, value) -> 
            System.out.println(key + ": " + value));
        System.out.println("======================");
    }
}
package com.example.demo.interceptor;

import com.example.demo.context.RequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * 请求上下文拦截器
 * 在请求开始时初始化 ThreadLocal 上下文
 * 在请求结束后清理 ThreadLocal
 */
@Component
public class RequestContextInterceptor implements HandlerInterceptor {
    
    private static final Logger logger = LoggerFactory.getLogger(RequestContextInterceptor.class);
    
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) 
            throws Exception {
        
        // 记录请求开始时间
        RequestContext.set(RequestContext.REQUEST_START_TIME, System.currentTimeMillis());
        
        // 设置 Trace ID
        RequestContext.set(RequestContext.TRACE_ID, RequestContext.getTraceId());
        
        // 设置用户 IP
        String userIp = RequestContext.extractClientIp(request);
        RequestContext.set(RequestContext.USER_IP, userIp);
        
        // 设置 User-Agent
        String userAgent = request.getHeader("User-Agent");
        RequestContext.set(RequestContext.USER_AGENT, userAgent != null ? userAgent : "unknown");
        
        // 模拟从 Token 或 Session 中获取用户 ID
        // 实际项目中可以从 JWT Token 或 Session 中解析
        Long userId = extractUserIdFromRequest(request);
        RequestContext.set(RequestContext.USER_ID, userId);
        
        // 设置请求相关的其他信息
        RequestContext.set("requestMethod", request.getMethod());
        RequestContext.set("requestURI", request.getRequestURI());
        RequestContext.set("requestURL", request.getRequestURL().toString());
        
        // 记录请求开始日志
        logger.info("[{}] Request started: {} {} from IP: {}, User-Agent: {}", 
            RequestContext.getTraceId(),
            request.getMethod(),
            request.getRequestURI(),
            userIp,
            userAgent
        );
        
        return true;
    }
    
    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, 
                          Object handler, ModelAndView modelAndView) throws Exception {
        // 在视图渲染前执行
    }
    
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, 
                               Object handler, Exception ex) throws Exception {
        
        // 记录请求处理时间
        long duration = RequestContext.getRequestDuration();
        
        // 根据是否有异常记录不同级别的日志
        if (ex != null) {
            logger.error("[{}] Request completed with error: {} {} - Duration: {}ms, Error: {}", 
                RequestContext.getTraceId(),
                request.getMethod(),
                request.getRequestURI(),
                duration,
                ex.getMessage()
            );
        } else {
            logger.info("[{}] Request completed: {} {} - Status: {}, Duration: {}ms", 
                RequestContext.getTraceId(),
                request.getMethod(),
                request.getRequestURI(),
                response.getStatus(),
                duration
            );
        }
        
        // 重要:必须清理 ThreadLocal,防止内存泄漏
        RequestContext.clear();
    }
    
    /**
     * 从请求中提取用户 ID
     * 实际项目中可以从 JWT Token 或 Session 中解析
     */
    private Long extractUserIdFromRequest(HttpServletRequest request) {
        // 模拟从 Header 中获取用户 ID
        String userIdHeader = request.getHeader("X-User-Id");
        if (userIdHeader != null && !userIdHeader.isEmpty()) {
            try {
                return Long.parseLong(userIdHeader);
            } catch (NumberFormatException e) {
                // 解析失败,返回默认值
            }
        }
        
        // 模拟从 Cookie 中获取
        /*
        Cookie[] cookies = request.getCookies();
        if (cookies != null) {
            for (Cookie cookie : cookies) {
                if ("userId".equals(cookie.getName())) {
                    try {
                        return Long.parseLong(cookie.getValue());
                    } catch (NumberFormatException e) {
                        // 解析失败
                    }
                }
            }
        }
        */
        
        // 模拟从 Session 中获取
        /*
        HttpSession session = request.getSession(false);
        if (session != null) {
            Object userIdObj = session.getAttribute("userId");
            if (userIdObj instanceof Long) {
                return (Long) userIdObj;
            }
        }
        */
        
        // 返回默认值(未登录用户)
        return 0L;
    }
}
package com.example.demo.service;

import com.example.demo.context.RequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

/**
 * 最底层的服务
 * 需要获取用户 IP 和其他上下文信息
 */
@Service
public class ServiceC {
    
    private static final Logger logger = LoggerFactory.getLogger(ServiceC.class);
    
    /**
     * 执行最终检查
     * 这里需要获取用户 IP 进行风控检查
     */
    public String performFinalCheck() {
        // 直接从 ThreadLocal 获取所有需要的信息
        String traceId = RequestContext.getTraceId();
        Long userId = RequestContext.getUserId();
        String userIp = RequestContext.getUserIp();
        String userAgent = RequestContext.getUserAgent();
        
        // 获取请求开始时间,计算处理时长
        long duration = RequestContext.getRequestDuration();
        
        logger.info("[{}] ServiceC: Performing final check", traceId);
        logger.debug("[{}] ServiceC: User details - IP: {}, User-Agent: {}", 
            traceId, userIp, userAgent);
        logger.debug("[{}] ServiceC: Request duration so far: {}ms", traceId, duration);
        
        // 模拟基于 IP 的风控检查
        boolean riskCheckPassed = performRiskCheck(userIp, userId);
        
        if (!riskCheckPassed) {
            logger.warn("[{}] ServiceC: Risk check failed for user {} from IP {}", 
                traceId, userId, userIp);
            return "RISK_CHECK_FAILED";
        }
        
        // 记录审计日志
        logAudit(traceId, userId, userIp, userAgent);
        
        return "CHECK_PASSED";
    }
    
    /**
     * 模拟风控检查
     * 基于用户 IP 进行简单的检查
     */
    private boolean performRiskCheck(String userIp, Long userId) {
        // 这里可以实现复杂的风控逻辑
        // 例如:检查 IP 是否在黑名单中,是否高频请求等
        
        // 简单示例:禁止某些 IP 段
        if (userIp.startsWith("192.168.") || userIp.startsWith("10.")) {
            // 内网 IP,认为是安全的
            return true;
        }
        
        // 模拟检查:禁止特定 IP
        if ("123.456.789.000".equals(userIp)) {
            return false;
        }
        
        return true;
    }
    
    /**
     * 记录审计日志
     */
    private void logAudit(String traceId, Long userId, String userIp, String userAgent) {
        // 在实际项目中,这里会记录到数据库或日志系统
        logger.info("[{}] AUDIT: User {} from IP {} with agent {} performed an action", 
            traceId, userId, userIp, userAgent);
    }
    
    /**
     * 另一个需要用户信息的方法
     */
    public void anotherMethod() {
        // 同样可以直接获取上下文信息
        String traceId = RequestContext.getTraceId();
        String userIp = RequestContext.getUserIp();
        
        // 这里可以看到,我们不需要传递任何参数
        // 就可以在任意深度的调用中获取到用户信息
    }
}