ThreadLocal源码解析

71 阅读12分钟

ThreadLocal使用

public class ThreadLocalTest {
    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<String> threadLocal = new ThreadLocal<>();
        new Thread(() -> {
            threadLocal.set("123");
            System.out.println(threadLocal.get());// 123
        }).start();
        Thread.sleep(1000);
        System.out.println(threadLocal.get());// null
    }
}

ThreadLocal是怎么做到在同一个线程内共享变量但是在另一个线程获取不到变量的呢?

ThreadLocal源码解析

构造器

空构造

public ThreadLocal() {
}

set方法

public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程对应的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // map 不为空就调用 ThreadLocalMap 的 set 方法存值
        map.set(this, value);
    else
        // 如果是该线程第一次 set,就会为该线程初始化 ThreadLocalMap,并存入值
        createMap(t, value);
}

/*********************************************************************************/

// 获取线程对应的 ThreadLocalMap
ThreadLocalMap getMap(Thread t) {
    // 每个线程都会对应一个 ThreadLocalMap,ThreadLocalMap里面维护了一个Entry[]数组
    // Entry 就是用来存储线程的共享变量的
    return t.threadLocals;
}

/*********************************************************************************/

// 初始化 ThreadLocalMap,t 为当前线程,firstValue 为要存入的值
void createMap(Thread t, T firstValue) {
    // 这里直接对线程的 threadLocals 属性设值
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

在Thread类中有这么一个属性叫threadLocals,这就是前面说的每个线程都会对应一个ThreadLocalMap

ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocalMap 的构造方法

// firstKey 就是 ThreadLocal,可以看出 ThreadLocal 的存值方式是key-value形式
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // 初始化 Entry[] 数组
    table = new Entry[INITIAL_CAPACITY];
    // 利用当前的ThreadLocal的hash值计算下标
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 创建一个新的 Entry,放入 Entry[] 数组
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    // 设置扩容阈值,为容量的 2/3
    setThreshold(INITIAL_CAPACITY);
}

/*********************************************************************************/

// Entry 是 ThreadLocalMap 的一个静态内部类,可以看到弱引用的是ThreadLocal,
// 还有一个强引用,就是存入的value值
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;
    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

ThreadLocal 的 hashcode 计算

// 每个 ThreadLocal 都会对应一个 hash 值
private final int threadLocalHashCode = nextHashCode();

// 注意,这个 AtomicInteger 是个静态变量,在类加载时就已经初始化好了,每个ThreaLocal对象都一样
private static AtomicInteger nextHashCode = new AtomicInteger();

// 0x61c88647 是一个神奇的数字,让哈希码能均匀的分布在2的N次方的数组里
private static final int HASH_INCREMENT = 0x61c88647;

// 每创建一个 ThreadLocal 就会调用原子类 AtomicInteger
// 加上一个 0x61c88647 值作为新创建的 ThreadLocal 的hash值
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

