从源码的层次解析TransmittableThreadLocal

905 阅读9分钟

「这是我参与2022首次更文挑战的第5天,活动详情查看:2022首次更文挑战

一:简述

我们知道要实现父子线程之间的数据传递,可以使用InheritableThreadLocal,因为InheritableThreadLocal会在子线程调用构造方法的时候将父线程的数据拷贝到子线程中。然而我们在实际开发中大部分时候并不会直接new线程,而是使用线程池来复用线程,这种情况下InheritableThreadLocal是无法保证数据的传递的。而TransmittableThreadLocal(简称ttl)可以帮助我们解决这个问题,今天我们就聊聊TransmittableThreadLocal的实现原理。

二:如何使用TransmittableThreadLocal

引入依赖,这里使用2.12.0版本,后续的源码分析也是基于此版本。

        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>transmittable-thread-local</artifactId>
            <version>2.12.0</version>
        </dependency>

方式一:修饰Runnable和Callable

通过TtlRunnable获取包装后的Runnable(或者Callable)来实现。

public static void main(String[] args) {
        TransmittableThreadLocal<Integer> ttl = new TransmittableThreadLocal<>();
        ttl.set(1);

        Executor executor = Executors.newFixedThreadPool(4);

        //TtlCallable.get()
        executor.execute(TtlRunnable.get(()->{
            Integer integer = ttl.get();
            System.out.println(integer);
        }));
    }

方式二:修饰线程池

通过TtlExecutors获取包装后的线程池来实现。

public static void main(String[] args) {
        TransmittableThreadLocal<Integer> ttl = new TransmittableThreadLocal<>();
        ttl.set(1);

        Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(4));

        ttlExecutor.execute(()->{
            Integer integer = ttl.get();
            System.out.println(integer);
        });
    }

方式三:使用Java Agent来修饰JDK线程池实现类

这种方式实现线程池的传递是透明的,业务代码中没有修饰Runnable或是线程池的代码。即可以做到应用代码无侵入。只需要下载transmittable-thread-local的jar包,并且修改Java的启动参数,使用-javaagent参数指定jar包的路径:

-javaagent:xx/xx/transmittable-thread-local-2.12.0.jar

这种方法是通过java Agent修饰了线程池,因为这种方式对于代码是无侵入的,所以大部分时候我们使用这种方式。但是Java Agent的方式也会有一些问题,例如和其他Agent一起使用时会失效的问题。(目前的解决方案就是把ttl的Agent放在在最前面)

这里对于ttl的使用就不再过多的阐述,如果有疑问可以去翻一翻github上的官方文档以及一些issues看看。地址:transmittable-thread-local的github地址

三:TransmittableThreadLocal的实现原理

首先ttl继承了InheritableThreadLocal,并且重写了set()方法和get()方法,所以ttl是带有InheritableThreadLocal的功能的。

接下来我们通过源码解答以下几个问题

问题一:ttl的数据如何存储

这点我们直接看ttl的get(),set(),remove()三个方法就可以知道。

set()方法

如果需要忽略空值,而且set方法的value是null,那么就调用remove()方法尝试删除值。否则先调用父类InheritableThreadLocal的set方法,这样是为了兼容InheritableThreadLocal的功能,真正存储ttl数据的方法是addThisToHolder()。

public final void set(T value) {
        //disableIgnoreNullValueSemantics 表示是否忽略空值 默认是false 如果value为null 并且 忽略空值 那么就调用remove()方法删除值
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            //调用父类的set方法 也就是InheritableThreadLocal的set方法
            //为了兼容InheritableThreadLocal的功能
            super.set(value);
            //然后调用addThisToHolder()方法保存值
            addThisToHolder();
        }
    }

addThisToHolder()

先判断holder中是否有当前的ttl,如果没有的话就添加。通过这个方法我们可以知道数据是存在在holder的静态成员变量中。

private void addThisToHolder() {
        if (!holder.get().containsKey(this)) {
            holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
        }
    }

