ThreadLocal源码解析

189 阅读3分钟
简介

ThreadLocal是线程的局部变量,每个线程都是一个Thread对象,Thread中有类型为ThreadLocal.ThreadLocalMap的成员变量。ThreadLocal.ThreadLocalMap是每个线程实例的独有变量,所以不存在并发安全问题

ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
1.ThreadLocalMap解析

ThreadLocalMap为ThreadLocal的静态内部类,Entry为ThreadLocalMap的静态内部类,ThreadLocalMap采用Entry数组来存储元素的键值,遇到哈希冲突时使用线性探测法来解决。

  • ThreadLocalMap的构造方法
static class ThreadLocalMap {

    //Entry的构造方法
    static class Entry extends WeakReference<ThreadLocal<?>> {
        //每个线程对应ThreadLocal的独有值
        Object value;
        //Entry节点的key为ThreadLocal的实例对象,所有线程公用同一ThreadLocal实例的引用
        //同时由于继承了WeakReference<ThreadLocal<?>> ,作为key的ThreadLocal的实例对象是弱引用
        //value为每个线程独有的实例对象
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
   
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        //ThreadLocalMap的初始容量,必须为2的整数幂
        table = new ThreadLocal.ThreadLocalMap.Entry[INITIAL_CAPACITY];
        //对key的hash值取余确定当前元素在Entry数组中的位置
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        size = 1;
        table[i] = new ThreadLocal.ThreadLocalMap.Entry(firstKey, firstValue);
        //设置扩容阈值
        setThreshold(INITIAL_CAPACITY);
    }
}
  • 赋值方法
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    
    //寻找对应的key修改value,不断通过线性探测法使i = i + 1往下找
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        //判断哈希值相同的key,是否为同一对象
        if (k == key) {
            e.value = value;
            return;
        }
        if (k == null) {
            //清除那些key的实例对象已经被GC,但是value的实例对象还存活的Entry防止内存泄漏
            //下面会说到
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //当前节点为空,则创建Entry对象存储键值
    tab[i] = new Entry(key, value);
    int sz = ++size;
    //cleanSomeSlots同replaceStaleEntry方法,同时判断是否需要进行扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
  • 取值方法
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            //当作为key的键对象被回收后,要对这个Entry对象进行清除
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}
  • replaceStaleEntry,cleanSomeSlots方法都是要遍历Entry数组,清除其中key为null的节点,核心方法为expungeStaleEntry
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    //将要清除的Entry节点的引用置空,使键值对象在下次GC中被回收
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // 遍历数组进行清除直到Entry节点为空
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}
2.ThreadLocal解析

ThreadLocal的核心逻辑都在ThreadLocalMap中,ThreadLocal中的方法都是对ThreadLocalMap方法的封装调用

  • 赋值方法
public void set(T value) {
    Thread t = Thread.currentThread();
    //获取当前Thread对象的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    //调用ThreadLocalMap的构造方法为Thread对象的ThreadLocal.ThreadLocalMap变量赋值
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
  • 取值方法
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //如果threadLocals为空,则对其进行初始化
    return setInitialValue();
}

private T setInitialValue() {
    //设定的默认值
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}
3.总结

上面只是把关于ThreadLocal的set及get方法进行了解析,其实理解ThreadLocal的使用原理首先要明白线程对象使用ThreadLocal中用Map来保存键值,Map的底层存储结构为Entry数组,每个线程对象都有自己的Map,多个线程间的关联在于对于同一个ThreadLocal属性会共享键对象的引用,value保存的值对象是每个线程独有的,然后还需要理解为什么键对象要用弱引用,以及JVM对于弱引用的处理。