ThreadLocalMap的 set 方法

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算 ThreadLocal 对应的数组下标
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 判断从下标取出的 Entry 对应的 ThreadLocal 是不是当前的 ThreadLocal
        if (k == key) {
            // 是的话,就直接覆盖并返回
            e.value = value;
            return;
        }
        // 判断 k 是不是空,如果不是空,就说明有哈希冲突,
        // 会将下标加1,接着判断,直到找到一个位置是空的或者在某个位置找到当前的ThreadLocal
        // nextIndex方法会将下标加1,下标加到len-1之后再加会变为0
        // 不可能所有位置都是满的,因为还没满就会扩容
        if (k == null) {
            // 若k是空,说明这个Entry里的ThreadLocal因为是弱引用,被回收了
            // 这时会处理掉这个没用的Entry,因为已经没有地方可以获取到Entry的value,可以当作垃圾了
            // 除了处理当前下标为i的Entry,也会搜索i附近的可回收的Entry
            // 同时,也会将当前需要插入的值进行插入
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 退出循环后,此时的i对应的下标就是空的,可以创建一个Entry填入
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 会尝试清理无效的Entry,如果有清理至少1个,就不会进行扩容
    // 如果没清理且size大于阈值,就会进入rehash方法尝试扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

/*********************************************************************************/

// staleSlot :需要被回收的Entry的下标
// 该方法会找到一个位置插入key和value,并将插入的新Entry与 staleSlot 的Entry交换位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;
    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    // slotToExpunge :第一个需要被回收的Entry的下标
    int slotToExpunge = staleSlot;
    // 向前搜索需要被回收的Entry的第一个下标,直到找到某个下标为空
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;
    // Find either the key or trailing null slot of run, whichever
    // occurs first
    // 向前找Entry,看看有没有合适的位置可以插入ThreadLocal的值
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        if (k == key) {
            // 找到了与需要插入的ThreadLocal相同的key,就直接覆盖掉旧值
            e.value = value;
            // 将当前的有效Entry与待回收的Entry交换位置,以维持Entry[]数组的顺序
            // 要让新Entry的下标位置与它在没有hash冲突时的下标位置之间不能有空值,否则就会key重复
            // 比如[e1,e2,e3],假设新的Entry是e3,要回收的是e2,那么就要交换e2和e3的位置
            // 如果不交换,等到e2被回收,就变成:[e1,null,e3],下次再插入一个跟
            // e3的key一样的值,就会插到e3前面,不会覆盖掉e3的value,就出现key重复了
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
            // Start expunge at preceding stale entry if it exists
            // if条件为true就说明:在staleSlot之前没有找到可以回收的Entry
            // 所以设置slotToExpunge的值为i,i位置已经交换过,变成原本在staleSlot的无效Entry
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // 从slotToExpunge开始,清除无效的Entry
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        // 如果 slotToExpunge == staleSlot,并且当前i位置是无效Entry,那么
        // i就是第一个无效Entry
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }
    // If key not found, put new entry in stale slot
    // 程序进行到这里,就说明并没有找到与key一致的ThreadLocal,就直接插在staleSlot的位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
    // If there are any other stale entries in run, expunge them
    // slotToExpunge != staleSlot 说明存在需要清除的无效Entry
    if (slotToExpunge != staleSlot)
        // 从slotToExpunge开始,清除无效的Entry
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

/*********************************************************************************/

// 清除无效Entry,并且将因为hash冲突而往后排的ThreadLocal放到离原本下标更近的位置(rehash)
// 比如 [e1,e2,e3,null,null],假设按照&计算下标e3应该放在下标0的位置,但是因为hash冲突,
// 放在了下标2的位置,现在e2为无效节点,需要清除,那么同时也会把e3放到离0更近的位置也就是下标1
// 所以,最后会变为[e1,e3,null,null,null]
// 返回值为清除过程在数组中第一个为null的下标
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    // expunge entry at staleSlot
    // staleSlot的位置是第一个无效节点,先给他清除
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;
    // Rehash until we encounter null
    Entry e;
    int i;
    // 往后找无效节点
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            // 找到无效节点
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 找到有效节点,计算其在没有hash冲突的时候的下标h
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                // 若与它现在的位置不一样,就将他尽量往前挪
                // 将原本的位置置为null
                tab[i] = null;
                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                // 从h往后找到第一个为null的地方,插入
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

/*********************************************************************************/

// 清除无效的Entry,只会向后扫描log2n次
// 返回值为是否在该方法中清理了无效节点
// i为扫描起点,n为扫描次数的控制参数
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 一旦发现有无效的Entry没清理,就会再调用上面的 expungeStaleEntry清理
            // 同时重置扫描次数,会再向后扫描log2n次
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
        // 相当于n/=2
    } while ( (n >>>= 1) != 0);
    return removed;
}

/*********************************************************************************/

private void rehash() {
    // 整个table扫描,清理无效的Entry
    expungeStaleEntries();
    // Use lower threshold for doubling to avoid hysteresis
    // 重新清理后若大小不小于阈值的3/4就扩容
    // 这里不是超过阈值才要扩容,为的是避免滞后性hysteresis?啥意思
    if (size >= threshold - threshold / 4)
        resize();
}

/*********************************************************************************/

// 整个table扫描,清理无效的Entry
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

/*********************************************************************************/

// 扩容,比较简单
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
    setThreshold(newLen);
    size = count;
    table = newTab;
}

set流程小结

每一个线程都有一个 ThreadLocalMap,这个 map 中维护了一个 Entry[] 数组,Entry 继承了 WeakReference,弱引用 ThreadLocal ,强引用value。

当调用 ThreadLocal 的 set 方法存入值时,会以当前的 ThreadLocal 作为 key,存入的值作为 value 来创建一个Entry对象放入线程对应的ThreadLocalMap

