ThreadLocal源码分析

116 阅读3分钟

使用场景

ThreadLocal一般用来隔离不同线程之间的数据,用于实现线程安全

使用示例

public static void main(String[] args) {
    ThreadLocal<String> threadLocal = new ThreadLocal<>();
    threadLocal.set("ThreadLocal");
    System.out.println(threadLocal.get());
    new Thread(() -> System.out.println(threadLocal.get())).start();
}

Output:
ThreadLocal
null

可以看到在main线程中设置了threadLocal的值并可以正常获取,而在另外一个线程中取不到main线程设置的值,从而做到在不同线程之间隔离数据

源码分析

JDK版本:1.8

set(T value)

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

可以发现set(T value)方法首先获取当前线程,然后获取了该线程所持有的ThreadLocalMap,当map非空时,将threadLocal自身作为键,value作为值设置到Entry[]中;而当map不存在时,创建一个ThreadLocalMap并由Thread持有

get()

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

get()方法也是需要先获取当前线程持有的ThreadLocalMap,当map非空时,遍历Entry数组获取key为当前threadLocal的Entry,获取到的话则返回该Entry的值,否则设置初始化值null并返回

Entry[]

通过源码我们可以发现每一个线程会持有一个ThreadLocalMap,而ThreadLocalMap实际上是一个Entry数组,初始化大小为16,每一个Entry存储了threadLocal和对应的值。与HashMap不同的是,ThreadLocalMap不存在next指针,因此在Entry的key,即threadLocal的hash冲突的时候只能通过寻找下一个位置(开放寻址法)存储来解决hash冲突,因此不建议一个线程同时绑定太多ThreadLocal,否则会造成太多hash冲突从而导致效率降低。

线程绑定多个ThreadLocalMap的情况

public static void main(String[] args) {
    ThreadLocal<String> threadLocal = new ThreadLocal<>();
    ThreadLocal<String> threadLocal2 = new ThreadLocal<>();
    ThreadLocal<String> threadLocal3 = new ThreadLocal<>();
    threadLocal.set("ThreadLocal");
    threadLocal2.set("ThreadLocal2");
    threadLocal3.set("ThreadLocal3");
    System.out.println(threadLocal.get());
    System.out.println(threadLocal2.get());
    System.out.println(threadLocal3.get());
}

Output:
ThreadLocal
ThreadLocal2
ThreadLocal3

关于Entry中key是WeakReference的看法

(以下看法仅是本人个人看法,如有错误,欢迎指正)

通过以上的源码分析我们可以发现,Thread中引用了ThreadLocalMap,而ThreadLocalMap是用Entry数组实现的,数组中的每个Entry都引用了一个threadLocal和保存的值。而一个线程的生命周期可能是很长的,如果调用threadLocal的set方法之后没有remove掉,那么Entry是一直引用着threadLocal的,这里可能会造成内存泄漏。通过Entry的key是WeakReference的实现,能够在内存不足时让JVM回收掉这部分内存。