ThreadLocal、InheritableThreadLocal、TransmittableThreadLocal相关内容

341 阅读10分钟

1.简介

ThreadLocal :在项目中,有进行使用;在面试中也经常被问到。今天就写一下自己对其的理解;

2.使用场景

我在项目中,有一部分是MNS消息队列对数据进行消费;在对数据进行消费的时候,需要使用一个工具类,做一些参数校验,数据处理相关的工作;那么,怎么解决这个问题呢?很显然使用锁是可以的;但是,使用锁在并发冲突的时候对性能会产生一定的影响;那么,有没有别的思路来解决这个问题呢?假设我们每个线程中,都有一个该工具类不就是解决了这个问题了吗?毕竟“无锁化”才是最高的境界;刚好,ThreadLocal就是解决这个问题的;

以上是我自己遇到的场景,还有一个场景就是,用ThreadLocal保存一些业务内容(用户权限信息、从用户系统获取到的用户名、userID等)
目的:就不需要再进行传参了,如果,想要获取到权限的话,通过ThreadLocal来进行获取即可; 这里,强调的是同一个线程内,不同方法间的共享;

2.1 ThreadLocal的使用

// 创建一个ThreadLocal
ThreadLocal<SimpleDateFormat> simpleDateFormat = new ThreadLocal(){
        @Override
        protected SimpleDateFormat initialValue(){
            return new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");
        }
    };

2.2 ThreadLocal 中 set() 和 initialValue()的关系

initialValue(): 在ThreadLocal第一次get的时候,把对象初始化出来,对象的初始化由我们自己控制; 注意:如果之前调用过set() 那么, 就不会调用该方法了。 set(): 通过set()也可以实现对象Entry的设置;
总结:
1.如果,针对同一个key,调用get()之前,没有调用过set(),就会通过initialValue()来设置值
2.如果,调用get()之前,调用过set(),那么,initialValue()就不会再进行调用了;
待会会有源码分析,现在,先来个实验测试:

image.png
没有set()时, 执行了initialValue()方法;

image.png
调用了set()方法时,没有执行initialValue()方法;

3.ThreadLocal源码分析

在源码分析之前,先介绍一下ThreadLocalMap中内存泄漏的问题。每一个线程中,都有一个ThreadLocalMap,ThreadLocalMap中,对key是弱引用,但是,Entry中,key对Value是强引用;所以,这个value的引用链条如下:

image.png
这个时候,我们可以看到,经历垃圾回收以后,再通过我们的ThreadLocal 可以找到Entry 但是,通过Entry是无法找到我们的Value了; 而Entry中的key对value的引用是强引用,所以,无法进行垃圾回收,只有在该线程进行回收时,才能把该区域进行回收;但是,很不巧现在一般是使用线程池来进行线程的复用,所以,这就造成这个value出现泄漏的可能;所以,根据阿里Java开发手册,一般使用完ThreadLocal的时候,都是需要手动进行remove();

3.1ThreadLocal set()方法分析

 private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            // 根据hash,找到数组中的一个位置
            int i = key.threadLocalHashCode & (len-1);
            // 如果这个位置没有被占用,说明没有冲突,没有冲突,那就不用循环了,直接使用这个位置
            // 如果冲突了,就需要一直往下找,找到一个可以用的位置,这个和HashMap不同,
            // ThreadLocalMap采用的是线性探测法
            for (Entry e = tab[i];e != null;e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                // 进入循环说明已经是冲突了
                // 先判断key是否相同,如果相同的话,就进行覆盖即可
                if (k == key) {
                    e.value = value;
                    return;
                }
                
                // 如果key为null,说明原来的key已经被回收,那么,就需要进行清理
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            
            // 跳出循环,说明通过线性探测法已经找到存储的位置了;
            // 然后把这个Entry放进去就可以了
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

ThreadLocalMap中的Hash冲突处理
ThreadLocalMap和HashMap不同,HashMap碰到hash冲突的时候,采用的是链表法
ThreadLocalMap碰到hash冲突的时候,采用的是最简单的线性探测法。如果发生了元素冲突,那么就往下找,找到一个合适的槽位来进行存放。具体实现可以看上面的put()源码分析;

3.2ThreadLocal get()方法分析

ThreadLocal变量在单个线程内是可见的,那么它是如何实现的呢?

    public T get() {
        // 获取当前线程
        Thread t = Thread.currentThread();
        // 每个线程中,都有一个自己的ThreadLocalMap
        // ThreadLocalMap里面保存着所有的ThreadLocal变量
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            // ThreadLoaclMap的key 就是当前ThreadLocal对象实例
            // 多个ThreadLocal变量都是放在这个map中的
            // 这个this 就是threadLocal对象实例 调用的时候是threadLocal.get() this 就是指这个threadLocal
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                // 如果,通过threadLocal能找到Entry 那么,就返回Entry的value
                T result = (T)e.value;
                return result;
            }
        }
        // 否则,就执行InitialValue()方法,加入对应的key(ThreadLocal) - value组合;
        return setInitialValue();
    }

