【笔记】ThreadLocal实现原理

92 阅读7分钟

参考:

面试官:小伙子,听说你看过ThreadLocal源码?(万字图文深度解析ThreadLocal) - 掘金

30 | 线程本地存储模式:没有共享,就没有伤害-极客时间

【【java】什么是ThreadLocal?】

【内存泄漏】测试ThreadLocal 在gc后引发的threadLocalMap的key为null,但value不为null的情况_thewindkee的博客-CSDN博客

ThreadLocal通常配合static可用作线程内的全局变量使用的。

public class demo {
    public static ThreadLocal<String> threadLocal=new ThreadLocal<>();

    public static void main(String[] args) {
        new Thread(()->{
            System.out.println(threadLocal.get()); // null
            threadLocal.set("01");
            System.out.println(threadLocal.get()); // 01
        }).start();
        new Thread(()->{
            System.out.println(threadLocal.get()); // null
            threadLocal.set("02");
            System.out.println(threadLocal.get()); // 02
        }).start();

    }
}

ThreadLocal类就好像起到了一个封装的作用,它将泛型T封装成了一个线程内的类型为T的全局变量。

ThreadLocalMap的hash算法

ThreadLocalMap的hash算法和HashMap的hash算法一样,都是&操作。

因此ThreadLocalMap的数组长度必然要求是2的幂次方,扩容倍数也必然是2的幂次方倍(一般扩容倍数取2,不能一下扩太凶)。

int i = key.threadLocalHashCode & (len-1); // hash算法

int newLen = oldLen * 2; // 2倍扩容

对象hash值的计算

ThreadLocalMap并没有直接使用的对象的Object#hashcode()方法计算出的hash值。

ThreadLocal的hash值的计算采用了斐波那契算法。

	private final int threadLocalHashCode = nextHashCode(); // 每次创建ThreadLocal对象时,获取一个新的hash值

    private static AtomicInteger nextHashCode = new AtomicInteger(); // 使用AtomicInteger保证并发安全

    private static final int HASH_INCREMENT = 0x61c88647; // 每次hash值的增长量

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT); // getAndAdd()获取值,并加上HASH_INCREMENT
    }

增长值0x61c88647斐波那契数, 也叫 黄金分割数。

hash冲突的解决方案

ThreadLocalMap中是采用了线性探测的方式来解决的hash冲突。

但与正常的线性探测不同之处在于ThreadLocalMap中的key是弱引用,可能会发生key为null,而value还存在的情况(这种情况可叫做过期key、过期entry、过期数据)。

对于线性探测中遇到的过期entry,ThreadLocalMap会使用探测式清理的方式进行清除。

ThreadLocalMap的set方法

set方法源码

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1); // 计算hash映射到数组的下标

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {   // 线性探测的过程中遇见了目标entry,直接更新数据即可
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i); // 线性探测的过程中遇见了过期数据,改用替换过期entry的策略。
            return;
        }
    }
	// 能到这里,说明线性探测过程中没有找到目标entry,也就是Map中没有这个key。
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold) // 因为新增了数据,要判断是否需要扩容。(先判断能否通过启发式清理掉一些过期数据,然后再判断是否超过阈值)
        rehash();
}

替换过期数据

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot; // slotToExpunge(待清理槽位起始位置),为探测式清理标记起始位置
    for (int i = prevIndex(staleSlot, len); // 向前进行线性探测,直到碰见entry为null。
         (e = tab[i]) != null;				// map中肯定有为null的entry,因为超过阈值就会扩容,数组不会满的。
         i = prevIndex(i, len))				

        if (e.get() == null)   // 如果向前线性探测的过程中发现新的过期数据,更新slotToExpunge值。
            slotToExpunge = i;

    // slotToExpunge标记完毕,开始查找目标entry。
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {

        ThreadLocal<?> k = e.get();

        if (k == key) {  // 向后线性探测的过程中找到目标entry,替换过期entry(过期槽位)
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            if (slotToExpunge == staleSlot) // 如果slotToExpunge标记中发现只有一个过期槽位staleSlot,那么在替换过期entry后,slotToExpunge的值也要随着改变
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); // 进行启发式和探测式清理过程
            return;
        }
    	// 在staleSlot后发现了新的过期槽位并且在slotToExpunge标记中发现只有一个过期槽位staleSlot,更新slotToExpunge。
        // 为什么这个时候也要更新slotToExpunge呢?因为staleSlot这个槽位最终必然会放入正常的entry,没有必要从staleSlot处开始进行清理
        // 那为什么向前查找的时候没有找到其他数据呢?我想起来了,应该是有为null的entry阻止了查找过程
        if (k == null && slotToExpunge == staleSlot) 
            slotToExpunge = i;
    }
	// 能到这里,说明map中没有目标key,需要创建新的entry,并放置到staleSlot槽位处
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    if (slotToExpunge != staleSlot) // 这时的相等就意味着staleSlot后直至为null的entry处都没有过期entry,也就没有清理的必要了。
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); // 进行启发式和探测式清理过程
}

