源码 ThreadLocal (线程私有变量)

1,227 阅读3分钟

简单介绍

ThreadLocal 是用来声明一个线程私有变量的容器。在被多个线程持有时,ThreadLocal 能保证每个线程都能拥有独一无二的实例。示例如下

public class Main {

    public static void main(String[] args) {
        BizOperator operator = new BizOperator();
        Runnable task = operator::operate;
        // 以下四个线程在执行 BizOperator::operate 时,访问到的 threadPriVar 都不一样
        new Thread(task).start();
        new Thread(task).start();
        new Thread(task).start();
        new Thread(task).start();
    }

    static class BizOperator {
        ThreadLocal<Object> threadPriVar = ThreadLocal.withInitial(Object::new);

        void operate() {
            // 每个线程在这里获取到的 someVar 都不一样
            Object someVar = threadPriVar.get();
        }
    }
}

源码概览

在开始之前,先思考一个问题:这个需求如果甩在面前,自己会如何实现?简单的思路可以在 ThreadLocal 里维护一个 Map<Thread, VALUE>,然后去设值、取值即可。伪代码如下:

public class ThreadLocalHomeMade {
        Map<Thread, Object> threadLocalMap= new HashMap<>();

        public Object get() {
            return threadLocalMap.get(Thread.currentThread());
        }
        public void set(Object val) {
            threadLocalMap.put(Thread.currentThread(), val);
        }
    }

emm,好像不对,在 ThreadLocal 中维护了所有了访问当前变量的 Thread 引用,且不好把握释放引用的时机,这样会给垃圾回收带来难度,会造成不同程度的内存泄露。不过我们可以通过 Reference 等来缓解这个问题。

我们跟一下源码,来看下JDK是如何实现的

ThreadLocal::get 找到 ThreadLocal::getMap 找到 Thread.threadLocals 就是他了,JDK通过把所有线程私有变量存在当前线程的 threadLocals 变量中。soga,JDK是以线程维度去存储所有变量,而在我的实现中,是以单个变量的维度去存储所有线程的私有变量。

有大神曾曰,设计的根本在于数据结构。那我们来看一下 ThreadLocalMap 的实现

/*
仅从成员变量上来看,其实是一个简化版的 HashMap
*/
class ThreadLocalMap {
        // 初始容量
        private static final int INITIAL_CAPACITY = 16;
        // 数据的实际存储容器
        private Entry[] table;
        // 存储的数据总数
        private int size = 0;
        // 当size超过该值时,会对table进行扩容
        private int threshold; // Default to 0
}

我们再来看下 Entry 的源码

        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

那其实这里和HashMap就非常之像了,ThreadLocal充当Key的角色,然后Value即是线程的私有变量。他的Hash算法:key.threadLocalHashCode & (table.length - 1)

深入一下

ThreadLocal::get

跟入源码后按照路径 ThreadLocal::get -> ThreadLocalMap::getEntry 找到第一处核心代码 ThreadLocalMap::getEntryAfterMiss,这端是在HashCode碰撞时,线性探测ThreadLocal的值

      Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                // 存在Hash碰撞,因此算出HashCode后还需要判断ThreadLocal
                if (k == key)
                    return e;
                // ThreadLocal 可能已经被JVM回收
                if (k == null)
                    expungeStaleEntry(i);
                // 线性探测
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

然后找到第二处核心代码ThreadLocalMap::expungeStaleEntry 用于未命中时,ReHash table中的元素。

int expungeStaleEntry(int staleSlot) {
    ThreadLocal.ThreadLocalMap.Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    ThreadLocal.ThreadLocalMap.Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // ThreadLocal 已被回收
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 在table中的理想位置
            int h = k.threadLocalHashCode & (len - 1);
            // 线性探测可能导致的元素偏离其理想位置
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

ThreadLocal::set

按照调用路径 ThreadLocal::set -> ThreadLocalMap::set

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    ThreadLocal.ThreadLocalMap.Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (ThreadLocal.ThreadLocalMap.Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }
       // 当前元素可能已被垃圾回收。
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new ThreadLocal.ThreadLocalMap.Entry(key, value);
    int sz = ++size;
    // 如果清空无效元素后,低于阈值 threshold,则进行rehash
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}