通过holder的赋值代码我们可以知道,holder是一个泛型为WeakHashMap<TransmittableThreadLocal, ?>的InheritableThreadLocal,并且重写了InheritableThreadLocal的initialValue()和childValue()方法。

注:WeakHashMap是一个key为弱引用的map,ttl把WeakHashMap当做一个set使用,key为当前的ttl,而value一直为null。使用WeakHashMap也是为了避免内存泄漏的问题。

    private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

get()方法

因为set方法先调用时调用了父类的set()方法,所以通过父类的get方法可以获取到value,获取到之后判断是否为null,然后调用addThisToHolder()。

    public final T get() {
        T value = super.get();
        if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
        return value;
    }

remove()

调用removeThisFromHolder()方法删除holder中的值,然后调用父类的remove()方法删除InheritableThreadLocal中的值。

public final void remove() {
        removeThisFromHolder();
        super.remove();
    }

removeThisFromHolder()

通过holder移除当前的ttl

    private void removeThisFromHolder() {
        holder.get().remove(this);
    }

小结:

ttl的数据是保存在holder的成员变量中,而holder是一个泛型为WeakHashMap的InheritableThreadLocal,其中WeakHashMap的key存储当前的ttl,而value一直为null, 同时为了兼容InheritableThreadLocal的功能,在get(),set(),remove()方法中都调用了对应的父类方法,也就同时维护了InheritableThreadLocal中的数据。

问题二:ttl是如何将主线程的值传递下去的呢?

其实我们可以猜测出TransmittableThreadLocal的实现原理,无非就是要在任务提交前将主线程的数据复制到执行线程中,这样就可以达到数据传递的目的了。所以我们可以想到利用装饰器模式或者代理模式对Runnable(或者Callnable)进行功能的增强即可。

接下来我们分析源码看看ttl的作者是怎么做的?

我们以封装线程池的代码为入口看其实现原理

Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(4));

如果Executor不为空,没有通过Agent加载并且也没有被包装过,那么new一个包装的Executor并且返回。

getTtlExecutor()

public static Executor getTtlExecutor(@Nullable Executor executor) {
        if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
            return executor;
        }
        return new ExecutorTtlWrapper(executor, true);
    }

我们可以看到这是一个装饰器模式,那么肯定会对关键的submit(),execute()方法的功能进行增强。

1644934986(1).png

我们看ExecutorTtlWrapper的execute方法,增强的部分就是将传进来的Runnable通过TtlRunnable.get()方法包装成TtlRunnable,所以其实无论是对线程池的包装还是对Runnable的包装其实本质是一样的,最后都是对Runnable任务进行包装(Agent方式其实也是对线程池进行包装,只不过是在加载类的时候隐式替换JDK的相应类)。

public void execute(@NonNull Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }

我们看TtlRunnable的get()方法,判断Runnable是否已经被包装过,如果没有就利用原Runnable创建一个TtlRunnable并且返回。

public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        if (null == runnable) return null;

        if (runnable instanceof TtlEnhanced) {
            // avoid redundant decoration, and ensure idempotency
            if (idempotent) return (TtlRunnable) runnable;
            else throw new IllegalStateException("Already TtlRunnable!");
        }
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
    }

我们可以看到TtlRunnable也是使用了装饰器模式,线程池会执行Runnable任务的run()方法,那么肯定会对关键的run()方法的功能进行加强。

1644935654(1).jpg

所以我们看run()方法,果然不出所料

我们看TtlRunnable的run()方法中的具体实现

TtlRunnable#run()

public void run() {
        //获取主线程的数据
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
        //复制主线程的数据到当前任务线程中
        final Object backup = replay(captured);
        try {
            //执行原任务的逻辑
            runnable.run();
        } finally {
            //恢复原来的数据
            restore(backup);
        }
    }

包装后的run()方法对数据的处理包含三步

第一步:获取主线程数据的快照

第二步:将主线程的数据保存到执行线程中

第三步:恢复执行线程原来的数据

