ThreadLocal使用的技巧

339 阅读5分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第6天,点击查看活动详情

ThreadLocal是什么

ThreadLocal顾名思义与线程相绑定,它的出现也是为了解决线程竞争;线程竞争一般采用加锁来保证互斥,当然也可以采用线程隔离。线程隔离是指每个线程操作各自的变量,互不相关。这也是解决线程竞争的解决办法。

ThreadLocal的特点

  1. ThreadLocal可以为当前线程关联一个数据。(它可以像Map一样存取数据,key为当前线程)
  2. 每一个 ThreadLocal对象,只能为当前线程关联一个数据,如果要为当前线程关联多个数据,就需要使用多个ThreadLocal对象实例。
  3. 每个ThreadLocal对象实例定义的时候,一般都是static类型
  4. ThreadLocal中保存数据,在线程销毁后。会由JVM虚拟自动释放。

ThreadLocal如何使用

常用方法
  1. set(T value): 设置当前线程的线程局部变量的值。
  2. get(): 获取当前线程的线程局部变量的值。
  3. remove(): 移除当前线程的线程局部变量的值。
使用场景:
场景一:spring中的事务

在Spring的@Transaction事务声明的注解中就使用ThreadLocal保存了当前的Connection对象,避免在本次调用的不同方法中使用不同的Connection对象,这样才能保证事务的acid特性,否则无法保证。

场景二:全局存储用户信息

可以尝试使用ThreadLocal替代Session的使用,当用户要访问需要授权的接口的时候,可以现在拦截器中将用户的Token存入ThreadLocal中;之后在本次访问中任何需要用户用户信息的都可以直接冲ThreadLocal中拿取数据。

场景三:亡羊补牢

面对一个无法改造的旧的项目,需要在原有的一个比较复杂的调用链路上新增一个参数并向下传递,如果我们再每个方法上都新增一个参数,那么这种成本较高,改动粒度也较大。对测试场景覆盖面我们难以把控,面对这种问题,我们可以通过ThreadLocal来保存参数,在其他方法中通过ThreadLocal来获取。

项目中的实际应用

项目中存在一个数据上报的功能: 将业务数据上报到大数据中心。面对这种需求,为了便于开发人员尽量减少对原有代码逻辑的改动情况下,我设计了一个组件,其中就需要ThreadLocal来保存业务参数;

public class TransportContext {

    /**
     * 用于数据传递
     */
    private static final ThreadLocal<TransportContext> LOCAL =
            ThreadLocal.withInitial(TransportContext::new);

    /**
     * 附加参数 便于属性传递
     */
    private final Map<String, Object> attachments = new HashMap<>();

    /**
     * 用于判断自动上报标记
     */
    private Boolean autoReportEnable = Boolean.FALSE;

    /**
     * @see AsyncInvokerFilter
     */
    private boolean enableAsync = true;

    /**
     * 指定topci
     */
    private String topic;

    /**
     * 耗时日志监控
     */
    private long costTimeLogEnable = 200;

    public TransportContext() {
    }

    public static TransportContext getContext() {
        return LOCAL.get();
    }

    /**
     * 清理上下文
     */
    public static void cleanContext() {
        // 清理上下文数据
        TransportContext.getContext().remove();
    }

    public Object getAttachment(String key) {
        return attachments.get(key);
    }

    /**
     * get attachment.
     *
     * @param key
     * @return attachment
     */
    public Object getObjectAttachment(String key) {
        return attachments.get(key);
    }

    /**
     * 设置
     *
     * @param key
     * @param value
     * @return
     */
    public TransportContext setAttachment(String key, Object value) {
        return setObjectAttachment(key, value);
    }

    /**
     * 设置值
     *
     * @param key
     * @param value
     * @return
     */
    public TransportContext setObjectAttachment(String key, Object value) {
        if (value == null) {
            attachments.remove(key);
        } else {
            attachments.put(key, value);
        }
        return this;
    }

