ThreadLocal 源码解析(一)

709 阅读7分钟

前言

在开始之前先看下代码如下:

代码片段1
ThreadLocal<String> stringThreadLocal = new ThreadLocal<>();
new Thread(() -> {
    stringThreadLocal.set("测试");
    System.out.println("子线程获取:" + stringThreadLocal.get());
}).start();
Thread.sleep(200);
System.out.println("主线程获取:" + stringThreadLocal.get());

输出如下:

子线程获取:测试
主线程获取:null

是不是感觉哪里不对劲。按道理,如果是其他的类似于集合的数据结构,我只要在子线程进行了赋值操作,又在主线程执行了sleep()操作那么主线程应该是可以获取到最新的赋值的,比如将上面ThreadLocal改为ArraList,代码如下:

代码片段2
ArrayList<String> arrayList = new ArrayList<>();
new Thread(() -> {
    arrayList.add("测试");
    System.out.println("子线程获取:" + arrayList.get(0));
}).start();
Thread.sleep(200);
System.out.println("主线程获取:" + arrayList.get(0));

输出如下:

子线程获取:测试
主线程获取:测试

再看下如下代码:

代码片段3
ThreadLocal<String> stringThreadLocal = new ThreadLocal<>();
stringThreadLocal.set("测试");
Thread.sleep(200);
new Thread(() -> {
    System.out.println("子线程获取:" + stringThreadLocal.get());
}).start();
System.out.println("主线程获取:" + stringThreadLocal.get());

输出结果如下:

主线程获取:测试
子线程获取:null

通过以上三个代码片段,似乎可以发现一些端倪,那就是ThreadLoal通过set()操作进行的赋值操作的作用范围似乎只能是在当前线程,那它是如何实现这种线程隔离的效果的呢? Talk is cheap,show me the fuck code.

正文

正文主要分亮大部分ThreadLocal本身的函数、与ThreadLocal息息相关的ThreadLocalMap

ThreadLocal本身

先看构造函数:

public ThreadLocal() {
}

没啥特殊的,那就看下set()函数:

插入数据

代码片段4
public void set(T value) {
    //注释1
    Thread t = Thread.currentThread();
    //注释2
    ThreadLocalMap map = getMap(t);
    if (map != null)
    //注释4
        map.set(this, value);
    else
        //注释3
        createMap(t, value);
}

注释1处获取当前线程t,然后把t当做参数通过getMap,获取ThreadLocalMap,如下:

代码片段5
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

返回的是是Thread的一个成员变量,那我们就来看看threadLocals初始化和赋值的地方如下:

代码片段6
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

可以发现threadLocals是在ThreadLocal#createMap()里面赋值的,再回到代码片段4的注释2显然一开始这里的map为null,逻辑走到了注释3的createMap(),也就是在这里完成了threadLocals的初始化以及赋值。关于ThreadLocalMap晚点再看,我们趁热打铁来看get()

取数据

get()函数如下:

代码片段7
public T get() {
    //注释1
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
     //注释2
    if (map != null) {
     //注释3
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
     //注释4
    return setInitialValue();
}

和set()一样,先获取当前线程然后通过getMap()获取ThreadLocalMap,因为前面已经调用过来set()函数,所以map不等于null,走到注释3,通过 ThreadLocalMap#getEntry()获取Entry并返回Entry的value,如果map等于走到注释4的逻辑 如下:

代码片段8
private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

这里可以看做是走了一遍set(null)的操作,并返回空值

小结

通过以上set()和get()方法的分析能够知道,两者都需要调用getMap()拿到threadLocals,再通过threadLocals去取值或者赋值,而threadLocals是Thread的成员变量,不同线程实例有自己的hreadLocals,天然形成了线程隔离,这也就能回答文章开头提出的问题。

ThreadLocalMap

通过源码分析知道ThreadLoal的各种操作与ThreadLocalMap紧密相关,那接下来我们看看ThreadLocalMap的构造函数

构造函数

代码片段9
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    //注释1 初始化数组
    table = new Entry[INITIAL_CAPACITY];
    //注释2 通过hash与INITIAL_CAPACITY取模(对一个2的整次幂的减一数进行&操作相当于取模)获取firstKey在数组的位置i
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    //注释3 插入数据
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    注释4 设置扩容阈值
    setThreshold(INITIAL_CAPACITY);
}

