ThreadLocal 源码全详解(ThreadLocalMap)

3,816 阅读12分钟

1. ThreadLocal 源码分析

1.1 ThreadLocal 原理

首先我们得从 Thread 类讲起,在 Thread 类中有维护两个 ThreadLocal.ThreadLocalMap 对象(初始为 null,只有在调用 ThreadLocal 类的 set 或 get 时才创建它们):threadLocalsinheritableThreadLocals。也就是说每个 Thread 对象都有两个 ThreadLocalMap 对象,ThreadLocalMapThreadLocal 定制的 HashMap,是 ThreadLocal 的内部类,其 key 为弱引用的 ThreadLocal 对象,value 为对应设置的 Object 对象。

public class Thread implements Runnable {
    //......
    //与此线程有关的ThreadLocal值。由ThreadLocal类维护
    ThreadLocal.ThreadLocalMap threadLocals = null;

    //与此线程有关的InheritableThreadLocal值。由InheritableThreadLocal类维护
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
    //......
}

我们想设置 ThreadLocal 值时,通过查看源码我们可以发现,使用 ThreadLocal 的 set() 方法时实际是调用了当前线程的 ThreadLocalMap 的 set() 方法。ThreadLocal 的 set() 方法中,先用 Thread.currentThread() 获得当前线程对象 t ,通过当前线程对象 t 获取线程的 ThreadLocalMap 对象 map ,接着判断 map 是否为 null——为 null 则调用creadMap() 方法传入当前线程对象 t 和当前 set() 方法的入参 value 创建为当前线程创建 ThreadLocalMap 对象并 put value 添加变量;不为 null 则调用 map.set(value) 设置该 ThreadLocal 对象的值。

// ThreadLocal.java
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

由此可见,变量是放在当前线程的 ThreadLocalMap 中,而 ThreadLocal 是 ThreadLocalMap 的封装,传递了变量值。

1.2 ThreadLocalMap 原理

ThreadLocalMap 的数据结构实际上是数组,对比 HashMap 它只有散列数组没有链表。

1.2.1 ThreadLocalMap 的四个属性
  • Entry[] table
  • INITIAL_CAPACITY
  • size
  • threshold
// 源码
static class ThreadLocalMap {
    
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    //初始容量默认为16,必须是2的幂
    private static final int INITIAL_CAPACITY = 16;

    // table每次resized,容量都得是2的幂
    private Entry[] table;

    // 当前table中的存放的元素数量
    private int size = 0;

    // 扩容阀值
    private int threshold; // Default to 0
	
    /**
     * 接下来还有 set()、get()、扩容方法、expungeStaleEntry()、cleanSomeSlots()等重要方法就不贴源码了
	 * ......
     */
} 
1.2.2 Hash 算法

ThreadLocalMap 实现了自己的hash 算法来解决散列表数组冲突。

int i = key.threadLocalHashCode & (len - 1);

这里的`i` 就是当前 key 在散列表中对应的数组下标位置。`len` 指的是`ThreadLocalMap` 当前的容量`capacity`。
而比较重要的是我们必须知道`key.threadLocalHashCode` 这个值是怎么计算的?
通过源码可以知道`threadLocalHashCode` 是`ThreadLocal` 的一个属性,其值是调用`ThreadLocal` 的`nextHahCode()` 方法获得的。
`nextHashCode()`:返回`AtomicInteger nextHahCode` 的值,并将`AtomicInteger nextHahCode`  自增一个常量值——`HASH_INCREMENT(0x61c88647)`。
> __特别提醒__:每创建一个`ThreadLocal` 对象(每将对象 hash 到 map 一次),`ThreadLocal.nextHashCode` 就增长`0x61c88647`。(`0x61c88647` 是斐波那契数,使用该数值作为 hash 增量可以使 hash 分布更加均匀。)
 