我们先看第一步:

capturedRef是一个原子变量,里面保存了父线程的数据快照

this.capturedRef = new AtomicReference<Object>(capture());

captureTtlValues()方法获取主线程的ttl的value,而captureThreadLocalValues()保存主线程ThreadLocal的value,将它们封装成Snapshot对象返回。

 public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

captureTtlValues()

captureTtlValues()获取TransmittableThreadLocal 的快照。

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    // 从 TransmittableThreadLocal的holder 中
    //遍历holder中所有的 TransmittableThreadLocal,将TransmittableThreadLocal 取出和值复制到 Map 中。
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

captureThreadLocalValues()

captureThreadLocalValues()获取ThreadLocal的快照。

private static WeakHashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
    final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value = new WeakHashMap<ThreadLocal<Object>, Object>();
    // 从threadLocalHolder中,遍历所有的ThreadLocal,将ThreadLocal值复制到Map中。
    for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        final TtlCopier<Object> copier = entry.getValue();

        threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
    }
    return threadLocal2Value;
}

接下来看第二步,复制主线程的值到当前执行线程中,这部分逻辑由replay()方法完成。

        public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

replay()方法调用replayTtlValues()和replayThreadLocalValues()两个方法分别对父线程的ttl和ThreadLocal进行复制,然后将当前执行线程的原值保存起来,用于之后的恢复操作。

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
  
    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // 遍历holder,将原数据保存到backup中 用于恢复
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        // 如果不是主线程快照中的值,那么就先移除
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    //覆盖holder中的值
    setTtlValuesTo(captured);

    // TransmittableThreadLocal 的回调方法,在任务执行前执行。
    doExecuteCallback(true);

    return backup;
}
private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        //调用ttl的set()方法的时候会把TransmittableThreadLocal注册到holder中
        threadLocal.set(entry.getValue());
    }
}
private static WeakHashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> captured) {
    final WeakHashMap<ThreadLocal<Object>, Object> backup = new WeakHashMap<ThreadLocal<Object>, Object>();

    for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        //将原数据保存到backup中 用于恢复
        backup.put(threadLocal, threadLocal.get());

        final Object value = entry.getValue();
        // 如果值是标记已删除,则清除
        if (value == threadLocalClearMark) threadLocal.remove();
        else threadLocal.set(value);
    }

    return backup;
}

第三步恢复执行线程原有的数据

利用第二部保存的原数据进行恢复,在restore()方法中进行数据的恢复。为什么要对数据进行恢复呢,照理任务执行完成之后就可以不管了,但是线程池有如果使用的是CallerRunsPolicy的拒绝策略,那么当前执行线程就是主线程,如果执行线程过程中对数据进行了修改,而不恢复原有的数据,那么就会导致之后的线程获取不到正确的数据了。

public static void restore(@NonNull Object backup) {
            final Snapshot backupSnapshot = (Snapshot) backup;
            restoreTtlValues(backupSnapshot.ttl2Value);
            restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
        }

restoreTtlValues()方法用于恢复ttl数据

private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
            // call afterExecute callback
            doExecuteCallback(false);

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // clear the TTL values that is not in backup
                // avoid the extra TTL values after restore
                //原先备份数据中没有的直接删除
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // restore TTL values
            setTtlValuesTo(backup);
        }

restoreThreadLocalValues()方法恢复执行线程原来的ThreadLocal的值

private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
            for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
                final ThreadLocal<Object> threadLocal = entry.getKey();
                threadLocal.set(entry.getValue());
            }
        }

四:总结

ttl帮助我们解决了InheritableThreadLocal的短板,线程复用情况下数据传递不了的问题。而我们通过猜想并且通过阅读源码对我们的猜想进行了验证。

如果有任何疑问或者不足欢迎在下方留言。另外如果本篇文章对你有所帮助,那么点个赞再走吧。

相关文章:

# 用了这么久ThreadLocal,它的原理你还不懂吗

# ThreadLocal的plus版--InheritableThreadLocal