TransmittableThreadLocal-学习

624 阅读3分钟

TransmittableThreadLocal,文中简称TTL

InheritableThreadLocal,文中简称ITL

Thread.inheritableThreadLocals字段,简称itls

简介

TTL产生的背景

TTL是为了解决线程复用时,ThreadLocal在父子线程之间的值传递问题。如果不存在线程复用问题,ITL完全可以满足需求。

ITL、TTL区别

ITL是在线程创建时,将主线程的inheritableThreadLocals赋值到子线程,而TTL是在线程执行前,完成赋值

整体流程

  1. 主线程:保存使用的ttl,调用TTL#set、#get后,将ttl设置到holder中

  2. 主线程:初始化子线程,TtlRunnable#get。并快照holder中的ttl,Transmitter#caputre

  3. 子线程:执行#run之前,设置主线程的ttl快照。TtlRunnable#run => Transmitte#replay

    1. 保存子线程ttl的值作为备份
    2. 将ttl快照中的值设置到子线程中
  4. 子线程:#run执行完成后,将备份恢复Transmitte#restore

流程分析

image.png

image.png

核心类

TTL类

  • TTL继承ITL,重写#set、#get方法,并在方法中增加#addThisToHolder方法

  • #addThisToHolde:将ttl对象作为key,存储到holder中,等待子线程使用

  • 新增#beforeExecute、#afterExecute回调方法

holder属性

  • holder的类型是:ITL。使用private static final修饰

  • holder用来存储项目中所有的ttl对象,也是按线程隔离

    • 只会关注使用过的ttl对象,调用过TTL#set、#get方法
    • 只会关注TTL类型的ThreadLocal

TtlRunnable

  • 工厂设计模式。通过#get,生成TtlRunnable,并将holder中记录的ttl进行快照
  • 装饰器模式,对Runnbale进行增强。
    • 执行#run方法之前,将复制的ttl对象值,设置到子类的itls中
    • 执行#run方法之后,将子类的ttl还原

Transmitter

  • 用来在线程切换时进行数据的快照保存(capture)、重放(replay)和恢复(restore)

先看Demo

public class TtlTest2 {

    private static final ThreadLocal<String> ttl1 = new TransmittableThreadLocal<>();

    private static final CountDownLatch countDownLatch = new CountDownLatch(1);

    public static void main(String[] args) throws InterruptedException {

        ttl1.set("main");

        Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(20));

        ttlExecutor.execute(() -> {
            System.out.println("log1-> ttl: " + ttl1.get());
            ttl1.set("thread");
            System.out.println("log2-> ttl: " + ttl1.get());

            countDownLatch.countDown();

        });

        countDownLatch.await();
        System.out.println("log3-> ttl: " + ttl1.get());
    }

}

执行结果:
log1-> ttl: main
log2-> ttl: thread
log3-> ttl: main

  • #capture
public static Object capture() {
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

/**
 * 复制线程在holder中的ttl对象
 * 返回值Map:key:ttl对象,value为ttl对象的值(这里是浅拷贝,直接引用的)
 * Demo中主线程创建了ttlA对象,set了值为main,这里Map返回的就是 <ttl1,main>
 */
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

/**
 * threadLocalHolder,同样支持ThreadLocal
 */
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
    for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        final TtlCopier<Object> copier = entry.getValue();

        threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
    }
    return threadLocal2Value;
}
  • #replay

public static Object replay(@NonNull Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

/**
 * 备份当前线程的holder中的ttl值,把ttl快照<ttl1,main>设置到当前线程中
 */
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // 备份ttl
        backup.put(threadLocal, threadLocal.get());

   
        // 删除快照中不存在的ttl【待研究】
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 线程ttl赋值快照的ttl<ttl1,main>,ttl1设置为main
    setTtlValuesTo(captured);

    // 执行#beforeExecute回调
    doExecuteCallback(true);

    return backup;
}


private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}


private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
    final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<ThreadLocal<Object>, Object>();

    for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        backup.put(threadLocal, threadLocal.get());

        final Object value = entry.getValue();
        if (value == threadLocalClearMark) threadLocal.remove();
        else threadLocal.set(value);
    }

    return backup;
}
  • #restore
public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}

private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
    // 执行#afterExecute回调
    doExecuteCallback(false);

   
    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        // 子线程中,ttl1的value已经由 mian -> thread了
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // clear the TTL values that is not in backup
        // avoid the extra TTL values after restore
        if (!backup.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 还原ttl1中的值,thread -> main
    setTtlValuesTo(backup);
}

private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
    for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

使用注意:

使用增强过的Runnable

使用TtlExecutor#getTtlExecutorService获取线程池 或者 重写ThreadPoolTaskExecutor的方法,主动调用TtlRunnable#get


public class MyThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {


    @Override
    public void execute(@Nullable Runnable task) {
        Runnable runnable = TtlRunnable.get(task);
        assert runnable != null;
        super.execute(runnable);
    }

深-浅拷贝问题

问题Demo

public class TtlTest3 {

    private static final ThreadLocal<User> ttl1 = new TransmittableThreadLocal<>();

    private static final CountDownLatch countDownLatch = new CountDownLatch(1);

    public static void main(String[] args) throws InterruptedException {

        ttl1.set(new User("张三", 20));

        Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(20));

        ttlExecutor.execute(() -> {
            System.out.println("log1-> ttl: " + ttl1.get());
            User user = ttl1.get();
            user.setName("李四");
            ttl1.set(user);
            System.out.println("log1-> ttl: " + ttl1.get());

            countDownLatch.countDown();
        });

        countDownLatch.await();

        System.out.println("log3-> ttl: " + ttl1.get());
    }

}

输出:
log1-> ttl: User(name=张三, age=20)
log2-> ttl: User(name=李四, age=20)
log3-> ttl: User(name=李四, age=20)

解决方案: 重写TTL#copy

public class UserTtl extends TransmittableThreadLocal<UserInfo> {

    @Override
    public UserInfo copy(UserInfo parentValue) {
        return new UserInfo(parentValue.getName(),parentValue.getAge());
    }
}