基本功之ThreadLocal

286 阅读7分钟

目录

1 问题由来

2 ThreadLocal是什么

3 底层原理

3.1 ThreadLocal

3.1.1 原理

3.1.2 源码

3.1.3 内存泄漏

3.2 InheritableThreadLocal

3.1.1 原理

3.1.2 源码

3.1.3 踩坑

1 问题由来

前些日子做接口优化时,引入了diff4j组件(公司内部开发的新老接口对比的组件),通过异步线程去比对新老接口的返回结果是否一致。线上观察发现有部分不一致的情况。

通过新老接口返回的数据和日志发现,获取「操作人的二级品类权限集合」不一致,导致查询结果不同。通过代码分析,发现在异步线程中获取当前操作人,用的是封装好的ThreadLocalContext。

image.png image.png image.png

image.png

image.png

众所周知,ThreadLocal是线程私有的,User信息是当前主线程的threadLocal中存储的,且异步线程是新建的线程池,并没有和主线程同属于一个池子,那异步线程中获取threadLocal中的user信息,理论上应该一直获取不到才对,为什么线上有时候能获取到,有时候还获取错误呢?

2 ThreadLocal是什么

ThreadLocal:线程本地变量。

ThreadLocal为每个线程创建一份变量副本,线程之间不需要竞争共享变量,通过访问本地化资源提升效率,适用于变量在线程间隔离,而在方法或类间共享的场景。

3 底层原理

如果让你设计ThreadLocal,该如何实现?

最容易想到的实现方式就是,用一个Map存储,key为线程名,value为变量。但是多线程访问ThreadLocal.get方法时,ThreadLocal本身也就变成共享变量,还需要保证ThreadLocal的线程安全...

3.1 ThreadLocal

3.1.1 原理

每个线程Thread内部保存一个 map变量,map的key为ThreadLocal对象,value为线程私有的变量。

Thread内部的Map数据是通过ThreadLocal来get和set的。

3.1.2 源码

image.png

image.png

ThreadLocal#set方法

代码块

Java

    public void set(T value) {

        Thread t = Thread.currentThread();//当前线程

        ThreadLocalMap map = getMap(t);//获取当前线程t的ThreadLocalMap

        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);//如果当前线程的map为空,则初始化一个map,并赋值value
    }

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    void createMap(Thread t, T firstValue) {

        t.threadLocals = new ThreadLocalMap(this, firstValue);

    }

ThreadLocal#get方法

image.png ThreadLocal#remove方法

代码块

Java

     public void remove() {

         ThreadLocalMap m = getMap(Thread.currentThread());

         if (m != null)

             m.remove(this);//将a1那根实线(引用)断掉

     }

Class ThreadLocalMap

类似HashMap,都是基于数组实现。但是对于hash冲突的解决方法不一样。HashMap是通过链表(红黑树)法来解决冲突,而ThreadLocalMap是通过开放寻址法来解决冲突。

代码块

Java

 static class ThreadLocalMap {

      static class Entry extends WeakReference<ThreadLocal<?>> {

            /** The value associated with this ThreadLocal. */

            Object value;

            Entry(ThreadLocal<?> k, Object v) {

                super(k);

                value = v;

            }

        }

      /**

         * The initial capacity -- MUST be a power of two.

         */
        private static final int INITIAL_CAPACITY = 16;
        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;
        /**
         * The number of entries in the table.
         */
        private int size = 0;
        /**
         * The next size value at which to resize.
         */
        private int threshold; // Default to 0
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
   
    private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }

            }
        }
   
   private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
   /**
         * Version of getEntry method for use when key is not found in
         * its direct hash slot.
         *
         * @param  key the thread local object
         * @param  i the table index for key's hash code
         * @param  e the entry at table[i]
         * @return the entry associated with key, or null if no such
         */
   private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

