Java多线程之ThreadLocal

1,036 阅读8分钟

欢迎关注公众号:五小竹

ThreadLocal是什么?

Java API文档中是这么定义ThreadLocal的:

该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其 get 或 set 方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。

所以说ThreadLocal其实是一个将对象的作用范围限定在当前线程的一个容器类。在多线程的情况下,访问到局部变量时候,该变量可能被其他线程修改。而ThreadLocal提供了一种线程封闭的技术,使得每个线程独享该变量。如同Mybatis的一级缓存和二级缓存的作用域。一级缓存的作用域是一个SqlSession,在同一个SqlSession的查询会使用缓存。ThreadLocal的作用域就是当前的Thread。

那些年使用过的ThreadLocal

  1. 用于保存线程不安全的工具类,典型的需要使用的类就是 SimpleDateFormat。例如在阿里Java开发手册中的一个规定。 rRdqyj.png

  2. 每个线程内需要保存类似于全局变量的信息(例如在拦截器中获取的用户信息),可以让不同方法直接使用,避免参数传递的麻烦却不想被多线程共享(因为不同线程获取到的用户信息不一样)。

  3. 最常见的ThreadLocal使用场景为 用来解决数据库连接、Session管理等,例如Spring事务,事务是和线程绑定起来的,Spring框架在事务开始时会给当前线程绑定一个Jdbc Connection,在整个事务过程都是使用该线程绑定的connection来执行数据库操作,实现了事务的隔离性。Spring框架里面就是用的ThreadLocal来实现这种隔离

public abstract class TransactionSynchronizationManager {
//线程绑定的资源,比如DataSourceTransactionManager绑定是的某个数据源的一个Connection,在整个事务执行过程中
//都使用同一个Jdbc Connection
private static final ThreadLocal<Map<Object, Object>> resources =
		new NamedThreadLocal<>("Transactional resources");
		......
}

  1. Srping的RequestContextHolder, DateTimeContextHolder等。

ThreadLocal原码解析

ThreadLocal类结构

rRd4TP.png

主要属性

public class ThreadLocal<T> {
    /**
     ThreadLocal 的 hashCode,用于计算当前 ThreadLocal 在 ThreadLocalMap 中的索引位置
     */
    private final int threadLocalHashCode = nextHashCode();

    
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
      哈希魔数
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     *  计算HashCode (每次递增一个哈希魔数)
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

ThreadLocalMap

ThreadLocalMap 其实是一个以ThreadLocal为Key,ThreadLocal 保存的值为Value的 Map,源码如下:

 static class ThreadLocalMap {
//WeakReference弱引用,当没有引用指向时,会被回收
        static class Entry extends WeakReference<ThreadLocal<?>> {
           
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        //数组初始容量16
        private static final int INITIAL_CAPACITY = 16;
        private Entry[] table;
        private int size = 0;

        private int threshold; // Default to 0
        //扩容阀值 默认是len的三分之二
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

      
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

get()方法

ThreadLocal如何实现变量的线程隔离独享呢?

从Thread类的源码中可以看到,每个线程都有自己的ThreadLocalMap,初始为null。

   //Thread类中有两个变量threadLocals和inheritableThreadLocals,都是ThreadLocal内部类ThreadLocalMap类型的变量
    ThreadLocal.ThreadLocalMap threadLocals = null;

    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

get()方法进来首先获取当前线程t,然后通过getMap(t)返回Thread t中绑定的threadLocals变量。就是上面所说的ThreadLocal.ThreadLocalMap threadLocals变量。如果他是null,则调用setInitalValue()。如果不为空,则通过t获取ThreadLocalMap.Entry。Thread类中的threadLocals变量是在调用ThreadLocal的get()/set()方法时赋值的。

    public T get() {
        //获取当前的线程
        Thread t = Thread.currentThread();
        //获取当前线程的map
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    } 

这里的map是保存到线程t中的,而不是保存到ThreadLocal中的,他是Thread t的threadLocals变量。

    
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
    
    /**
    * 如果getMap(t)返回的线程t的threadLocal为空,这创建新的ThreadLocalMap绑定到线程t的threadLocals
    */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    
    ThreadLocalMap getMap(Thread t) {
        //返回当前线程的threadLocals。
        return t.threadLocals;
    }

set()

set方法同样的,先获取当前线程t的threadLocals变量。如果不为空,这以当前线程t为key,与value形成键值对存入当前线程的threadLocals。如果为空,则调用createMap()方法先创建,再绑定。

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

set()深入剖析