我们刚刚set()的时候,分析出来当有hash冲突的时候,采用的是线性探测法;所以,这个时候,在get()的时候,就有讲究了;因为 ThreadLocalMap.Entry e = map.getEntry(this);是获取到当前的Entry,所以,解决线性探测法的相关逻辑肯定是在这里;

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            // 如果,直接找到了,就直接返回
            if (e != null && e.get() == key)
                return e;
            else
                // 如果,没有找到,那么就得向后找了,因为涉及到线性探测相关技术
                return getEntryAfterMiss(key, i, e);
        }

下面是 getEntryAfterMiss() 的实现:

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
            // 因为采用线性探测,所以,如果第一个元素不对,后面的元素也有可能是对的
            // 直到找到null为止
            // 整个e是entry,也就是一个弱引用
            while (e != null) {
                ThreadLocal<?> k = e.get();
                // 如果找到了,就返回
                if (k == key)
                    return e;
                // 如果key为null,说明弱引用已经被回收了
                // 那么就要在这里回收里面的value了
                if (k == null)
                    expungeStaleEntry(i);
                else
                    // 如果key不是要找的那个,那说明有hash冲突,这里是处理冲突,找下一个entry
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

用来回收value的是expungeStaleEntry(i)函数;查看相关函数;

private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // 把threadLocal为null, value为null的元素进行清除
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // 往后进行判断,所有的key为null的Entry都进行清理
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);(e = tab[i]) != null;i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    //这里进行rehash 为了解决ThreadLocalMap中线性探测,不能留空的问题;就是中间不能空出来一部分
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

真正用来回收value的是expungeStaleEntry()方法,在remove()和set()和get()方法中,都会直接或者间接的使用到这个方法;所以,使用这三个函数时,可以对key为null的value进行清理;

4.InheritableThreadLocal

4.1 代码实现

在实际开发中,可能会有这样的需求,父线程中开了一个子线程,但是,希望子线程可以访问主线程中的ThreadLocal对象,也就是说有些数据需要进行父子线程之间的传递;先来个代码演示:

    public static void main(String[] args) {
        ThreadLocal<String> Simple = new ThreadLocal();
        Simple.set("我是阿宝");
        System.out.println(Simple.get());
        new Thread(() -> {
            System.out.println(Simple.get());
        }).start();
    }

结果如下:

我是阿宝
null

很显然,在子线程中是无法获取到父线程的ThreadLocal,为了解决这个问题:我们可以使用InheritableThreadLocal;

    public static void main(String[] args) {
        ThreadLocal<String> Simple = new InheritableThreadLocal<>();
        Simple.set("我是阿宝");
        System.out.println(Simple.get());
        new Thread(() -> {
            System.out.println(Simple.get());
        }).start();
    }

结果如下:

我是阿宝
我是阿宝

4.2 底层实现

为什么它可以实现这样的功能呢?我们可以通过源码来进行分析:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

我们可以看到它是继承了ThreadLocal然后重写了三个方法;我们找到对应的ThreadLocal这三个方法:

public class ThreadLocal<T> {
    T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }
    
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
   
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
}

我们可以看到InheritableTheadLocal中createMap(),以及getMap()方法处理的对象和ThreadLocal方法中的处理对象不太一样了。在TheadLocal中处理的是threadLocals,而InheritableThreadLocal中用的是inheritableTheadLocals;我们这个时候,在来看看Thread对象的 init()方法,看看对不同的对象是如何处理的;

    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        if (name == null) {
            throw new NullPointerException("name cannot be null");
        }

        this.name = name;
        // 先获取到父类线程
        Thread parent = currentThread();
        SecurityManager security = System.getSecurityManager();
        if (g == null) {

            if (security != null) {
                g = security.getThreadGroup();
            }

            if (g == null) {
                g = parent.getThreadGroup();
            }
        }

        g.checkAccess();

        if (security != null) {
            if (isCCLOverridden(getClass())) {
                security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
            }
        }

        g.addUnstarted();

        this.group = g;
        this.daemon = parent.isDaemon();
        this.priority = parent.getPriority();
        if (security == null || isCCLOverridden(parent.getClass()))
            this.contextClassLoader = parent.getContextClassLoader();
        else
            this.contextClassLoader = parent.contextClassLoader;
        this.inheritedAccessControlContext =
                acc != null ? acc : AccessController.getContext();
        this.target = target;
        setPriority(priority);
        // 如果父类线程的inheritableThreadLocals 不为null(本例是不为null的)
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                // 把父类线程的inherThreadLocals 通过createInheritedMap方法进行操作,将threadLoaclMap赋值给子线程
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        this.stackSize = stackSize;

        tid = nextThreadID();
    }
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
    
    private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        // 注意:这里的value还是指向父类的,如果是线程不安全的类,子类和父类共同访问的话,这里也会有问题;
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

