ThreadLocal原理解析

1,088 阅读4分钟

ThreadLocal相信大家日常开发中都经常使用。下面我们看几个常见的问题,看看大家是否对它足够了解。

问题

  1. ThreadLocal存储在jvm的哪个区域
  2. 为什么用Entry数组而不是Entry对象
  3. ThreadLocal里的对象一定是线程安全的吗
  4. 为什么Entry key设计成弱引用
  5. 会导致内存泄漏么
  6. ThreadThreadLocalThreadLocalMapEntry在内存中的关系是什么样的

如果以上问题,你都知道,可以忽略本文。

内存关系梳理

首先我们new 一个ThreadLocal对象

public class ThreadLocalTest {
    public static void main(String[] args) {
        ThreadLocal<String> threadLocal = new ThreadLocal<>();
    }
}

此时内存关系如图

2.png

往ThreadLocal里面set一个值,同时查看相关源码

public class ThreadLocalTest {
    public static void main(String[] args) {
        ThreadLocal<String> threadLocal = new ThreadLocal<>();
        threadLocal.set("1"); //从这里进入看源码
    }
}

=======以下为ThreadLocal源码===========================

public void set(T value) {
    //1.获取当前运行的线程(demo里面为main线程)
    Thread t = Thread.currentThread();
    //获取t.threadLocals字段(ThreadLocalMap对象)
    ThreadLocalMap map = getMap(t);
    //此时map为null,走else逻辑
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

void createMap(Thread t, T firstValue) {
    //2.创建ThreadLocalMap对象,赋值给t.threadLocals。key为当前threadLocal,value为"1"。
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

//ThreadLocalMap的构造函数
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    //3.新建Entry数组
    table = new Entry[INITIAL_CAPACITY];
    //计算数组下标
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    //4.新建Entry对象,放入数组对应位置
    table[i] = new Entry(firstKey, firstValue);
    //数组里面元素个数
    size = 1;
    //设置扩容阈值
    setThreshold(INITIAL_CAPACITY);
}

//Entry类定义
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        //5.key为弱引用
        super(k);
        value = v;
    }
}

根据上面源码中,标记的1~5点,我们内存图变成了下面的样子。

3.png

一句话总结就是:Thread维护了ThreadLocalMap,ThreadLocalMap里维护了Entry数组,而Entry里存的是以ThreadLocal(弱引用)为key,传入的值为value的键值对。

说完了set方法,我们看看调用get方法如何获取对应的值。