/**
         * Increment i modulo len.
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
     private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                if (k == key) {
                    e.value = value;
                    return;
                }
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
    private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;
            for (int i = nextIndex(staleSlot, len);

                 (e = tab[i]) != null;

                 i = nextIndex(i, len)) {

                ThreadLocal<?> k = e.get();


                if (k == key) {

                    e.value = value;

                    tab[i] = tab[staleSlot];


                    tab[staleSlot] = e;
                    if (slotToExpunge == staleSlot)

                        slotToExpunge = i;

                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);

                    return;

                }

                if (k == null && slotToExpunge == staleSlot)

                    slotToExpunge = i;

            }

            tab[staleSlot].value = null;

            tab[staleSlot] = new Entry(key, value);

            if (slotToExpunge != staleSlot)

                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);

        }

 }

image.png

image.png

ThreadLocalMap#set方法

image.png

3.1.3 内存泄漏

内存泄漏,就是指程序中已经无用的内存无法被释放,造成系统内存的浪费

a. 为什么Enrty对key的引用设置成弱引用?(引用-强软弱虚

image.png

根据上面这个引用链可以看到,如果Entry对key不是弱引用,而是强引用,那么只要线程在,那么相关的ThreadLocal对象肯定就会一直在。实际情况中,线程又经常是以线程池的方式来使用的。比如 Tomcat/Jetty 的线程池处理了一堆请求,而线程池中的线程一般是不会被清理掉的,所以这个引用链就会一直在,那么 ThreadLocal 对象即使没有用了,也会随着线程的存在,而一直存在着!所以如果是强引用,会造成内存泄漏。但是如果用弱引用来实现,ThreadLocal在没有其他强引用的情况下,那就会被GC回收就会避免部分内存泄漏的情况。

b. key设置为弱引用之后,还会有内存泄漏问题?

上面说到的是Entry与key的引用,如果ThreadLoca对象没有其他强引用,则GC之后就会回收(ThreadLocal对象不会造成内存泄漏),但是value对象会一直存在(Entry的key会设置为null),并且成为了无用的垃圾数据,也没办法进行回收,因为Entry对象被线程一直强引用。这种情况value就可能会造成内存泄漏。但是设计者肯定也知道这个问题,也做了很多清理这种key为null的Entry操作,一定程度的避免内存泄漏。譬如上面代码中的ThreadLocalMap#getEntry方法,扩容的时候,都会执行expungStaleEntry清理无用的Entry。(这里可以将value也设置成弱引用吗?)

image.png 但是被动等着这些操作执行时清理无用的Entry,就只能防止一部分内存泄漏,因为有可能后面就没人调用了。所以正确的操作应该是用完了之后,调用一下remove方法,手工把Entry从map中移除,GC时就会把这个Entry回收。

代码块

Java

threadLocal.set("XXX");

try{

  //do sth

}finally{

  threadLocal.remove();

}

3.2 InheritableThreadLocal

3.1.1 原理

InheritableThreadLocal继承ThreadLocal,当子线程创建时,可以从父线程继承所有的inheritable thread-local variables到子线程。

3.1.2 源码

image.png

image.png

image.png

image.png

从源码中可以看到,Thread类中不止有一个threadLocals属性,还有一个inheritableThreadLocals属性。该属性是在创建线程时,继承了父线程的thread-local variables(也给下面埋了坑)。

3.1.3 踩坑

看源码发现,InheritableThreadLocal适用于在子线程创建时将父线程中的threadLocal副本传递给子线程。但是生产环境中用的线程池复用,这会导致新的threadLocal副本无法放入到子线程中。

回到最初的那个问题上,我们的http请求都是在Jetty的工作线程池中,而对比逻辑引入的Diff4J,新建了MDP线程池,在每一次请求进来之后,都会通过异步线程获取当前登录人(user存储在inheriateThreadLocal),根据当前登录人的二级品类权限列表查询计费单。

MDP线程池中的线程在第一次创建时,会继承父线程的thread-local变量值userA。由于是池化资源,所以下次线程B请求再进来时,复用了线程A1,而线程A1没有触发新建线程,所以线程A1不会继承父线程B的thread-local values,而是保留了thread-local:userA。所

以线上的现象是B用户登录进来之后,新接口却使用的是A用户的二级品类权限列表。

image.png

解决方案1:将ThreadLocal的参数提取到主线程获取,通过增加入参将operator变量传进子线程

解决方案2: 阿里开源项目TransmittableThreadLocal结合TtlExecutors线程池使用

扩展:

阿里开源TransmittableThreadLocal juejin.cn/post/706437…

Netty实现的FastThreadLocal juejin.cn/post/700508…