ThreadLocal, InheritableThreadLocal, TransmittableThreadLocal

252 阅读5分钟

1. 说明

本文不讲基础,旨在讨论使用 ThreadLocal 中需要注意的一些问题。

2. ThreadLocal

1. 结构

threadLocal.png

  1. 每个线程里面有一个局部变量 - threadLocalMap,所以threadLocalMap天然做到了线程隔离。
  2. threadLocal.get()的原理:通过 Thread.currentThread()获得当前的线程,然后拿到 threadLocalMap,通过 key = threadLocal,就可以定位到线程绑定的值。

2. 为什么 threadLocalMap的key要设置成弱引用?

简单说下弱引用:

如果一个对象只被弱引用对象引用,那么当内存不足的时候会回收该对象。

强引用

ts1.png

  1. 假设在业务代码中使用完ThreadLocal,threadLocal Ref被回收了
  2. 当前线程一直活跃(例如线程池里面的核心线程),所以当前线程的threadLocalMap的Entry强引用了threadLocal,造成threadLocal无法被回收。

弱引用

ts.png

  1. 假设在业务代码中使用完ThreadLocal,threadLocal Ref被回收了

  2. 由于entry的key是弱引用,所以当内存不足的时候,会把 threadLocal 设置为null。

  3. 单GC仅仅是让key的内存释放,后续还要根据 key 是否为 null 来进一步释放值的内存,释放时机有

    1. 获取 key 发现 null key
    2. set key 时,会使用启发式扫描,清除临近的 null key,启发次数与元素个数,是否发现 null key 有关
    3. remove 时(推荐),因为一般使用 ThreadLocal 时都把它作为静态变量,因此 GC 无法回收

这就意味着使用完ThreadLocal,CurrentThread依然运行的前提下,就算忘记调用remove方法,弱引用比强引用可以多一层保障:弱引用的ThreadLocal会被回收,对应的value在下一次ThreadLocalMap调用set,get,remove中的任一方法的时候会被清除,从而避免内存泄漏。但是通常我们的 threadLocal 都会被设置成一个静态变量,提供全局访问的API,所以强烈要求每次使用完 remove()

3. 脏数据

线程池情况下,线程会被复用,所以可能会使用之前threadLocal设置的值。比如SpringMVC执行请求的线程池。所以每次请求执行完成后手动remove。

3. InheritableThreadLocal

解决什么?

public static void main(String[] args) throws InterruptedException {
    ThreadLocal<String> threadLocal = new ThreadLocal<>();
    threadLocal.set("父线程设置的值");
    TimeUnit.MILLISECONDS.sleep(
        500
    );
    new Thread(() -> System.out.println("子线程获取的值:" + threadLocal.get())).start();
}
// 子线程获取的值:null

现在想要达成的效果是 子进程能够共享父进程的值

public static void main(String[] args) throws InterruptedException {
    ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
    threadLocal.set("父线程设置的值");
    TimeUnit.MILLISECONDS.sleep(
        500
    );
    new Thread(() -> System.out.println("子线程获取的值:" + threadLocal.get())).start();
}
// 子线程获取的值:父线程设置的值

问题

public static void main(String[] args) throws InterruptedException {
    ExecutorService executor = Executors.newFixedThreadPool(1);
    Student student = new Student();
    student.setUsername("父线程");
    ThreadLocal<Student> threadLocal = new InheritableThreadLocal<>();
    threadLocal.set(student);
    executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
    TimeUnit.MILLISECONDS.sleep(500);
    student.setUsername("父进程新设置的值");
    executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
    // 子线程获得值 : Student(username=父线程)
    // 子线程获得值 : Student(username=父进程新设置的值)
}

其实说实话吧,这个问题也不是问题,其实就是父线程对student的改变对子线程可见。原因是因为父线程new子线程的时候,会将自身 InheritableThreadLocal 的值拷贝给 子线程的InheritableThreadLocal,这个时候拷贝是直接拷贝的student的引用(浅拷贝)。所以自然而然父线程对引用的任何修改,子线程都能看见。

如果你想要深拷贝,你只需要继承 InheritableThreadLocal ,实现关键方法即可

public static void main(String[] args) throws InterruptedException {
    @SuppressWarnings("all")
    ThreadLocal<Student> threadLocal = new InheritableThreadLocal() {
        @Override
        protected Object childValue(Object parentValue) {
            // 实现深拷贝
            if (parentValue instanceof Student) {
                Student student = (Student) parentValue;
                Student copyStudent = new Student();
                copyStudent.setUsername(student.getUsername());
                return copyStudent;
            } else {
                return super.childValue(parentValue);
            }
        }
    };
    ExecutorService executor = Executors.newFixedThreadPool(1);
    Student student = new Student();
    student.setUsername("父线程");
    threadLocal.set(student);
    executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
    TimeUnit.MILLISECONDS.sleep(500);
    student.setUsername("父进程新设置的值");
    executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
    // 子线程获得值 : Student(username=父线程)
    // 子线程获得值 : Student(username=父线程)
}

