ThreadLocal使用和原理

328 阅读9分钟

ThreadLocal介绍

ThreadLocal是一个工具类,主要用作在多线程环境下全局变量使用,与普通变量的区别在于ThreadLocal用于多线程中,用于维护线程私有的全局变量。

ThreadLocal解决了什么问题

ThreadLocal主要解决了全局变量线程私有的问题,比如以下场景:

当需要对一个日期执行格式化的时候,需要用到SimpleDateFormat对象,但是这个对象不是线程安全的,如果对线程共用 一个SimpleDateFormat,会存在线程安全的问题,出现多个线程格式化后出现了一样的时间问题。

 private static SimpleDateFormat  dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
 public void doSomething(Date date){
     dateFormat.format(date);
 }

解决方式,我们每次格式化时间的时候,都新建一个SimpleDateFormat,因为执行新建的一个全新对象,因此每一个SimpleDateFormat都是独立的,不会出现线程安全问题。但是每次格式化都要新建,其实它们的功能都是一样,造成了不必要的性能消耗。

 public void doSomething(Date date){
     SimpleDateFormat  dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
     dateFormat.format(date);
 }

那么当我们想要有个全局变量,但是又是不会出现线程干扰,就需要ThreadLocal了,即写出了如下代码:

 public static ThreadLocal<SimpleDateFormat> dateFormatThreadLocal = new ThreadLocal<SimpleDateFormat>() {
         @Override
         protected SimpleDateFormat initialValue() {
             return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
         }
     };
 public void doSomething(Date date){
     dateFormatThreadLocal.get().format(date);
 }

除了以上问题,在日常使用中和框架里面,大多数是将ThreadLocal作为一个线程安全的存储和获取的一种方式,比如在Spring事务管理中TransactionSynchronizationManager就是使用了ThreadLocal来保存一次连接的connection对象,用于事务提交和回滚操作。

ThreadLocal使用

ThreadLocal的用途主要有三种:

  1. 保存线程上下文信息,在任意需要的地方可以获取。 由于ThreadLocal的特性,同一线程在某地方进行设置,在随后的任意地方都可以获取到。从而可以用来保存线程上下文信息。比如Spring的事务管理,用ThreadLocal存储Connection,从而各个DAO可以获取同一Connection,可以进行事务回滚,提交等操作。
  2. 线程安全的,避免某些情况需要考虑线程安全必须同步带来的性能损失。 由于不需要共享信息,自然就不存在竞争问题了,从而保证了某些情况下线程的安全,以及避免了某些情况需要考虑线程安全必须同步带来的性能损失,如上面举例的SimpleDateFormat的问题。
  3. 线程间数据隔离。

ThreadLocal原理

ThreadLocal原理主要是在于ThreadLocal的内部类ThreadLocalMap,该类是一个Map键值对结构的类,key存储了ThreadLocal对象的引用,是一个弱引用,value存储了具体的值。

ThreadLocalMap的定义:

 static class ThreadLocalMap {
 ​
         static class Entry extends WeakReference<ThreadLocal<?>> {
             /** The value associated with this ThreadLocal. */
             Object value;
 ​
             Entry(ThreadLocal<?> k, Object v) {
                 super(k);
                 value = v;
             }
         }
 }

内部是一个定制的散列表。

在使用过程中每个线程Thread类内部维护了两个ThreadLocal.ThreadLocalMap

  /* ThreadLocal values pertaining to this thread. This map is maintained
   * by the ThreadLocal class. */
 ThreadLocal.ThreadLocalMap threadLocals = null;
 ​
 /*
      * InheritableThreadLocal values pertaining to this thread. This map is
      * maintained by the InheritableThreadLocal class.
      */
 ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

threadLocals: 维护了属于当前线程的的ThreadLocal值。

inheritableThreadLocals:维护了可传递的ThreadLocal值,让子线程也能够获取到父线程的ThreadLocal值,有单独的一个扩展类InheritableThreadLocal

ThreadLocalMap

获取ThreadLocal值的源码:

ThreadLocal#get()

 public T get() {
     // 获取当前线程
     Thread t = Thread.currentThread();
     // 获取与当前线程绑定的ThreadLocalMap
     ThreadLocalMap map = getMap(t);
     if (map != null) {
         ThreadLocalMap.Entry e = map.getEntry(this);
         // 根据当前ThreadLocal对象,获取到值并返回
         if (e != null) {
             @SuppressWarnings("unchecked")
             T result = (T)e.value;
             return result;
         }
     }
     // 如果当前线程还没有初始化当前ThreadLocal值,则初始化
     return setInitialValue();
 }
 private T setInitialValue() {
     // 回调函数
     T value = initialValue();
     Thread t = Thread.currentThread();
     ThreadLocalMap map = getMap(t);
     // 如果线程的ThreadLocalMap已经初始化了,则直接设置初始值,否则为Thread线程创建ThreadlocalMap对象
     if (map != null)
         map.set(this, value);
     else
         createMap(t, value);
     return value;
 }

InheritableThreadLocal类

InheritableThreadLocal扩展自ThreadLocal,重写了三个函数

 public class InheritableThreadLocal<T> extends ThreadLocal<T> {
     // 父线程创建子线程的时候,想子线程设置InheritableThreadLocal的时候使用
     protected T childValue(T parentValue) {
         return parentValue;
     }
 ​
     // 由于重写了getMap,因此对InheritableThreadLocal类型的ThreadLocal调用getMap的时候,会获取到inheritableThreadLocals变量,不会影响其threadLocals变量
     ThreadLocalMap getMap(Thread t) {
        return t.inheritableThreadLocals;
     }
     
     // 和getMap同理,为Thread.inheritableThreadLocals变量赋值
     void createMap(Thread t, T firstValue) {
         t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
     }
 }

Thread.inheritableThreadLocals变量的创建流程:

 // (主线程)用户线程创建一个子线程
 new Thread();
 // 构造函数
 public Thread() {
     init(null, null, "Thread-" + nextThreadNum(), 0);
 }
 // 执行初始化
 private void init(ThreadGroup g, Runnable target, String name,
                   long stackSize) {
     // 初始化,注意需要关注最后一个参数为true
     init(g, target, name, stackSize, null, true);
 }
 // 变量inheritThreadLocals=true
 private void init(ThreadGroup g, Runnable target, String name,
                   long stackSize, AccessControlContext acc,
                   boolean inheritThreadLocals) {
     ... 其它代码省略
     // 判断是否符合传递ThreadLocals条件,注意如果父线程不传递,则子线程也无法传递
     if (inheritThreadLocals && parent.inheritableThreadLocals != null)
         // 符合传递条件,创建createInheritedMap
         this.inheritableThreadLocals =
         ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
 }
 // ThreadLocal#createInheritedMap
 static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
     return new ThreadLocalMap(parentMap);
 }
 private ThreadLocalMap(ThreadLocalMap parentMap) {
     Entry[] parentTable = parentMap.table;
     int len = parentTable.length;
     setThreshold(len);
     table = new Entry[len];
     // 复制 parentMap 的记录
     for (Entry e : parentTable) {
         if (e != null) {
             @SuppressWarnings("unchecked")
             ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
             if (key != null) {
                 Object value = key.childValue(e.value);
                 Entry c = new Entry(key, value);
                 int h = key.threadLocalHashCode & (len - 1);
                 while (table[h] != null)
                     h = nextIndex(h, len);
                 table[h] = c;
                 size++;
             }
         }
     }
 }

使用示例:

 private static InheritableThreadLocal<Integer> num = new InheritableThreadLocal<Integer>() {
     @Override
     protected Integer initialValue() {
         return 1;
     }
 };
 ​
 public static void main(String[] args) throws InterruptedException {
     num.set(2);
     TimeUnit.SECONDS.sleep(1);
     new Thread(() -> {
         Integer val = num.get();
         System.out.println("子线程获取到的val值:" + val);
     }).start();
 }

扩展

ThreadLocalMap的key为什么要设计成弱引用?

ThreadLocalMap中的hash表是一个自定义Entry数组,该Entry是一个弱引用扩展,并将key作为了引用。

 static class Entry extends WeakReference<ThreadLocal<?>> {
     /** The value associated with this ThreadLocal. */
     Object value;
 ​
     Entry(ThreadLocal<?> k, Object v) {
         super(k);
         value = v;
     }
 }

什么是弱引用?

弱引用用来描述非必需对象的,它的强度比软引用更弱一些,被弱引用关联的对象只能生存到下一次垃圾收集发生之前。当垃圾收集器工作时,无论当前内存是否足够,都会回收掉只被弱引用关联的对象。