    public TransportContext removeAttachment(String key) {
        attachments.remove(key);
        return this;
    }

    public void remove() {
        LOCAL.remove();
    }

    public Map<String, Object> getAttachments() {
        return attachments;
    }

    public Boolean getAutoReportEnable() {
        return autoReportEnable;
    }

    public void setAutoReportEnable(Boolean autoReportEnable) {
        this.autoReportEnable = autoReportEnable;
    }

    public boolean isEnableAsync() {
        return enableAsync;
    }

    /**
     * 设置是否启用异步
     *
     * @param enableAsync
     */
    public void setEnableAsync(boolean enableAsync) {
        this.enableAsync = enableAsync;
    }

    public String getTopic() {
        return topic;
    }

    public void setTopic(String topic) {
        this.topic = topic;
    }

    public long getCostTimeLogEnable() {
        return costTimeLogEnable;
    }

    public void setCostTimeLogEnable(long costTimeLogEnable) {
        this.costTimeLogEnable = costTimeLogEnable;
    }
}

业务方在使用时不需要额外去封装一个对象来存储业务参数,而是直接通过这个上下文来存储业务参数。我们再切面层拦截到时就可以从上下文中渠道对应的参数值,最后数据上报完成后上下文中的数据就应该被清除掉。

切面层数据上报:

@Slf4j
public class TransportControllerAspect {

    /**
     * 配置中心
     */
    private final TransportConfigService configService;
    private TransportBootStrap bootStrap;

    /**
     * 注解配置解析器
     */
    private ExportConfigResolver configResolver = DefaultExportConfigResolver.getInstance();

    public TransportControllerAspect(TransportConfigService configService, TransportBootStrap bootStrap) {
        this.configService = configService;
        this.bootStrap = bootStrap;
    }

    /**
     * 配置注解
     */
    @Pointcut(value = "@annotation(com.psd.commons.transport.annotations.ExportActive)")
    public void annotationPointCut() {

    }

    /**
     * 方法环绕
     *
     * @param joinPoint
     * @return
     */
    @Around(value = "annotationPointCut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        // 请求时间
        long requestTime = Constants.getCurrentTime();
        // 返回结果
        Object proceed = joinPoint.proceed();
        // 结束时间
        long endTime = Constants.getCurrentTime();

        try {
            //  controller 拦截到了 然后biz也拦截了 此时应该只有一次拦截 需要丢弃后一个的数据 以前面的为准
            if (TransportContext.getContext().getAutoReportEnable()) {
                // 说明当前已经存在 本次不在进行上报 避免重复
                log.warn("transport circle is disable");
                return proceed;
            }
            // 非方法级别拦截 直接跳过
            if (!(joinPoint.getSignature() instanceof MethodSignature)) {
                return proceed;
            }

            // 强转
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            // 方法参数
            Object[] args = joinPoint.getArgs();

            // step1 解析注解 获取注解信息 获取配置
            ExportActiveConfig exportActiveConfig = configResolver.resolver(signature);
            if (!configResolver.isEnableTransport(exportActiveConfig, configService)) {
                // 清理上下文
                TransportContext.cleanContext();
                return proceed;
            }

            // step2 设置当前线程自动上报标记.参数确定
            TransportContext.getContext().setAutoReportEnable(Boolean.TRUE);
            TransportContext.getContext().setTopic(exportActiveConfig.getTopic());

            // step请求参数处理
            StandardTransportEnvironment requestEnvironment = null;
            if (exportActiveConfig.getNeedRequest()) {
                requestEnvironment = new StandardTransportEnvironment(args, signature.getParameterTypes(),
                        Constants.REQUEST_RESOURCE_NAME);
            }
            TransportEnvironment respEnvironment = null;
            if (exportActiveConfig.getNeedResp()) {
                respEnvironment = new ResponseStandardTransportEnvironment(proceed);
            }

            // step4 执行数据整形
            invokerDataReshapes(exportActiveConfig, requestEnvironment, respEnvironment);

            // step5 解析配置
            TransportMetaData exportMetadata = configResolver.getExportMetadata();
            // 解析发送方 接收方
            String sendUserId = configResolver.resolverSendUserId(exportActiveConfig, exportMetadata, signature.getMethod(), args);
            String receiverUserId = configResolver.resolverReceiverUserId(exportActiveConfig, exportMetadata, signature.getMethod(), args);

            // step6 构造协议数据
            Invocation invocation = buildInvocation(exportActiveConfig, exportMetadata, requestEnvironment,
                    respEnvironment, sendUserId, receiverUserId, requestTime, endTime);

            // step7 执行
            bootStrap.start(invocation);
        } catch (Exception e) {
            log.error("transport parse error", e);
        }
        return proceed;
    }