 private void set(ThreadLocal<?> key, Object value) {

        Entry[] tab = table;
        int len = tab.length;
        //1.计算下标
        int i = key.threadLocalHashCode & (len-1);
        //2.查看 tab[i] 有没有值,有值的话,索引+1,直到找到没有值
        for (Entry e = tab[i]; e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                //2.1. 覆盖更新
                if (k == key) {
                    e.value = value;
                    return;
                }
                //2.2 key被清理 直接替换
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
           // 当前 i 位置是无值的,可以被当前thradLocal 使用
            tab[i] = new Entry(key, value);
            int sz = ++size;
            // 如果数组大小>=扩容阈值(数组大小的三分之二)
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
            //扩容
            rehash();
    }

在源码中看到计算数组索引位置的公式是:hashCode & 数组大小, ThreadLocalMap解决hash冲突的方法是线性探测法。与HashMap的数组+链表+红黑树的实现方式不同。线性探测法是根据当前的位置计算出下一个位置,如果发生冲突,就继续找下一个空位置。

InheritableThreadLocal

InheritableThreadLocal介绍

我们注意到Thread类中的另一个变量inheritableThreadLocals,这个变量是做什么的呢?看一下JDK文档的介绍。

该类扩展了 ThreadLocal,为子线程提供从父线程那里继承的值:在创建子线程时,子线程会接收所有可继承的线程局部变量的初始值,以获得父线程所具有的值。通常,子线程的值与父线程的值是一致的;但是,通过重写这个类中的 childValue 方法,子线程的值可以作为父线程值的一个任意函数。

ThreadLocal声明的线程私有变量无法给该线程的子线程使用,InheritableThreadLocal声明的线程私有变量可以给其子线程使用。简而言之,InheritableThreadLocal 可使子线程继承父线程的值, ThreadLocal 不能实现值继承。 但是子线程对变量的修改对父线程也是不可见的。

public class Demo {
	
    public static ThreadLocal<Integer> threadLocal =
    new ThreadLocal<Integer>();
    public static InheritableThreadLocal<Integer> inheritableThreadLocal = new InheritableThreadLocal<Integer>();

    public static void main(String[] args) throws InterruptedException {
        threadLocal.set(1);
        inheritableThreadLocal.set(2);
        //新创建一个子线程
        new Thread(()-> {
            System.out.println("子线程threadLocal:"+threadLocal.get());
            //在子线程中修改值
            inheritableThreadLocal.set(3);
            System.out.println("子线程inheritableThreadLocal:"+inheritableThreadLocal.get());
        }).start();
        Thread.sleep(1000);
        //输出main线程中的本地变量值
        System.out.println("main 线程中threadLocal:"+threadLocal.get());
        System.out.println("main 线程中inheritableThreadLocal:"+inheritableThreadLocal.get());
    }

}
---------------控制台输出--------------------
子线程threadLocal:null
子线程inheritableThreadLocal:3
main 线程中threadLocal:1
main 线程中inheritableThreadLocal:2

InheritableThreadLocal源码分析

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }

new Thread()放最终会调到下面的init方法。该方法会获取父线程的inheritableThreadLocals变量,赋值给当前线程。

    private void init(ThreadGroup g, Runnable target, String name,long stackSize, AccessControlContext acc) {
       //此处省略n行源码,因为主要想看下面几行
       ......
       
        if (parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
              ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
      ......
    }

使用ThreadLocal要注意的问题

内存泄漏

在上面的Thread和ThreadLocal的源码中,我们了解到:每一个Thread都有自己的ThreadLocalMap类型的变量threadLocals,key为使用弱引用的ThreadLocal实例,value为线程变量的副本。这些对象之间的引用关系如下图,实线为强引用,虚线为弱引用。

rhs3CR.md.png

一旦发生GC,弱引用的对象就会被回收,ThreadLocalMap中就会存在key为null的Entry。如果线程一直没有执行完,那么这些key为null的Entry的value就会一直存在一条强引用链:Thread Ref -> Thread -> ThreaLocalMap -> Entry -> value一直无法回收,造成内存泄漏。

  • 如何避免内存泄漏呢?在用完ThreadLocal以后,调用remove()方法。清除key、value。
    private void remove(ThreadLocal<?> key) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
            if (e.get() == key) {
                e.clear(); //通过this.referent = null;清除key
                expungeStaleEntry(i); //清除value
                return;
            }
        }
    }
        
    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;

        // expunge entry at staleSlot
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;
        //此处省略N行,rehash
            ......
    }

线程池中线程复用的问题

我们在使用线程池来执行多线程操作时候,如果要使用ThreadLocal的话,就需要考虑线程复用的问题,因为不是每次执行都创建新的线程,大多数情况下是复用已有的线程,导致不同的执行任务从ThreadLocal中get到上一次线程的变量。解决办法同样是在线程任务执行完后,调用remove()方法。

总结

ThreadLocal提供了线程局部变量,一个线程局部变量在多个线程中,分别有独立的副本。正是他提供了这种隔离的技术,使我们可以通过当前线程来获取我们需要的变量。避免了我们在开发中参数层层传递。实现了低耦合,简化我们的开发。ThreadLocal在框架中随处可见。所以是时候去征服他了!