1. 说明
本文不讲基础,旨在讨论使用 ThreadLocal 中需要注意的一些问题。
2. ThreadLocal
1. 结构
- 每个线程里面有一个局部变量 - threadLocalMap,所以threadLocalMap天然做到了线程隔离。
- threadLocal.get()的原理:通过 Thread.currentThread()获得当前的线程,然后拿到 threadLocalMap,通过 key = threadLocal,就可以定位到线程绑定的值。
2. 为什么 threadLocalMap的key要设置成弱引用?
简单说下弱引用:
如果一个对象只被弱引用对象引用,那么当内存不足的时候会回收该对象。
强引用
- 假设在业务代码中使用完ThreadLocal,threadLocal Ref被回收了
- 当前线程一直活跃(例如线程池里面的核心线程),所以当前线程的threadLocalMap的Entry强引用了threadLocal,造成threadLocal无法被回收。
弱引用
-
假设在业务代码中使用完ThreadLocal,threadLocal Ref被回收了
-
由于entry的key是弱引用,所以当内存不足的时候,会把 threadLocal 设置为null。
-
单GC仅仅是让key的内存释放,后续还要根据 key 是否为 null 来进一步释放值的内存,释放时机有
- 获取 key 发现 null key
- set key 时,会使用启发式扫描,清除临近的 null key,启发次数与元素个数,是否发现 null key 有关
- remove 时(推荐),因为一般使用 ThreadLocal 时都把它作为静态变量,因此 GC 无法回收
这就意味着使用完ThreadLocal,CurrentThread依然运行的前提下,就算忘记调用remove方法,弱引用比强引用可以多一层保障:弱引用的ThreadLocal会被回收,对应的value在下一次ThreadLocalMap调用set,get,remove中的任一方法的时候会被清除,从而避免内存泄漏。但是通常我们的 threadLocal 都会被设置成一个静态变量,提供全局访问的API,所以强烈要求每次使用完 remove()
3. 脏数据
线程池情况下,线程会被复用,所以可能会使用之前threadLocal设置的值。比如SpringMVC执行请求的线程池。所以每次请求执行完成后手动remove。
3. InheritableThreadLocal
解决什么?
public static void main(String[] args) throws InterruptedException {
ThreadLocal<String> threadLocal = new ThreadLocal<>();
threadLocal.set("父线程设置的值");
TimeUnit.MILLISECONDS.sleep(
500
);
new Thread(() -> System.out.println("子线程获取的值:" + threadLocal.get())).start();
}
// 子线程获取的值:null
现在想要达成的效果是 子进程能够共享父进程的值
public static void main(String[] args) throws InterruptedException {
ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
threadLocal.set("父线程设置的值");
TimeUnit.MILLISECONDS.sleep(
500
);
new Thread(() -> System.out.println("子线程获取的值:" + threadLocal.get())).start();
}
// 子线程获取的值:父线程设置的值
问题
public static void main(String[] args) throws InterruptedException {
ExecutorService executor = Executors.newFixedThreadPool(1);
Student student = new Student();
student.setUsername("父线程");
ThreadLocal<Student> threadLocal = new InheritableThreadLocal<>();
threadLocal.set(student);
executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
TimeUnit.MILLISECONDS.sleep(500);
student.setUsername("父进程新设置的值");
executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
// 子线程获得值 : Student(username=父线程)
// 子线程获得值 : Student(username=父进程新设置的值)
}
其实说实话吧,这个问题也不是问题,其实就是父线程对student的改变对子线程可见。原因是因为父线程new子线程的时候,会将自身 InheritableThreadLocal 的值拷贝给 子线程的InheritableThreadLocal,这个时候拷贝是直接拷贝的student的引用(浅拷贝)。所以自然而然父线程对引用的任何修改,子线程都能看见。
如果你想要深拷贝,你只需要继承 InheritableThreadLocal ,实现关键方法即可
public static void main(String[] args) throws InterruptedException {
@SuppressWarnings("all")
ThreadLocal<Student> threadLocal = new InheritableThreadLocal() {
@Override
protected Object childValue(Object parentValue) {
// 实现深拷贝
if (parentValue instanceof Student) {
Student student = (Student) parentValue;
Student copyStudent = new Student();
copyStudent.setUsername(student.getUsername());
return copyStudent;
} else {
return super.childValue(parentValue);
}
}
};
ExecutorService executor = Executors.newFixedThreadPool(1);
Student student = new Student();
student.setUsername("父线程");
threadLocal.set(student);
executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
TimeUnit.MILLISECONDS.sleep(500);
student.setUsername("父进程新设置的值");
executor.execute(() -> System.out.println("子线程获得值 : " + threadLocal.get()));
// 子线程获得值 : Student(username=父线程)
// 子线程获得值 : Student(username=父线程)
}
InheritableThreadLocal 的原理其实就是 父线程构建 子线程的时候会copy一下父线程的InheritableThreadLocal,这里的 copy 默认是直接拷贝引用。但是 InheritableThreadLocal 给这个copy提供了一个扩展,也就是说你可以去继承 InheritableThreadLocal 实现你自己的copy逻辑。将浅拷贝逻辑转化为深拷贝逻辑。
4. TransmittableThreadLocal
解决什么?
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class Demo {
/**
* 模拟 tomcat 线程池
*/
static ExecutorService EXECUTOR = Executors.newFixedThreadPool(1);
static ThreadLocal<String> TRACE_ID_HOLDER = new InheritableThreadLocal<>();
public static void main(String[] args) {
for (int i = 0; i < 3; i++) {
String traceId = i + 1 + "";
System.out.println("主线程获得 traceId : " + traceId);
mockReq(traceId);
// 主线程获得 traceId : 1
// 主线程获得 traceId : 2
// 主线程获得 traceId : 3
// 子线程获得 traceId : 1
// 子线程获得 traceId : 1
// 子线程获得 traceId : 1
}
}
/**
* mock req
*
* @param traceId
*/
public static void mockReq(String traceId) {
TRACE_ID_HOLDER.set(traceId);
EXECUTOR.execute(
() -> System.out.println("子线程获得 traceId : " + TRACE_ID_HOLDER.get())
);
}
}
这个场景是 日志traceId透传。可以看到开启的任务只有第一次能拿到请求的traceId,后面的还是拿到的第一次请求的traceId。这是因为 InheritableThreadLocal 是通过 new Thread 帮助子线程拷贝 InheritableThreadLocal 相关值。但这是线程池,线程池里面只有一个线程,也就是 new Thread的机会只有一次,自然而然也就是第一次。
解决:
原理:
使用:
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static com.hdu.transmitableThreadLocal.TransmitableRunnable.wrap;
public class Main {
public static void main(String[] args) {
TransmitableThreadLocal<String> threadLocal = new TransmitableThreadLocal<>();
threadLocal.set("1");
ExecutorService executor = Executors.newFixedThreadPool(1);
executor.execute(
TransmitableRunnable.wrap(
() -> System.out.println("子线程打印 : " + threadLocal.get())
)
);
threadLocal.set("2");
executor.execute(
TransmitableRunnable.wrap(
() -> System.out.println("子线程打印 : " + threadLocal.get())
)
);
threadLocal.set("3");
executor.execute(
TransmitableRunnable.wrap(
() -> System.out.println("子线程打印 : " + threadLocal.get())
)
);
//子线程打印 : 1
//子线程打印 : 2
//子线程打印 : 3
executor.shutdown();
}
}
关键源码
TransmitableThreadLocal
public class TransmitableThreadLocal<T> extends ThreadLocal<T> {
@Override
public void set(T value) {
super.set(value);
TransmitableThreadLocalHolder.register(this);
}
@Override
public void remove() {
super.remove();
TransmitableThreadLocalHolder.unregister(this);
}
}
TransmitableThreadLocalHolder
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class TransmitableThreadLocalHolder {
private final static Map<Thread, Map<TransmitableThreadLocal<?>, Object>> HOLDER
= new ConcurrentHashMap<>();
public static void register(TransmitableThreadLocal<?> threadLocal) {
HOLDER.computeIfAbsent(Thread.currentThread(), k -> new ConcurrentHashMap<>())
.put(threadLocal, threadLocal.get());
}
public static Map<TransmitableThreadLocal<?>, Object> get() {
return HOLDER.get(Thread.currentThread());
}
public static void unregister(TransmitableThreadLocal<?> transmitableThreadLocal) {
if (HOLDER.containsKey(Thread.currentThread())) {
HOLDER.get(Thread.currentThread()).remove(transmitableThreadLocal);
if (HOLDER.get(Thread.currentThread()).isEmpty()) {
HOLDER.remove(Thread.currentThread());
}
}
}
}
TransmitableRunnable
import java.util.HashMap;
import java.util.Map;
public class TransmitableRunnable implements Runnable {
private final Runnable ORIGINAL;
@SuppressWarnings("all")
private final Map<TransmitableThreadLocal, Object> CONTEXT;
public static TransmitableRunnable wrap(Runnable original) {
return new TransmitableRunnable(original);
}
private TransmitableRunnable(Runnable ORIGINAL) {
this.ORIGINAL = ORIGINAL;
CONTEXT = new HashMap<>();
// 抓取当前线程的所有 TransmitableThreadLocal
CONTEXT.putAll(TransmitableThreadLocalHolder.get());
}
@Override
public void run() {
// 回放 父线程的所有 TransmitableThreadLocal
CONTEXT.forEach(
TransmitableThreadLocal::set
);
// 执行原始的 Runnable逻辑
ORIGINAL.run();
}
}
阿里开源的 transmittableThreadLocal 差不多就是这个原理。