InheritableThreadLocal 的原理其实就是 父线程构建 子线程的时候会copy一下父线程的InheritableThreadLocal,这里的 copy 默认是直接拷贝引用。但是 InheritableThreadLocal 给这个copy提供了一个扩展,也就是说你可以去继承 InheritableThreadLocal 实现你自己的copy逻辑。将浅拷贝逻辑转化为深拷贝逻辑。

4. TransmittableThreadLocal

解决什么?

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Demo {

    /**
     * 模拟 tomcat 线程池
     */
    static ExecutorService EXECUTOR = Executors.newFixedThreadPool(1);
    static ThreadLocal<String> TRACE_ID_HOLDER = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        for (int i = 0; i < 3; i++) {
            String traceId = i + 1 + "";
            System.out.println("主线程获得 traceId : " + traceId);
            mockReq(traceId);
            // 主线程获得 traceId : 1
            // 主线程获得 traceId : 2
            // 主线程获得 traceId : 3
            // 子线程获得 traceId : 1
            // 子线程获得 traceId : 1
            // 子线程获得 traceId : 1
        }
    }

    /**
     * mock req
     *
     * @param traceId
     */
    public static void mockReq(String traceId) {
        TRACE_ID_HOLDER.set(traceId);
        EXECUTOR.execute(
                () -> System.out.println("子线程获得 traceId : " + TRACE_ID_HOLDER.get())
        );
    }
}

这个场景是 日志traceId透传。可以看到开启的任务只有第一次能拿到请求的traceId,后面的还是拿到的第一次请求的traceId。这是因为 InheritableThreadLocal 是通过 new Thread 帮助子线程拷贝 InheritableThreadLocal 相关值。但这是线程池,线程池里面只有一个线程,也就是 new Thread的机会只有一次,自然而然也就是第一次。

解决:

原理:

ts3.png

使用:

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


import static com.hdu.transmitableThreadLocal.TransmitableRunnable.wrap;

public class Main {
    public static void main(String[] args) {
        TransmitableThreadLocal<String> threadLocal = new TransmitableThreadLocal<>();
        threadLocal.set("1");
        ExecutorService executor = Executors.newFixedThreadPool(1);
        executor.execute(
            TransmitableRunnable.wrap(
                () -> System.out.println("子线程打印 : " + threadLocal.get())
            )
        );
        threadLocal.set("2");
        executor.execute(
            TransmitableRunnable.wrap(
                () -> System.out.println("子线程打印 : " + threadLocal.get())
            )
        );
        threadLocal.set("3");
        executor.execute(
            TransmitableRunnable.wrap(
                () -> System.out.println("子线程打印 : " + threadLocal.get())
            )
        );
        //子线程打印 : 1
        //子线程打印 : 2
        //子线程打印 : 3
        executor.shutdown();
    }
}

关键源码

TransmitableThreadLocal
public class TransmitableThreadLocal<T> extends ThreadLocal<T> {

    @Override
    public void set(T value) {
        super.set(value);
        TransmitableThreadLocalHolder.register(this);
    }

    @Override
    public void remove() {
        super.remove();
        TransmitableThreadLocalHolder.unregister(this);
    }
}

TransmitableThreadLocalHolder
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class TransmitableThreadLocalHolder {

    private final static Map<Thread, Map<TransmitableThreadLocal<?>, Object>> HOLDER
        = new ConcurrentHashMap<>();

    public static void register(TransmitableThreadLocal<?> threadLocal) {
        HOLDER.computeIfAbsent(Thread.currentThread(), k -> new ConcurrentHashMap<>())
            .put(threadLocal, threadLocal.get());
    }

    public static Map<TransmitableThreadLocal<?>, Object> get() {
        return HOLDER.get(Thread.currentThread());
    }

    public static void unregister(TransmitableThreadLocal<?> transmitableThreadLocal) {
        if (HOLDER.containsKey(Thread.currentThread())) {
            HOLDER.get(Thread.currentThread()).remove(transmitableThreadLocal);
            if (HOLDER.get(Thread.currentThread()).isEmpty()) {
                HOLDER.remove(Thread.currentThread());
            }
        }
    }
}

TransmitableRunnable
import java.util.HashMap;
import java.util.Map;

public class TransmitableRunnable implements Runnable {

    private final Runnable ORIGINAL;
    @SuppressWarnings("all")
    private final Map<TransmitableThreadLocal, Object> CONTEXT;


    public static TransmitableRunnable wrap(Runnable original) {
        return new TransmitableRunnable(original);
    }


    private TransmitableRunnable(Runnable ORIGINAL) {
        this.ORIGINAL = ORIGINAL;
        CONTEXT = new HashMap<>();
        // 抓取当前线程的所有 TransmitableThreadLocal
        CONTEXT.putAll(TransmitableThreadLocalHolder.get());
    }

    @Override
    public void run() {
        // 回放 父线程的所有 TransmitableThreadLocal
        CONTEXT.forEach(
            TransmitableThreadLocal::set
        );
        // 执行原始的 Runnable逻辑
        ORIGINAL.run();
    }
}

阿里开源的 transmittableThreadLocal 差不多就是这个原理。

5. 源码

threadLocalDemo: threadLocalDemo (gitee.com)