代码如下:
```java

public class ThreadLocal<T> {
    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode = new AtomicInteger();

    private static final int HASH_INCREMENT = 0x61c88647;

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    static class ThreadLocalMap {
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);

            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
    } 
}

总结:ThreadLocalMaphash 算法很简单,就是使用斐波那契数的倍数(len -1) 按位与(这个结果其实就是斐波那契数的倍数capacity 取模)的结果作为当前 key 在散列表中的数组下标。

1.2.3 Hash 冲突

HashMap 如何解决 hash 冲突HashMap 解决冲突是使用链地址法,在数组上构造链表结构,将冲突的数据放在链表上,且每个数组元素也就是链表的长度超过某个数量后会将链表转换为红黑树。

ThreadLocalMap 使用的是线性探测的开放地址法去解决 hash 冲突。 当当前 key 存在 hash 冲突,会线性地往后探测直到找到为 null 的位置存入对象,或者找到 key 相同的位置覆盖更新原来的对象。在这过程中若发现不为空但 key 为 null 的桶(key 过期的 Entry 数据)则启动探测式清理操作。

4.1.2.4 ThreadLocal.set() 源码详解

ThreadLocalset() 方法的原理:先获得当前线程对象,传入线程对象到 getMap() 方法获得 ThreadLocalMap 对象,判断其是否存在,存在则用 map 的 set() 方法进行数据处理,否则调用 createMap() 方法传入当前线程对象和 value 创建 map。 代码如下:

// ThreadLocal 的 set() 方法的源码
public void set(T value){
  Thread t = Thread.currentThread();
  ThreadLocalMap map=getMap(t);
  if(map != null)
	  map.set(this, value);
  else
	  createMap(t, value);
}
void create(Thread t, T firstValue){
	t.threadLocals = new ThreadLocalMap(this, firstValue);
}

set() 方法的主要的核心逻辑还是在ThreadLocalMap

1.2.5 ThreadLocalMap.set() 原理详解

通过 ThreadLocalMap 中的 set() 方法可以新增或更新数据,这可以分为四种情况:

  • :通过 hash 计算后的位置对应的 Entry 数据为空:直接将数据存入该位置即可。

  • :位置对应的数据不为空,但 key 值和当前 ThreadLocal hash 计算后的 key 值相同:直接将数据更新覆盖到该位置。

  • hash 到的位置不为空,key 值和当前 hash 到的 key 值不相同,向后遍历且在找到 Entry 为 null 的位置或者 key 值相同的位置之前,未遇到 Entry not null 但 key 为 null 的情况:直接存入数据或更新数据。

  • hash 到的位置不为空,在向后遍历时遇到了 Entry not nullkey 为 null (假设该位置下标为x )的情况:

    1. 此时执行replaceStaleEntry() 方法(替换过期数据),从下标x 为起点向前遍历,初始化探测式清理的开始位置:slotToExpunge = staleSlot = x,进行探测式数据清理。
    2. staleSlot 开始向前遍历查找其他的过期数据,并更新清理过期数据的起始下标 slotToExpunge(遇到 key 为 null 的位置则更新 slotToExpu nge = 当前下标 ),直到遇到 Entry = null 停止向前遍历。
    3. staleSlot 开始向后遍历,直至遇到 Entry = null 或者 key = hash 后得到的 key。 * Entry = null:将数据覆盖替换掉staleSlot 位置上的Entry
    • key = hash 后得到的 key:将数据更新,然后与 staleSlot 的 Entry 交换。
    1. 在前 3 步的过程中若发现有两个或以上的key = null 则调用cleanSomeSlots(expungeStaleEntry(slotToExpunge), len) 方法清理过期元素。(从 slotToExpunge 开始向后检查并清理过期元素,此时主要是通过 expungeStaleEntry()cleanSomeSlots() 两个方法工作。)
1.2.6 ThreadLocalMap.set() 源码详解

set() 方法的代码如下:

// ThreadLocal.ThreadLocalMap.set()方法
private void set(ThreadLocal<?> key, Object value) {
    // 通过 key 计算出当前 key 在散列表对应的位置——i
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    
    // 从 i 开始向后遍历,查找找到为空的位置(也就是得到 tab[i]),注意:通过nextIndex()方法,在遍历完散列数组的最后位置后,遍历的下一个位置是 index=0
    /** 
     * private static int nextIndex(int i, int len) {
     *     return ((i + 1 < len) ? i + 1 : 0);
     * }
	*/
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
		// 遇到key相同,直接更新覆盖,返回
        if (k == key) {
            e.value = value;
            return;
        }
		// 遍历到到key=null(过期元素),执行replaceStaleEntry(),返回
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 在 空位置 存放数据
    tab[i] = new Entry(key, value);
    // size++
    int sz = ++size;
    // 调用boolean cleanSomeSlots()进行启发式清理过期元素
    // 若未清理到任何数据且size超过阈值threshold(len*2/3)则rehash(),rehash()中会先进行探测式清理过期元素,若此时size>=len/2(threshold-threshold/4)则扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

通过上面的代码以及注释可以清晰地了解使用 set() 方法时的前三种情况的处理逻辑,第四种情况的主要处理逻辑都在 replaceStaleEntry() 方法中。

ThreadLocal.ThreadLocalMap.replaceStaleEntry() 方法代码如下:

// ThreadLocal.ThreadLocalMap.replaceStaleEntry()
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;
	// 从staleSlot向前遍历直到遇到Entry=null,期间遇到key=null时更新slotToExpunge
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;
	// 从staleSlot向后遍历,直到Entry=null停止
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {

        ThreadLocal<?> k = e.get();
		// 遇到key=key
        if (k == key) {
            // 更新该位置Entry并将该位置和staleSlot的Entry交换
            e.value = value;
			
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
			// 若此时slotToExpunge=staleSlot,说明向前遍历时没有发现过期元素以及向后遍历也没发现过期元素,此时修改探测式清理过期元素的起始下标为i(也就是从i作为起始下标开始探测式清理)
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // cleanSomeSlots()为启发式清理,expungeStaleEntry()为探测式清理
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
		// 若遇到key=null 且 slotToExpunge=staleSlot,说明向前遍历未遇到过期元素但向后遍历遇到了过期元素,此时修改探测式清理过期元素的起始下标为i
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }
	// 从staleSlot向后遍历过程中遇到了Entry=null,此时直接将数据更新到staleSlot位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
	// 若slotToExpunge!=staleSlot,说明向前遍历或者向后遍历过程中有遇到过期元素,此时slotToExpunge为向前遍历中“最远”的或者向后遍历中遇到的“最远”的key为null的下标,启动探测式清理后启动启发式清理。
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

1.2.7 探测式清理详解

探测式清理,也就是 expungeStaleEntry() 方法。

从开始位置向后遍历,清除过期元素,将遍历到的过期数据的 Entry 设置为 null ,沿途碰到的未过期的数据则将其 rehash 后重新在 table 中定位,如果定位到的位置有数据则往后遍历找到第一个 Entry=null 的位置存入。接着继续往后检查过期数据,直到遇到空的桶才终止探测。

代码如下:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
	// 传入的staleSlot位置上的数据一定是过期数据,将staleSlot位置的置空
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;
	// for循环是向后遍历,直到遇到 Entry=null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // 若当前遍历的 key 为 null则将 Entry置空
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        //若当前遍历的 key 不为null,将其rehash并将key的原本位置Entry置空,再将key的Entry放入rehash后的位置以及其后面位置的第一个为null的位置
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    // 返回i,也就是探测式清理向后遍历中遇到的第一个为null的位置
    return i;
}

这可以使得 rehash 后的数据距离正确的位置(i= key.hashCode & (tab.len - 1))更近一些。能过提高整个散列表的查询性能。

1.2.8 启发式清理详解

启发式清理,cleanSomeSlots(int i, int n)

向后遍历 log2n\lfloor log_2 n \rfloor 个位置,下标 i 作为遍历的第一个位置。遍历中遇到位置上 key=null 时(假设该位置为 i ),同步调用 expungeStaleEntry(i) 方法。

代码如下:

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

注意:在 ThreadLocalMap.set() 方法的调用方法 ThreadLocalMap.replaceStaleEntry() ,一般会这样调用—— cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)

1.2.9 扩容机制

ThreadLocalMap 的扩容是在 set() 方法之后才有可能执行的。在 set() 方法的最后,若在 set() 未清理到任何数据且 size 超过或等于阈值 threshold(也就是 len*2/3)则 rehash()rehash() 中会先进行探测式清理过期元素,若在 rehash() 清除过后 size>=len/2(也就是 threshold-threshold/4)则调用 resize() 扩容。

注意:阈值是 len*2/3

rehash() 的代码如下:

private void rehash() {
    // 该方法为从下标0出发,找到第一个 key=null 的位置j,以j为起始开始探测式清理
    expungeStaleEntries();
    // 阈值 threshold=len*2/3
	// 当前size超过或等于阈值的3/4时执行扩充
    if (size >= threshold - threshold / 4)
        resize();
}

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

扩容的具体实现是 resize() 。首先,扩容是 tab 直接扩容为原来的 2 倍的,然后遍历旧的散列表,重新计算每个元素的 hash 位置放到新的 tab 数组中,遇到 hash 冲突则往后寻找最近的 entry=null 的位置存放。最后重新计算 tab 执行扩容的阈值。

resize() 的代码如下:

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null;
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}
1.2.10 ThreadLocalMap.get() 详解

使用 get() 操作获取数据有 2 种情况:

  • **一:**通过传入的 key 计算出的位置,位置上 Entry!=null && k==key ,直接返回。
  • **二:**位置上的 Entry 的 key 和 传入的 key 不相等,则从该位置向后遍历,遇到 key=null 就启动探测式清理然后继续遍历,直到遍历到 key=传入的key 的位置,最后将位置上的 Entry 返回;或者位置上的 Entry 为空,返回 null。

ThreadLocalMap.get() 的代码如下:

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        // 第一种情况
        return e;
    else
        // 第二种情况
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 遍历到key=传入的key,返回该Entry
        if (k == key)
            return e;
        if (k == null)
            // 遍历中遇到 key=null,启动探测式清理
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    // 遍历中遇到了null
    return null;
}

2. ThreadLocal 的问题

2.1 ThreadLocal 内存泄露问题

WeakReference 弱引用:我们平时使用的引用基本上都是弱引用,弱引用可以理解为生活中可有可无的物品,当某对象只被弱引用时,在 GC 时一旦扫描到该对象,该对象就会被清理掉。

在 ThreadLocalMap 中的 Entry 的 key 是对 ThreadLocal 的 WeakReference 弱引用,而 value 是强引用。当 ThreadLocalMap 的某 ThreadLocal 对象只被弱引用,GC 发生时该对象会被清理,此时 key 为 null,但 value 为强引用不会被清理。此时 value 将访问不到也不被清理掉就可能会导致内存泄漏。

因此我们使用完 ThreadLocal 后最好手动调用 remove() 方法。但其实在 ThreadLocalMap 的实现中以及考虑到这种情况,因此在调用 set()get()remove() 方法时,会清理 key 为 null 的记录。

2.2 ThreadLocal 无法给子线程共享父线程的线程副本数据

异步场景下无法给子线程共享父线程的线程副本数据,可以通过 InheritableThreadLocal 类解决这个问题。

它的原理就是子线程是通过在父线程中调用 new Thread() 创建的,在 Thread 的构造方法中调用了 Thread的init 方法,在 init 方法中父线程数据会复制到子线程(ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);)。

但是我们做异步处理都是使用线程池,线程池会复用线程会导致问题出现。遇到这种情况我们需要自己解决。