小总结:Thread先获取父线程,当父线程的inherThreadLocals不为空的时候,就会创建一个新的ThreadLocalMap 把父线程的inherThreadLocals 传给该ThreadLocalMap; 用inheritableThreadLocals指向该ThreadLocalMap;

4.3 inherThreadLocals真的没问题了吗?

我们把inherThreadLocals和线程池来搭配一下哈;

    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<String> Simple = new InheritableThreadLocal<>();
        Simple.set("我是阿宝");
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(1, 1, 10, TimeUnit.SECONDS, new ArrayBlockingQueue<>(10));
        threadPoolExecutor.submit(() -> {
            System.out.println(Simple.get());
        });
        Thread.sleep(100);
        System.out.println("--------------------");
        Simple.set("我是阿宝二号");
        threadPoolExecutor.submit(()->{
            System.out.println(Simple.get());
        });
    }

结论如下:

我是阿宝
--------------------
我是阿宝

很尴尬,阿宝二号没有出现;
原因很简单,线程池内部涉及重复利用,不会再重新执行init()初始化方法,而是之间使用已经创建过的线程(里面有while(true)的轮询),所以,这里的值不会二次产生变化;其实,不仅仅是使用线程池,我使用线程也可以实现这种现象;

    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<Double> Simple = new InheritableThreadLocal<>();
        Simple.set(11.22);

        new Thread(() -> {

            try {
                // 创建线程把父线程中ThreadLocalMap内容复制到子线程中
                System.out.println(Simple.get());
                // 等待200ms让子线程进行更改value
                Thread.sleep(200);
                // 获取到的value还是为11.22
                System.out.println(Simple.get());
                System.out.println("子线程运行结束");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }).start();
        Thread.sleep(100);
        // 父线程进行设值
        Simple.set(22.33);
        System.out.println("主线程运行结束");
    }

运行结果为:

11.22
主线程运行结束
11.22
子线程运行结束

所以,可以得出结论:只有在执行init()方法时,才可以让父线程中的ThreadLocalMap 复制到子线程中;上面的流程图如下:

image.png
这里解释一下,为什么父线程和子线程都指向一个对象时,该对象发生改变时,子线程获取到的对象却没有发生改变的原因;

5.TransmittableThreadLocal

阿里巴巴开源解决方案github.com/alibaba/tra… 根据github相关的内容:改写一下自己代码:

    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<String> Simple = new TransmittableThreadLocal<>();
        Simple.set("我是阿宝");
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(1, 1, 10, TimeUnit.SECONDS, new ArrayBlockingQueue<>(10));
        ExecutorService ttlExecutorService = TtlExecutors.getTtlExecutorService(threadPoolExecutor);
        ttlExecutorService.submit(() -> {
            System.out.println(Simple.get());
        });
        Thread.sleep(100);
        System.out.println("--------------------");
        Simple.set("我是阿宝二号");
        ttlExecutorService.submit(()->{
            System.out.println(Simple.get());
        });
    }

结果如下: 符合预期

我是阿宝
--------------------
我是阿宝二号

具体实现细节及其源码等哪天有空再进行查看; 主要相关方法

  1. get方法调用时,先获取父线程相关数据判断是否有数据,然后在holder中把自身也加进去
  2. set方法调用是,现在父线程汇总设置,再本地判断holder是否为删除或者新增
  3. remvoe调用时,先删除自身,再删除父线程中的数据
public final T get() {
        T value = super.get();
        if (this.disableIgnoreNullValueSemantics || null != value) {
            this.addThisToHolder();
        }

        return value;
    }

    public final void set(T value) {
        if (!this.disableIgnoreNullValueSemantics && null == value) {
            this.remove();
        } else {
            super.set(value);
            this.addThisToHolder();
        }

    }
    public final void remove() {
        this.removeThisFromHolder();
        super.remove();
    }