ThreadLocalMap的探测式清理

/**
* 过程简单来说就是:从给定的过期槽位开始,向后探测,清掉期间遇见的过期槽位,对正常非null槽位进行再次hash,使得数据前移,不至于线性探测出现断层。
*/
private int expungeStaleEntry(int staleSlot) { // staleSlot起始过期槽位
    Entry[] tab = table;
    int len = tab.length;

	// 删去对过期entry的相关引用
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    Entry e;
    int i;
	// 向后探测,直到遇见null,清理掉期间遇见的过期entry,对正常的entry进行rehash
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) { //清理掉期间遇见的过期entry
            e.value = null;
            tab[i] = null;
            size--;
        } else { // 对正常的entry进行rehash,目的是前移数据,rehash的目的并不主要是为了查询效率,而是因为ThreadLocalMap解决hash冲突的方式所决定的。如果不进行rehash,get的时候因为连续空间中出现断层必然会引发问题(有些数据将永远查不到)
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) { //entry映射到的不是当前槽位,需要进行数据的前移,因为有过期槽位被清掉了。
                tab[i] = null;

                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i; //返回为null的槽位下标
}

ThreadLocalMap的启发式清理

/**
* i的下一个槽位就是开始清理的第一个槽位,n用来控制循环次数。
*/
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false; // removed代表这次清理过程是否有清掉过期entry
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) { // 如果碰见过期entry,触发探测式清理并重置循环次数
            n = len;
            removed = true;
            i = expungeStaleEntry(i); // 清掉i及之后的过期数据,返回探测式最后碰到的空槽位下标
        }
    } while ( (n >>>= 1) != 0); // 以无符号右移作为循环条件
    return removed;
}

ThreadLocalMap扩容机制

private void rehash() {
    expungeStaleEntries(); // 先进行探测式清理,看能否清掉一些过期数据
	// 官方注释:使用较低的倍增阈值以避免滞后
    if (size >= threshold - threshold / 4) // threshold为数组长度的2/3,那么这里就是说大于数组长度的一半时,进行扩容。为什么这里是数组长度的一半呢?
        resize();
}

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j); // 找到一个过期entry后就调用有参的探测式清理
    }
}

\

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;  	// 两倍扩容
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null;
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null) // 找到重新hash后的位置
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

ThreadLocalMap的getEntry()方法

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key) // 如果直接命中目标值
        return e;
    else
        return getEntryAfterMiss(key, i, e); // 对应出现哈希冲突的情况
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null) // 如果这个entry过期了,触发探测式清理,导致数据前移
            expungeStaleEntry(i); 
        else		   // 如果这个entry未过期,跳到下一个entry
            i = nextIndex(i, len); 
        e = tab[i];
    }
    return null;
}

InheritableThreadLocal

ThreadLocal是无法在父子线程间共享的。所以出现了InheritableThreadLocal。

实现原理:

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    this.stackSize = stackSize;
    tid = nextThreadID();
}

但InheritableThreadLocal仍然有缺陷,一般我们做异步化处理都是使用的线程池,而InheritableThreadLocal是在new Thread中的init()方法给赋值的,而线程池是线程复用的逻辑,所以这里会存在问题。

当然,有问题出现就会有解决问题的方案,阿里巴巴开源了一个TransmittableThreadLocal组件就可以解决这个问题,这里就不再延伸,感兴趣的可自行查阅资料。