    /**
     * 构建执行元数据
     *
     * @param exportActiveConfig
     * @param exportMetadata      元数据
     * @param requestEnvironment  request
     * @param responseEnvironment response
     * @param sendUserId          发送方  可能为null
     * @param receiverUserId      接收方 可能为空
     * @param requestTime         请求时间
     * @param endTime             结束时间
     * @return
     */
    private Invocation buildInvocation(ExportActiveConfig exportActiveConfig, TransportMetaData exportMetadata,
                                       TransportEnvironment requestEnvironment, TransportEnvironment responseEnvironment,
                                       String sendUserId, String receiverUserId, long requestTime, long endTime) {
        // 获取附加参数
        Map<String, Object> attachments = TransportContext.getContext().getAttachments();

        // 构造执行元数据
        return new TransportInvocation(exportActiveConfig.getGroupKey(), exportActiveConfig.getActionKey(),exportActiveConfig.getVersion(),
                exportMetadata, attachments, sendUserId, receiverUserId, requestEnvironment, responseEnvironment, requestTime, endTime);
    }

    /**
     * 数据整形
     *
     * @param exportActiveConfig
     * @param requestEnvironment
     * @param respEnvironment
     * @return
     */
    private void invokerDataReshapes(ExportActiveConfig exportActiveConfig, TransportEnvironment requestEnvironment, TransportEnvironment respEnvironment) {
        if (requestEnvironment == null && respEnvironment == null) {
            return;
        }
        List<DataReshape> dataReshapes = TransportExtensionLoader.getExtensionLoader(DataReshape.class).getActivateExtension(exportActiveConfig.getGroupKey(),
                exportActiveConfig.getActionKey(), null);

        if (CollectionUtils.isEmpty(dataReshapes)) {
            return;
        }
        log.warn("data reshape start");
        // 数据整形
        for (DataReshape dataReshape : dataReshapes) {
            dataReshape.reshape(exportActiveConfig.getGroupKey(), exportActiveConfig.getActionKey(), requestEnvironment, respEnvironment);
        }
        log.warn("data reshape end");
    }

数据清除
@Slf4j
@Order(-1000)
public class ContextFilter implements Filter {

    @Override
    public void invoker(Invoker invoker, Invocation invocation) {
        // 获取开始上报时间
        long startTransportTime = invocation.startTransportTime();
        // 获取耗时日志时长
        long costTimeLogEnable = TransportContext.getContext().getCostTimeLogEnable();

        try {
            invoker.invoke(invocation);
        } finally {
            // 清除上下文
            TransportContext.cleanContext();
            // 计算耗时时间
            long transportCostTime = Constants.getCurrentTime() - startTransportTime;
            if (transportCostTime > costTimeLogEnable) {
                log.warn("transport time threshold type:{},action:{} costTime:{}", invocation.group(), invocation.actionKey(), transportCostTime);
            }
        }
    }

    @Override
    public String value() {
        return "contextFilter";
    }
}

以上就是我对ThreadLocal的使用相关介绍。