public T get() {
    //获取当前线程
    Thread t = Thread.currentThread();
    //通过线程对象获取ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //获取key为threadLocal对应的Entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

private Entry getEntry(ThreadLocal<?> key) {
    //获取数组下标
    int i = key.threadLocalHashCode & (table.length - 1);
    //获得数组对应元素
    Entry e = table[i];
    if (e != null && e.get() == key)
        //返回key为threadLocal对应的Entry
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

get方法获取值大概图示

3_1.png

一切看上去很美好。但是弱引用只要发生GC,引用链就会断掉。所以内存变为如下图。

4.png

这个时候大家是不是会有个疑问,调用get方法的时候,是否能正确获得对应的值。 我们写一个小Demo来测试一下。

public class WeakRefTest {
    public static void main(String[] args) throws InterruptedException {
        //Foo(1)对象只有弱引用
        WeakReference<Foo> wr1 = new WeakReference<>(new Foo(1));
        //Foo(2)对象有弱引用和强引用(这种情况和ThreadLocal类似)
        Foo foo = new Foo(2);
        WeakReference<Foo> wr2 = new WeakReference<>(foo);

        System.out.println("gc前wr1 "+wr1.get());
        System.out.println("gc前wr2 "+wr2.get());
        System.gc();
        Thread.sleep(100);
        System.out.println("gc后wr1 "+wr1.get());
        System.out.println("gc后wr2 "+wr2.get());
    }
    static class Foo{
        public int a;
        Foo(int a){
            this.a = a;
        }
        @Override
        public String toString() {
            return "Foo{" +
                    "a=" + a +
                    '}';
        }
    }
}

运行结果
gc前wr1 Foo{a=1}
gc前wr2 Foo{a=2}
gc后wr1 null
gc后wr2 Foo{a=2}

根据上面Demo运行情况,我们画一下内存图。 GC前

5.png wr1,wr2都能正常访问。

GC后

6.png wr1为null,wr2能正常访问。

结论:引用的断开不会影响我们引用的寻址功能。引用的断开只会导致引用链断开导致对象被GC回收,但是!此时若有一个强引用引用着,那么弱引用就可以在无引用链的情况下继续访问该对象。(这里扩展一下。若对象的地址强制改变,弱引用将无法继续跟踪)

通过上面的Demo我们回头看看ThreadLocal的内存图,能得知就算引用链断裂,get()方法也能获得值。

7.png

为什么ThreadLocalMap中的key要设置成弱引用

public class ThreadLocalTest {
    public static void main(String[] args) {
        ThreadLocal<String> threadLocal = new ThreadLocal<>();
        threadLocal.set("1");
        threadLocal = null;
    }
}

上面的代码块我们期望通过设置threadLocal=null,从而让threadLocal被回收。 假设此时不用弱引用,内存如下图

8.png 如果使用弱引用,ThreadLocal对象就能被正确的回收。这就是使用弱引用的原因。

Entry的key内存泄漏

我们将threadLocal成功回收后,内存图画出来

9.png 如上图,由于我们将ThreadLocal对象的成功回收,我们的key变为null了。但是我们的value依旧存在,因此这一组数据的value由于key为null的原因也无法访问导致内存泄漏。

ThreadLocal针对内存泄漏的优化

//set方法优化
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    //遇到hash冲突的时候,沿着数组向后遍历
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }
        
        //发现为null的key,直接替换之前的值
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    //cleanSomeSlots方法,看看有没有为null的key需要清理
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            //发现key为null,删除对应的entry
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

//get方法优化

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //获取entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        //处理key为null的情况
        return getEntryAfterMiss(key, i, e);
}


private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            //发现key为null,删除对应的entry
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

//remove方法
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            //key设置为null
            e.clear();
            //删除对应的entry
            expungeStaleEntry(i);
            return;
        }
    }
}

remove是我们主动触发,清理Entry的方式。可以加速我们泄漏的内存回收。因此,如果当栈中的引用变为null时,我们可以再次调用remove()方法,将ThreadLocalMap中的Entry进行清理。(更具时效性)

线程退出时优化

private void exit() {
    if (group != null) {
        group.threadTerminated(this);
        group = null;
    }
    /* Aggressively null out all reference fields: see bug 4006245 */
    target = null;
    /* Speed the release of some of these resources */
    //threadLocalMap的引用断了,加速回收
    threadLocals = null;
    inheritableThreadLocals = null;
    inheritedAccessControlContext = null;
    blocker = null;
    uncaughtExceptionHandler = null;
}

10.png

最后我们回头看文章开头提出的问题(只有2,3没有在文中直接回答,其他都已经有了答案)

  1. 为什么用Entry数组而不是Entry对象 因为你业务代码能new好多个ThreadLocal对象。但是一个线程里,ThreadLocalMap是同一个,而不是多个,不管你new几次ThreadLocal,ThreadLocalMap在一个线程里就一个,因为ThreadLocalMap的引用是在Thread里的,所以它里面的Entry数组存放的是一个线程里你new出来的多个ThreadLocal对象。

  2. ThreadLocal里的对象一定是线程安全的吗 未必,如果在每个线程中ThreadLocal.set()进去的东西本来就是多线程共享的同一个对象,比如static对象,那么多个线程的ThreadLocal.get()获取的还是这个共享对象本身,还是有并发访问线程不安全问题。