通过构造函数底层的存储结构是个数组,通过对 ThreadLocal 的hash值与默认容量取模获取到在数组的位置,插入、更新、删除、数据,并且在必要的时候进行扩容

Entry

看下Entry这个类

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

可以看到Entry继承自WeakReference,是个弱引用,也就是再内存不足的时候,触发GC会被回收

插入数据

从代码片段4我们知道,ThreadLocal#set()函数,会调用到代码片段4注释4出把自身当做key,传入到ThreadLocalMap#set(),如下:

代码片段10
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
        //注释1 从当前位置往后查找适合插入的位置
    for (Entry e = tab[i];
         e != null;
         //注释2 nextIndex到达尾部之后,会回到开头
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
        //注释3 数组中已经存在,直接更新返回
            e.value = value;
            return;
        }
        //注释4 Entry不为空,key为空,说明key已经被GC
        if (k == null) {
        // 注释5 把过期数据数据替换为新值
            replaceStaleEntry(key, value, i);
            return;
        }
    }
        //注释6 找到一个空的位置,插入数据
    tab[i] = new Entry(key, value);
    // 注释7 size 加1
    int sz = ++size;
    //注释8 清理过期数据,判断是否是需要扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

以上代码可知,ThreadLocalMap#set()的关键是找到合适的位置,下面结合图形分析插入过程,假设当前 ThreadLocaMap 中数据为空,如下图: image.png

插入entry0,假设通过hash取模计算的i值为0,在注释1处由于0这个位置的数据是空的,执行插入,插入后数据结构如下:

image.png 继续插入entry2,entry3,entry5,数据结构如下:

image.png 假设这个时候继续插入一个entryA,假设通过key的hash与数组取模得到的索引i为2,这个位置已经有了一个entry2,产生了hash冲突,注释1处继续往下遍历,直到找到一个空的位置4,执行插入,插入后数据结构如下: image.png 如果这个时候执行了删除数据entry3的操作,比如删除数据,数据结构如下:

image.png 这个时候执行插入entryB的操作,假设通过key的hash与数组取模得到的索引i为3,走到注释4,替换过期数据,执行后数据结构如下:

image.png 继续插入entry7,数据结构如下:

image.png 继续插入entryC,假设通过key的hash与数组取模得到的索引i为7,由于7这个位置不空,产生了hash冲突,这个时候注意注释2处的nextIndex(),代码如下:

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

可以看到从当前位置遍历到末尾时,会从头开始遍历,知道找到空的位置,插入entryC的过程,通过执行完注释1处的循环,找到了插入位置1,执行插入,插入完成之后数据结构如下:

image.png 通过以上的插入操作过程中产生的hash冲突后的解决策略,可以知道ThreadLocaMap的冲突解决策略是开发地址法 插入的整个流程图如下: image.png

读取数据

代码片段11
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
    //注释1
        return e;
    else
    //注释2
        return getEntryAfterMiss(key, i, e);
}

注释1处通过key的hash与数组取模得索引i,如果i位置的entry的key与要查找的key值相等,直接返回,否则执行注释2,通过上节中的插入数据我们知道 ThreadLocaMap 的冲突解决策略是开发地址法,i值不一定是元素的插入位置,需要进一步查找,进入到getEntryAfterMiss(),如下:

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 注释1 
        if (k == key)
            return e;
         // 注释2
        if (k == null)
            expungeStaleEntry(i);
        else
         // 注释3
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

遍历逻辑与插入数据大同小异,注释1处表示已经找到了,注释2处表示是过期数据,清理过期数据,注释3处继续往下遍历,知道找到一个空值,结束循环

小结

通过以上对ThreadLocalMap的源码分析,相信对ThreadLocalMap有了初步了解,其实它的底层数据结构就是一个数组,通过开放地址法解决hash冲突。

总结

考虑到本文行文已经过长,决定对ThreadLocalMap的删除数据、清除过期数据以及扩容等细节实现,单独开篇,敬请期待