为什么要这么设计?

这样设计的目的主要方便JVM垃圾回收,因为ThreadLocal的使用场景是随着线程的生命周期结束而失去了作用,但是我们不可能实时监控每个线程结束了就去清理一遍吧,或者要求每个线程在结束的时候都去清理一遍它使用的ThreadLocal。因此当将key设置成弱引用后,每次垃圾回收都能够及时回收掉线程里面的ThreadLocalMap使得key称为null。

看到这里可能要问了,那回收了之后,如果线程没有终止,那么下次是不是就取不到对应线程的值了?

这肯定是不可能的,因为ThreadLocalMap内部是采用hash算法的一个Entry[]数组,只要是同一个线程,它每次计算出来的hash值都是相同的,因此可以即时被回收了也能够获取到对应值。

 private Entry getEntry(ThreadLocal<?> key) {
     int i = key.threadLocalHashCode & (table.length - 1);
     Entry e = table[i];
     if (e != null && e.get() == key)
         return e;
     else
         return getEntryAfterMiss(key, i, e);
 }

Threadlocal导致内存泄漏

上面分析了ThreadLocalMap的key是一个弱引用,因此也引出了相关弱引用导致的问题,比如内存泄漏。

什么是内存泄漏?

内存泄漏指的是,存在JVM无法管理到的内存区域,比如一个油桶,流出去的油就不在油桶的管理区域里面。

在Java里面内存泄漏通常会由文件流没有及时关闭,导致已经不用的流处于开启状态,JVM无法回收,还有数据库连接等打开了连接没有关闭也会导致内存泄漏,内存泄漏非常不方便排除,并且泄漏次数过多肯定会导致内存溢出。

为什么会导致内存泄漏?

这主要是因为key设置成虚引用,当我们使用线程池的时候线程不会终止,那么当弱引用key被回收后,就会出现<null,value>这种情况,该ThreadLocal如果不再被使用,则会导致value永远不会被使用,并且线程池线程不终止,造成value内存泄漏。

内存泄漏代码:

 static ThreadLocal<Integer> num = ThreadLocal.withInitial(() -> 1);
 public static void main(String[] args) throws InterruptedException {
 ​
     List<Thread> threadList = new ArrayList<>();
     for (int i = 0; i < 10; i++) {
         Thread thread = new Thread(() -> {
             num.set(2);
         });
         threadList.add(thread);
         thread.start();
     }
     /*
         线程里面的threadLocals的key变为了null
         但是这10个thread对象又没有被回收,因为它们被list持有,因此造成了内存泄漏
     */
     num = null;
 }

使用线程池造成脏数据

上面分析了因为弱引用key被回收了,线程没有终止导致内存泄漏,如果该线程对任务执行完成后,执行下一个任务并且还要使用这个ThreadLocal对象,则还会造成脏数据,因为同一个线程计算出来的hash槽是一样的。

解决弱引用导致的问题

ThreadLocal提供了一个remove()方法,手动释放key为null的value,并且在get的时候也会清除掉key为null的引用,因此大多数情况是不会造成内存泄漏的,当然也建议每次使用完ThreadLocal的时候,都手动remove一下。

image-20220623212135052.png

ThreadLocalMap为什么不设计成Map<Thread,Value>这种形式

ThreadLocalMap键值对为:<ThreadLocal,Value>,如果让我设计是不是通常想到每个线程有自己的一个的value第一反应是直接设计成<Thread,Value>这种模式,每次拿值的时候,直接get(currentThread)即可拿到对应线程的值。

但是这样设计有什么问题?

不能及时销毁掉不使用的值

如果设计成<Thread,Value>,无法监控到Thread终止,如果Thread终止后,对象被回收了则这个Map里面将会造成内存泄漏,如果Thread不会终止(线程池)则需要线程里面的任务自己来控制这个Map是否要删除值,是否要添加,都得让用户自己管理,增加了编码成本。

子线程能否共享父线程的ThreadLocal

子线程是可以共享父线程ThreadLocal的,但是需要使用InheritableThreadLocal类创建才行,如果单纯的使用ThreadLocal是线程私有的,不会共享。

ps:这是一道面试题,当初碰到过,我直接回答的不能共享。这肯定是部队的,要看使用的是否是InheritableThreadLocal,如果是InheritableThreadLocal是可以共享的。