get方法

有了前面的铺垫,接下来的两个方法都比较容易看懂

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 如果map还未创建或者没有对应的ThreadLocal,就会调用setInitialValue方法赋初始值
    // 默认是null,可以继承ThreadLocal,重写initialValue方法来设置默认值
    return setInitialValue();
}

/*********************************************************************************/

// 设置默认值
private T setInitialValue() {
    // 获取默认值,默认是null,可以继承ThreadLocal,重写initialValue方法来设置默认值
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

/*********************************************************************************/

// 获取默认值,默认是null,可以继承ThreadLocal,重写initialValue方法来设置默认值 
protected T initialValue() {
    return null;
}

ThreadLocalMap的get方法

private Entry getEntry(ThreadLocal<?> key) {
    // 计算下标
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
	    // 如果没有获取到Entry或者Entry的key不是当前的ThreadLocal
        // 就调用下面的方法进一步查找
        return getEntryAfterMiss(key, i, e);
}

/*********************************************************************************/

// 会往后查找,如果遇到无效的Entry会顺便帮忙清除
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            // 顺便帮忙清除无效的Entry
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

remove方法

比较简单

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // 将节点的key也就是弱引用设置为空,这样当前节点就变成无效节点了
            e.clear();
            // 清除节点,顺便会检查有没有别的无效节点
            expungeStaleEntry(i);
            return;
        }
    }
}

小结

解决hash冲突的两种方法以及优缺点

开放地址法(ThreadLocalMap采用的,出现冲突就将下标后移):

  1. 容易产生堆积问题,不适于大规模的数据存储。
  2. 散列函数的设计对冲突会有很大的影响,插入时可能会出现多次冲突的现象。
  3. 删除的元素是多个冲突元素中的一个,需要对后面的元素作处理,实现较复杂。

链地址法(HashMap采用的,通过链表):

  1. 处理冲突简单,且无堆积现象,平均查找长度短。
  2. 链表中的结点是动态申请的,适合构造表不能确定长度的情况。
  3. 删除结点的操作易于实现。只要简单地删去链表上相应的结点即可。
  4. 指针需要额外的空间,故当结点规模较小时,开放定址法较为节省空间。

ThreadLocalMap 采用开放地址法原因

  1. ThreadLocal 中看到一个属性 HASH_INCREMENT = 0x61c88647 ,0x61c88647 是一个神奇的数字,让哈希码能均匀的分布在2的N次方的数组里, 即 Entry[] table,关于这个神奇的数字google 有很多解析,这里就不重复说了
  2. ThreadLocal 往往存放的数据量不会特别大(而且key 是弱引用又会被垃圾回收,及时让数据量更小),这个时候开放地址法简单的结构会显得更省空间,同时数组的查询效率也是非常高,加上第一点的保障,冲突概率也低

为什么Entry的key(ThreadLocal)要设置成弱引用

如果一旦发生ThreadLocal生命周期结束,但是又没有清空Entry(Entry没被清空,Entry的key也就是ThreadLocal还在使用)这种情况,并且又是强引用,会发生什么情况?就会发生如果这个线程不消亡,这个对象就回收不掉的情况,但是这个对象又是可达的(有Entry的key指向),这就产生了 可达但不使用 的情况,就是我们说的内存泄漏

但是如果我们把ThreadLocal设置成是弱引用,就能尽可能避免这个问题(线程执行完但是Entry没被清,下一次GC的时候,就能把这个对象回收了)

需要注意:这里的value是不能被定义成弱引用的,因为外部没有强引用指向它,但是key(ThreadLocal)用强引用,这点用ThreadLocal本身作为key的设计还是挺巧妙的

问题到这里,可能又有人有行的问题了——既然Entry里只有key被设置成弱引用,value没有设置,那岂不是value会很容易产生内存泄漏的问题?(因为这时候key被回收了,也就是key变成了null,但是value还是强引用,对象还在堆里,并且可达不使用,就是在ThreadLocalMap的Entry里产生了一堆key为null的东西)

的确是这样的,但是这个问题,jdk的设计者在设计的时候就在一定程度上进行了缓解——我们在调用ThreadLocal的get/set/remove的时候,底层源码会自动地把这一堆key为null的东西删除了,以便下一次GC把value回收掉