为什么ThreadLocal是线程安全的?

2,463 阅读35分钟

对于线程安全,我们所需要做的就是在时间和空间上作出权衡,而今天所谈到的ThreadLocal就是典型的空间换时间的数据结构。

ThreadLocal的使用

在项目中我们可以通过ThreadLocal来存储用户信息,其中一般会在过滤器/拦截器的入口处初始化用户信息,并在执行结束后对其进行清理。这样从请求进来一直到返回,我们只需要通过线程变量ThreadLocal获取用户信息即可,而不用每次都从数据库查出来。

因为ThreadLocal是线程安全的,所以这里我们把它声明为一个单例。

public class UserHolder {
    private final static ThreadLocal<UserInfo> CURRENT_USER = new ThreadLocal<>();

    public static UserInfo get() {
        return CURRENT_USER.get();
    }

    public static void set(UserInfo userInfo) {
        CURRENT_USER.set(userInfo);
    }

    public static void remove() {
        CURRENT_USER.remove();
    }
}

然后我们就可以在拦截器和业务方法进行初始化和获取用户信息了。

public class Filter{

    public void doFilter(Request request,Response response){
        UserInfo userInfo = getUserInfo(request);
        UserHolder.set(userInfo);

        doHandle(request,response);

        UserHolder.remove();
    }
}

public class BusinessService{

    public void serviceA(){
        // 因为ThreadLocal的机制,此处是线程安全的
        UserInfo userInfo = UserHolder.get();
        // 操作userInfo...
    }

}

这样我们就了解了ThreadLocal的基本用法了,接着下面我们就从实现原理的角度来揭开ThreadLocal的神秘面纱。

ThreadLocal的实现原理

这里我们先通过Java官方文档对ThreadLocal进行初步的认识。

/**
 * ...
 *
 * <p>Each thread holds an implicit reference to its copy of a thread-local
 * variable as long as the thread is alive and the {@code ThreadLocal}
 * instance is accessible; after a thread goes away, all of its copies of
 * thread-local instances are subject to garbage collection (unless other
 * references to these copies exist).
 *
 * @author  Josh Bloch and Doug Lea
 * @since   1.2
 */
public class ThreadLocal<T> {
    // ...
}

我截取了信息量比较大的最后一段注释下来,这里大概意思是“每个活着的线程中都会持有一份线程变量的副本”,这句话从本质上阐述了ThreadLocal的设计理念和实现原理(读到这里也许有所迷惑,可以先忽略这句话继续往下阅读)。下面我们从ThreadLocal使用的角度进行阐述。

首先我们会通过ThreadLocal的构造方法创建一个新实例。而在ThreadLocal中仅存在一个默认构造方法。

    /**
     * Creates a thread local variable.
     * @see #withInitial(java.util.function.Supplier)
     */
    public ThreadLocal() {
    }

当然,我们也可以通过withInitial方法创建出带有默认初始值的ThreadLocal,但是为了方便理解其中的原理这部分会在分析完getset之后再进行分析。

接着我们会通过set方法对ThreadLocal进行赋值,也是从这里我们揭开ThreadLocal的第一层面纱。

    /**
     * Sets the current thread's copy of this thread-local variable
     * to the specified value.  Most subclasses will have no need to
     * override this method, relying solely on the {@link #initialValue}
     * method to set the values of thread-locals.
     *
     * @param value the value to be stored in the current thread's copy of
     *        this thread-local.
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

这里会先通过Thread.currentThread获取到当前线程Thread

ThreadcurrentThread方法是一个本地native方法,用于获取当前线程。此处也充分体现了Java万物皆对象的理念,直接通过一个对象来表示线程这种虚无缥缈的概念,这里我们就把它当作是进来的线程就好了。

public class Thread implements Runnable {
    /**
     * Returns a reference to the currently executing thread object.
     *
     * @return  the currently executing thread.
     */
    public static native Thread currentThread();
}

在获取当前线程之后调用了一个getMap方法,以此获取到了一个ThreadLocalMap,这里我们先来看看getMap方法。

    /**
     * Get the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param  t the current thread
     * @return the map
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    public class Thread implements Runnable {
         /* ThreadLocal values pertaining to this thread. This map is maintained 
          * by the ThreadLocal class. */
        ThreadLocal.ThreadLocalMap threadLocals = null;
    }

getMap方法中可以看到它是直接返回一个线程对象的成员变量。

这里回归到ThreadLocal#set方法,在获取到线程对象的成员变量ThreadLocalMap之后,它会对此对象进行判空处理:

  • 如果不为空的话调用set方法设置我们要存储的值
  • 如果为空则创建Map对象(此处把value也丢了进去,应该是创建的同时把值也存储进去了)。

到这里笔者在set方法语义上进行了一次解读,如果要继续深入的话,我们必须认识什么是ThreadLocalMap及了解其作用。


    /**
     * ThreadLocalMap is a customized hash map suitable only for
     * maintaining thread local values. No operations are exported
     * outside of the ThreadLocal class. The class is package private to
     * allow declaration of fields in class Thread.  To help deal with
     * very large and long-lived usages, the hash table entries use
     * WeakReferences for keys. However, since reference queues are not
     * used, stale entries are guaranteed to be removed only when
     * the table starts running out of space.
     */
    static class ThreadLocalMap {
        //...
    }

我们找到ThreadLocalMap注解说明,ThreadLocalMap是一个用来存储线程变量值的散列表,所谓散列表即是可以让我们快速的对数据进行定位和查找的一个数据结构。具体这个散列表是怎么实现的,下文再深入了解,此处点到即止。

既然它是一个散列表,此处我们可以暂时把它当作是一个HashMapJava中的散列工具类),这样我们再次回到上面拿ThreadLocalMap的操作上,如果ThreadLocalMap不为空的话,我们就将当前对象实例(ThreadLocal)为键,存储的value为值存储到散列表中去。

阅读到这里或许读者会有所迷惑,即使是HashMap也不是线程安全的,怎么这里ThreadLocal就是线程安全的呢?这里先不着急,在看完createMap后你就会一目了然了。

再次回到ThreadLocal#set方法中,对于获取ThreadLocalMap为空的情况就会通过方法createMap进行创建,其中在创建的同时会将相应的键值对赋进去。这里我们来到createMap一探究竟:

    /**
     * Create the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the map
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

经过笔者把一些旁支删减掉之后,此处的思路就变成清晰很多了。在判断线程中的成员变量为空时,我们就通过ThreadLocalMap构造方法新建一个实例(包含存储的键值对)并赋值到线程的成员变量上。

这里我们应该知道为什么ThreadLocal是线程安全的了。这里由笔者来总结一下:

  1. 每当我们使用ThreadLocal的时候,其实我们就是新建了一个散列表的键(ThreadLocal),这个键是所有线程共享的。
  2. 当我们要设置线程变量时,我们会在所在线程Thread拿出成员变量——散列表TheadLocalMap进行赋值(这个散列表是当前线程独享的,因为每个请求都是一个线程)
    • 如果我们所在的线程已经存在散列表,我们直接往散列表赋值
    • 如果我们所在的线程不存在散列表,新建一个散列表实例并且往里面赋值

上文也提及了可以把线程当作是Thread对象实例,也就是说每个请求线程都是一个独立的Thread对象,而对象之间的成员变量是不共享的。

最后再次强调,在ThreadLocalMap中是将ThreadLocal作为键进行存储的,也就是说程序中的每个ThreadLocal在同一个线程中是共享同一个ThreadLocalMap的(其中不同ThreadLocal对应着不同的键)

既然我们已经把值设置到线程变量里面了,接下来我们就在业务代码中把值取出来——get方法

    /**
     * Returns the value in the current thread's copy of this
     * thread-local variable.  If the variable has no value for the
     * current thread, it is first initialized to the value returned
     * by an invocation of the {@link #initialValue} method.
     *
     * @return the current thread's value of this thread-local
     */
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

    /**
     * Variant of set() to establish initialValue. Used instead
     * of set() in case user has overridden the set() method.
     *
     * @return the initial value
     */
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    /**
     * Returns the current thread's "initial value" for this
     * thread-local variable.  This method will be invoked the first
     * time a thread accesses the variable with the {@link #get}
     * method, unless the thread previously invoked the {@link #set}
     * method, in which case the {@code initialValue} method will not
     * be invoked for the thread.  Normally, this method is invoked at
     * most once per thread, but it may be invoked again in case of
     * subsequent invocations of {@link #remove} followed by {@link #get}.
     *
     * <p>This implementation simply returns {@code null}; if the
     * programmer desires thread-local variables to have an initial
     * value other than {@code null}, {@code ThreadLocal} must be
     * subclassed, and this method overridden.  Typically, an
     * anonymous inner class will be used.
     *
     * @return the initial value for this thread-local
     */
    protected T initialValue() {
        return null;
    }

对于get方法实际上与set方法整体实现思路上基本相同,这里直接给出整个实现思路:

  • 获取当前所在线程的散列表ThreadLocalMap,然后通过将ThreadLocal作为键从ThreadLocalMap中获取相应的键值对,并返回给调用方
    • 如果散列表为空或者以ThreadLocal作为键的值为空,则调用方法setInitialValue初始化相应值并返回(默认初始化为空)
    • 如果散列表不为空且查找以ThreadLocal作为键的值不为空,则直接返回相应的值。

从上述流程也解释了为什么对没有赋值的ThreadLocal对象调用get方法会返回一个空。当然,我们也可以通过重写initialValue设置返回的默认值。而官方也对于这种情况给出了一个带有默认值的构造类SuppliedThreadLocal,我们可以通过方法withInitial进行创建。

    /**
     * Creates a thread local variable. The initial value of the variable is
     * determined by invoking the {@code get} method on the {@code Supplier}.
     *
     * @param <S> the type of the thread local's value
     * @param supplier the supplier to be used to determine the initial value
     * @return a new thread local variable
     * @throws NullPointerException if the specified supplier is null
     * @since 1.8
     */
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }

    /**
     * An extension of ThreadLocal that obtains its initial value from
     * the specified {@code Supplier}.
     */
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

对于SuppliedThreadLocal类仅仅是通过继承ThreadLocal并覆盖其中的初始化方法initialValue来达到目的的。这样我们就可以对没有赋值的键返回提供的默认值了。

对于其他的特性,SuppliedThreadLocalThreadLocal是相同的。

最后在使用完ThreadLocal之后我们需要对其执行remove操作。

这一步很关键,因为使用不当会造成内存泄漏,更严重的则会因为ThreadLocal串场造成线上事故(下文进行分析)。

    /**
     * Removes the current thread's value for this thread-local
     * variable.  If this thread-local variable is subsequently
     * {@linkplain #get read} by the current thread, its value will be
     * reinitialized by invoking its {@link #initialValue} method,
     * unless its value is {@linkplain #set set} by the current thread
     * in the interim.  This may result in multiple invocations of the
     * {@code initialValue} method in the current thread.
     *
     * @since 1.5
     */
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

对于remove方法,ThreadLocal则是在获取当前所在线程的散列表后(如果不为空),对其中以ThreadLocal作为键的键值对进行删除。

到这里已经解答了开篇的问题为什么ThreadLocal是线程安全的?,如果没有想继续深入了解的话真的可以点到即止了。下面我也给出了ThreadLocal的原理设计图。为了方便理解,这里我把ThreadLocalMap的一些实现细节省略掉了:


             │     │
             │     │
             │ │   │
             │ │   │
             │ │   │    ┌──────────────────┐
    request1 │ │   │    │ Thread1          │
             │ │   │    │                  │  ┌────────────┬─────────────┐
             │ │◄──┼────┤   ThreadLocalMap─┼─►│ThreadLocal1│ stored value│
             │ │   │    │                  │  ├────────────┼─────────────┤
             │ │   │    └──────────────────┘  │ThreadLocal2│ stored value│
             │ ▼   │                          ├────────────┼─────────────┤
             │     │                          │ThreadLocal3│ stored value│
             │   │ │                          └────────────┴─────────────┘
             │   │ │
             │   │ │    ┌──────────────────┐
             │   │ │    │ Thread2          │
    request2 │   │◄├────┤                  │  ┌────────────┬─────────────┐
             │   │ │    │   ThreadLocalMap─┼─►│ThreadLocal1│ stored value│
             │   │ │    │                  │  ├────────────┼─────────────┤
             │   │ │    └──────────────────┘  │ThreadLocal2│ stored value│
             │   ▼ │                          ├────────────┼─────────────┤
             │     │                          │ThreadLocal3│ stored value│
             │     │                          └────────────┴─────────────┘
          xxxxxxxxxxxxx
           xxxxxxxxxxx
            xxxxxxxxx
             xxxxxxx
              xxxxx
               xxx
                x

ThreadLocalMap实现原理

经过上面的分析,我们基本已经了解ThreadLocal的整个工作流程了,然而如果想更进一步理解ThreadLocal,则需要对ThreadLocalMap进行深度学习了。

我们先来看ThreadLocalMap的内部数据结构是怎么构成的。

/**
 * ThreadLocalMap is a customized hash map suitable only for
 * maintaining thread local values. No operations are exported
 * outside of the ThreadLocal class. The class is package private to
 * allow declaration of fields in class Thread.  To help deal with
 * very large and long-lived usages, the hash table entries use
 * WeakReferences for keys. However, since reference queues are not
 * used, stale entries are guaranteed to be removed only when
 * the table starts running out of space.
 */
static class ThreadLocalMap {
    /**
     * The entries in this hash map extend WeakReference, using
     * its main ref field as the key (which is always a
     * ThreadLocal object).  Note that null keys (i.e. entry.get()
     * == null) mean that the key is no longer referenced, so the
     * entry can be expunged from table.  Such entries are referred to
     * as "stale entries" in the code that follows.
     */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;
}

看过HashMap源码的你应该会知道典型的散列表内部实现结构是怎么样的,就是数组+散列映射的方式,而这里同样也是通过这种典型的散列表来实现的。

但是在这里有一个很值得关注的点,它将我们的键(即ThreadLocal)指定为弱引用。对于这里笔者翻看了对此有所谈及的一些注释说明,其中让我在ThreadLocalMap中找到了,原文是这样描述的:

To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys. However, since reference queues are not used, stale entries are guaranteed to be removed only when the table starts running out of space.

大致意思是用弱引用指向键(即ThreadLocal)是为了解决ThreadLocalMap大量和长时间使用的问题。阅读到这里貌似会有点迷惑,我们不妨先来了解一下强引用、软引用、弱引用和虚引用的概念:

  • 强引用:如果一个对象具有强引用,那垃圾回收器绝不会回收它。
  • 软引用:如果一个对象只具有软引用,则内存空间充足时,垃圾回收器就不会回收它;如果内存空间不足了,就会回收这些对象的内存。
  • 弱引用:在垃圾回收器线程扫描它所管辖的内存区域的过程中,一旦发现了只具有弱引用的对象,不管当前内存空间足够与否,都会回收它的内存。
  • 虚引用:虚引用顾名思义,就是形同虚设。与其他几种引用都不同,虚引用并不会决定对象的生命周期。如果一个对象仅持有虚引用,那么它就和没有任何引用一样,在任何时候都可能被垃圾回收器回收。

关于reference queue的相关信息这里就不再阐述了,有兴趣可阅读以下链接:Weak references cleared atomically with placement on reference queue ?

基于上述弱引用的概念,如果这个对象只存在弱引用时它就会被垃圾回收器回收,也就是说如果键(即ThreadLocal)外部的强引用都被删除了,比如方法执行结束后方法栈被清空后,键(即ThreadLocal)就只存在在Entry中所指定的弱引用了,那么在下次GC的时候它就会被回收,即获取键(即ThreadLocal)的值就会变成空了。

基于此,我们再通过以下例子对弱引用修饰的作用作进一步阐释。

首先假设把键(即ThreadLocal)设为强引用,看看这会产生什么问题:

  • 通过上文,我们知道ThreadLocalMap是每个线程独享的,每个ThreadLocal相当于这个ThreadLocalMap的一个键。而因为ThreadLocalMapThread的一个成员变量,所以它的生命周期是跟Thread(即线程)走的。如果每个请求进来都创建一个线程,而请求结束的时候都把栈清空和线程销毁掉,伴随着ThreadLocalMapThreadLocal和相对应的Entry都会被回收,这时候其实使用强引用是没有任何问题的。但是在一般项目中是通过线程池处理请求的,即我们的线程会被复用,这时候如果对应的栈被清空后,但是这个键(即ThreadLocal)还存在一个强引用ThreadLocalMap中无法被释放,就会造成内存泄漏。

而如果我们现在采用了弱引用去指向键(即ThreadLocal)这个问题就得到解决了:

  • 在线程执行完之后栈被情况了,键(即ThreadLocal)就只有本身的弱引用而不存在其他的强引用,在下次GC时就会被回收,键(即ThreadLocal)所对应的值就会变成空。

到这里是不是觉得已经万事大吉了?其实不然,就算把键(ThreadLocal)设置成弱引用还是会存在内存泄漏:

  1. 对于Entry的删除ThreadLocalMap采用的是一种惰性删除的策略,即在下次程序进行getsetremove操作的时候进行删除。在线程池中的线程如果迟迟没有被调用,那Entry是不会被清空的,这就造成了内存泄漏。
  2. 一般对于ThreadLocal这种类的通常用法是把它声明为一个静态变量在整个项目的生命周期中都存在着(像开篇的例子),这相当于这个键(即ThreadLocal)将永远被强引用指向着,如果单靠它本身的弱引用特性,不但会造成内存泄漏,更严重的会造成生产事故(复用线程池中的线程时,顺便把这个键(即ThreadLocal)所对应的值也复用了)。

为了方便下文进行分析,我们把将要惰性删除的Entry称为过期的Entry

具体怎么去避免这种情况呢?我们文末再给出答案,我们先来看看它的散列表是怎么来实现的(相信看完实现原理之后你就会找到答案了)

在充分了解到ThreadLocalMap存储的数据结构之后,我们再顺着其使用的思路对其进行进一步的分析,这里我们从ThreadLocalMap的构造方法开始。

static class ThreadLocalMap {
    /**
     * The initial capacity -- MUST be a power of two.
     */
    private static final int INITIAL_CAPACITY = 16;
 
    /**
     * The number of entries in the table.
     */
    private int size = 0;

    /**
     * The next size value at which to resize.
     */
    private int threshold; // Default to 0

    /**
     * Construct a new map initially containing (firstKey, firstValue).
     * ThreadLocalMaps are constructed lazily, so we only create
     * one when we have at least one entry to put in it.
     */
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }

    /**
     * Construct a new map including all Inheritable ThreadLocals
     * from given parent map. Called only by createInheritedMap.
     *
     * @param parentMap the map associated with parent thread.
     */
    private ThreadLocalMap(ThreadLocalMap parentMap) {
        //..
    }

    /**
     * Set the resize threshold to maintain at worst a 2/3 load factor.
     */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

}

ThreadLocalMap有两个构造方法,其中一个是私密方法,它的用处对理解ThreadLocalMap的实现原理用处不大,所以这里先忽略它。也就是说ThreadLocalMap相当于只有一个构造方法,我们将围绕着这个来展开。

这里我们回顾一下上文中是如何使用到ThreadLocalMap的:在设置线程变量时,如果在ThreadLocalMap不存在的情况下,则通过构造方法(这里谈及的ThreadLocalMap构造方法)进行构造且赋值。

接着我们再来看看ThreadLocalMap构造方法中初始化了哪些变量:

  1. 散列表数组的初始化大小INITIAL_CAPACITY16,此处也强调了必须是2的次幂
  2. 散列表中元素的数量size
  3. 散列表数组的扩容阈值threshold,其默认的负载因子2/3,即0.66
  • 为什么负载因子是2/3?此处并没有给出说明,类似的HashMap也没有给出为什么默认负载因子是3/4,即0.75
  • 为什么是数组容量必须是2的次幂呢?具体原因笔者在之前关于HashMap的博客已经做了相关说明,即只有容量为2的次幂时,索引映射算法才能从mod运算转化为位运算。换句话说,只有b2的次幂,算法a mod b ==> a & (b-1)才能成立。

接着,我们再来看看ThreadLocalMap是如何做索引映射的?从代码可以看出它是通过将ThreadLocalhashcode与数组容量进行取模作运算映射出来的:

int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)

ThreadLocalhashcode是怎么生成的呢?我们回到ThreadLocal看一下:

public class ThreadLocal<T> {
    /**
     * ThreadLocals rely on per-thread linear-probe hash maps attached
     * to each thread (Thread.threadLocals and
     * inheritableThreadLocals).  The ThreadLocal objects act as keys,
     * searched via threadLocalHashCode.  This is a custom hash code
     * (useful only within ThreadLocalMaps) that eliminates collisions
     * in the common case where consecutively constructed ThreadLocals
     * are used by the same threads, while remaining well-behaved in
     * less common cases.
     */
    private final int threadLocalHashCode = nextHashCode();

    /**
     * The next hash code to be given out. Updated atomically. Starts at
     * zero.
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * The difference between successively generated hash codes - turns
     * implicit sequential thread-local IDs into near-optimally spread
     * multiplicative hash values for power-of-two-sized tables.
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * Returns the next hash code.
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
}

关于这里的哈希算法,从表面看这里是通过静态共享变量nextHashCode进行递增(区间为HASH_INCREMENT)操作,每个进来的ThreadLocal都会从这里得到一个hashcode。实际上此处使用的是Fibbonachi hashing算法乘法哈希算法中的一种)生成hashcode,经过与标准的乘法哈希算法(h(k)=(ak mod W)/(W/M))进行比较发现是有一些出入,但总体来说来还是乘法哈希算法(此处相当于把a0x61c88647k{1,2,3..,n}代入公式ak进行计算)。

关于魔数0x61c88647是也是一个需要关注到的点,具体资料可以参考下面链接:

然后在构造方法中,计算出数组索引后就立刻赋值给相应的槽(这里可立即插入元素,因为插入的元素是整个ThreadLocalMap中第一个元素,所以是不会有散列冲突问题的发生)。但是如果是后面插入的元素就有可能会产生散列冲突的问题了,那ThreadLocalMap是如何解决的呢?可以看到threadLocalHashCode变量上的这么一段注释:

ThreadLocals rely on per-thread linear-probe hash maps attached to each thread (Thread.threadLocals and inheritableThreadLocals).

在这里我们可以看到它是通过linear-probe来解决哈希碰撞的,也就是说ThreadLocalMap是通过探查算法为线性探查(linear-probe)的开放寻址法来解决哈希碰撞的。

所谓线性探查就是通过散列函数寻找到对应的槽位i,如果存在继续探查i+1i+2以此类推,去到数组最后一个绕回到第一个槽位继续探查,直到找到为止。线性探查实现起来是比较简单的,但是它也存在比较致命的缺点,随着连续被占用的槽不断地增加,平均查找时间也随之增加,我们称它为一次群集。

在构造出ThreadLocalMap我们就通过ThreadLocalMap#set方法进行赋值:

    static class ThreadLocalMap {

        /**
         * Set the value associated with key.
         *
         * @param key the thread local object
         * @param value the value to be set
         */
        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }


        /**
         * 线性探查算法
         * Increment i modulo len.
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
    }

此处我们先别绕进去,先对最简单的流程梳理清楚,暂时忽略replaceStaleEntry方法和插入之后的rehash方法。通过对这些细节去除之后整个流程就清晰很多了,很典型的开放寻址法:

  1. 通过散列方法计算出槽的位置,如果没有被占用(即为null)直接插入即可
  2. 如果计算出来的槽是被占用的,我们就开启线性探查(nextIndex方法)往后一位一位的寻找
    • 如果在探查过程中发现所对应的key已经存在了,直接覆盖要存储的值
    • 如果在探查过程中一直到null都没有找到所对应的key,这时直接插入到null所在位置(即插入到没被占用的槽中)

set方法的基本面讲完了,我们再回到散列表的治理上去。先看rehash这个方法,它是用来作散列表的扩容操作的,因为当超过负载因子查询效率会急速下滑,所以就有了扩容这步操作了(散列表很常规的操作)。在散列表中扩容操作都是很耗费性能的,所以这里在判断负载因子之前,还做了一次优化,上文提及的entry是惰性删除的,所以我们一般是先把需要惰性删除的entry去除之后再计算得到有效的负载因子,再来决定是否扩容。具体清除操作可以看到方法cleanSomeSlots

   static class ThreadLocalMap {
        /**
         * Heuristically scan some cells looking for stale entries.
         * This is invoked when either a new element is added, or
         * another stale one has been expunged. It performs a
         * logarithmic number of scans, as a balance between no
         * scanning (fast but retains garbage) and a number of scans
         * proportional to number of elements, that would find all
         * garbage but would cause some insertions to take O(n) time.
         *
         * @param i a position known NOT to hold a stale entry. The
         * scan starts at the element after i.
         *
         * @param n scan control: {@code log2(n)} cells are scanned,
         * unless a stale entry is found, in which case
         * {@code log2(table.length)-1} additional cells are scanned.
         * When called from insertions, this parameter is the number
         * of elements, but when from replaceStaleEntry, it is the
         * table length. (Note: all this could be changed to be either
         * more or less aggressive by weighting n instead of just
         * using straight log n. But this version is simple, fast, and
         * seems to work well.)
         *
         * @return true if any stale entries have been removed.
         */
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }
    }

这个方法并没有清除所有过期的槽,只是按比例的去清除,这样做的原因注释也说明了:

  • 如果不进行清理速度会快,但是不会有垃圾回收(因为entryvalue都没有被释放)
  • 如果一次清理掉全部,可能会导致O(n)的时间复杂度

因此这里做个一个权衡,只是按比例做了清理,即每次调用方法都只会执行log2(n)次,n为入参。具体的清理流程如下:

  1. 首先获取i的下一个槽,对应的算法和线性探查是同一个nextIndex
  2. 如果确定为过期(即key为空)就执行清理操作(方法expungeStaleEntry
  3. 清理完之后更新参数i,用于下一次循环进行清理(循环次数为log2(n)

其中参数i表示没有持有过期entry的索引槽,所以清理操作会在i后开始(因为i没有过期,所以忽略对i的清理)。

接着,我们再来看真正的清理方法expungeStaleEntry

    static class ThreadLocalMap {
        /**
         * Expunge a stale entry by rehashing any possibly colliding entries
         * lying between staleSlot and the next null slot.  This also expunges
         * any other stale entries encountered before the trailing null.  See
         * Knuth, Section 6.4
         *
         * @param staleSlot index of slot known to have null key
         * @return the index of the next null slot after staleSlot
         * (all between staleSlot and this slot will have been checked
         * for expunging).
         */
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        } 
    }

方法expungeStaleEntry是真正清除的槽的方法,下面梳理出整个清除流程:

  1. 传入的参数代表是key为空的槽(即过期的槽),所以直接把entry设为null清除掉
  2. 接着对当前槽之后的元素进行rehash操作,直到遇到槽为null停止
    • 如果期间遇到keynull(即已经过期了),顺便把对应的entry也清除了
    • 如果entry是正常的,则再通过hashcode做一次映射操作
      • 如果算出来的索引槽与当前不相同,则清空当前槽(将槽设置为null),然后重新计算槽的位置并进行插入
      • 如果算出来的索引槽与当前相同,则继续执行循环
  3. 返回槽为null的索引(此处用于外层cleanSomeSlots继续作循环清理)

此处为什么要进行rehash操作呢?其实这涉及到开放寻址法查询和删除机制所决定的,在元素删除之后为了不影响查询,通常会用一个标记去标识这个被删除的元素而不是置为空。但是这样就会导致一个问题,如果我们前期插入了大量的元素之后又大量删除,这样就会拉高了整个散列表的查询时间复杂度。而这里做了一个优化就是删除之后把之后的元素进行了一次rehash操作,把正常的entry推到前面去,这样就避免了上述所提到的问题了。

cleanSomeSlots方法和expungeStaleEntry方法这两个方法对于整个开放寻址的治理很重要,这里我们先来总结一下这两个方法分别的作用:

  • expungeStaleEntry方法:对从入参第staleSlot个槽(包含)一直到第一个空槽(null)进行清理和rehash,并返回rehash后下一个空槽的位置索引。
  • cleanSomeSlots方法:从入参第i个槽(不包含)后开始,进行log2(n)expungeStaleEntry操作,每次执行expungeStaleEntry的入参依赖于上一次expungeStaleEntry返回的。

最后在对过期entry清理后,如果发现负载因子还是大于阈值时,就需要进行rehash操作了,下面我们来看一下:

    static class ThreadLocalMap {

        /**
         * Re-pack and/or re-size the table. First scan the entire
         * table removing stale entries. If this doesn't sufficiently
         * shrink the size of the table, double the table size.
         */
        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         * Expunge all stale entries in the table.
         */
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
    }

这里我们看到rehash操作里面还会进行了一次对过期槽的清理工作,如果说在进来之前我们已经按比例的清除了一次,那么就会进行了一次全范围的清除了(从代码也可以看到已经遍历所有的槽去清理了)。假如这样还是超过了阈值,那只能执行resize了:

    static class ThreadLocalMap {

        /**
         * Double the capacity of the table.
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }
    }

整个resize流程还是十分好理解了,简单来说就是新建一个大小为原来两倍的数组,然后再逐个逐个通过散列函数把原来数组中的元素迁移到新的数组中。虽然调用resize前把所有槽都遍历一次去清理了,但是这里还有对key进行判空处理,因为在扩容期间可能也有一些过期的ThreadLocal产生。

到这里,我们回到上面set方法中hold住的一个点,就是在线性探查的过程中,如果发现key为空(即过期的槽),我们同样可以把新的覆盖上去的,具体我们看到replaceStaleEntry方法(经过上面的分析,现在再来看这个方法相信你会容易很多):

   static class ThreadLocalMap {

        /**
         * Replace a stale entry encountered during a set operation
         * with an entry for the specified key.  The value passed in
         * the value parameter is stored in the entry, whether or not
         * an entry already exists for the specified key.
         *
         * As a side effect, this method expunges all stale entries in the
         * "run" containing the stale entry.  (A run is a sequence of entries
         * between two null slots.)
         *
         * @param  key the key
         * @param  value the value to be associated with key
         * @param  staleSlot index of the first stale entry encountered while
         *         searching for key.
         */
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // 用于标注开始清理的位置
            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
            // 通过向前循环,寻找从`null`开始第一个过期的位置
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        // 此处先记着目前遍历的节点的第i个槽,第i个槽是要已存在要替换的key
                        // 这里表示向前寻找第一个过期的槽失败且到向后寻找这一步位置也没找到过期的槽
                        // 因为原本staleSlot传进的位置是过期的,上面覆盖值之后把位置i和位置staleSlot做了交换
                        // 在交换之后位置i是过期的,所以这里相当于的第i个位置开始(包含)清理
                        // 这里同时也解释了为什么要做交换,我的观点是为了减少之后元素移动的次数
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // 此处表示向前寻找第一个过期的槽失败,开始向后寻找了
                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // 找到就执行清理;找不到需要清理的槽就跳过。staleSlot除外,因为上面已经对staleSlot进行覆盖了
            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
    }

如果单单看replace操作并不难,就是简单的对staleSlot的槽进行覆盖就好,但是这里同时执行了对过期的槽进行清理。怎么清理呢?可以看到注释有这么一段: As a side effect, this method expunges all stale entries in the "run" containing the stale entry. (A run is a sequence of entries between two null slots.)。意思为在运行期间,该方法清理两个null之间包含的过期entry。此处用一个变量slotToExpunge标注开始执行的位置,整体看上去还是有点绕的,这里我们不按照一步一步的去分析,我们只需记住两个点就能把这段代码看懂了:

  1. 首先赋值元素
    • 如果在散列中不存在key,直接在对应的staleSlot设置值
    • 如果在散列中找到对应的key,覆盖原有的值并且与staleSlot位置的槽做交换(此处为什么做交换?我的观点是为了尽量少的移动元素)
  2. 清除过期的槽,寻找区间内第一个过期的槽,并从这个槽开始进行清理操作
    • 通过向前循环,寻找从null开始第一个过期槽的位置
    • 如果向前循环没有的话,就开始往后找
    • 最后如果找不到(即slotToExpunge != staleSlot的情况,staleSlot本身是被替换的,所以不算在内),就不做清理

最后提一下cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)这个调用,关于cleanSomeSlots方法和expungeStaleEntry方法的作用可以回顾一下上文提及的,这里其实cleanSomeSlots方法里面就会调用expungeStaleEntry方法的,为什么这里外面还调用了一次呢?是因为cleanSomeSlots方法是会忽略当前传入参数的位置的(参数上也有说明是传进来的位置表示是不需要清理的),所以在外面再额外调用一次。而这样使用而不是重新写一个专门方法的原因应该是对为了对代码实现更多的复用(可能是懒。。。),当然如果这个场景比较多得话我相信肯定会新建一个专门方法的。(反正我平时也是这样的。。。)

到这里其实基本上对ThreadLocalMap的实现原理已经阐述完毕了,后面的get方法和remove方法其实原理是一样的,这里我们简单来说一下:


   static class ThreadLocalMap {

        /**
         * Get the entry associated with key.  This method
         * itself handles only the fast path: a direct hit of existing
         * key. It otherwise relays to getEntryAfterMiss.  This is
         * designed to maximize performance for direct hits, in part
         * by making this method readily inlinable.
         *
         * @param  key the thread local object
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

        /**
         * Version of getEntry method for use when key is not found in
         * its direct hash slot.
         *
         * @param  key the thread local object
         * @param  i the table index for key's hash code
         * @param  e the entry at table[i]
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

   }
  • 通过散列计算出数组索引
    • 如果槽为空,直接返回null
    • 如果槽的key与要寻找的key相同,直接返回
    • 如果槽的key不相等开始作线性探查,直到找到对应的key返回相应数据或找不到对应的key返回空

期间也有做对过期槽的清理工作,此处的清理工作是不会影响到后续的线性探查的。因为清理的时候会将后续正常的entry提到前面(位置i及它后面)来。

我们再来看remove方法:

   static class ThreadLocalMap {

        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
   }

同样的思路

  • 通过散列计算出数组索引
    • 如果槽为空,直接结束方法
    • 如果槽的key与要寻找的key相同,调用expungeStaleEntry清理槽及其数据(连带其后相邻的槽也去检查一遍)
    • 如果槽的key不相等开始作线性探查,直到找到对应的key则执行上一步操作,否则结束remove操作

到这里整个ThreadLocal真的就完全分析完了。通过对细节的分析,对于上图我进行了修改,添加一些实现的细节:

                                 +---------------------------------+       +---------------------------------+        +---------------------------------+
                                 |                                 |       |                                 |        |                                 |
                                 |  StrongReference=>ThreadLocal1  |       |  StrongReference=>ThreadLocal2  |        |  StrongReference=>ThreadLocal3  |
                                 |                                 |       |                                 |        |                                 |
                                 +----------------+----------------+       +--------------+------------------+        +------+--------------------------+
                                                  |                                       |                                  |
                                                  |                                       +-----------------------------+    |
                                                  |                                                                     |    |
                                                  +----------------------------------------------------------------+    |    |
                                                                                                                   |    |    |
                                                                          ┌────────────────────────────────────┐   |    |    |
                                                                          │              Entry                 │   |    |    |
                                                                          ├──────┬─────────────────────────────┤   |    |    |
                                                                     +----> key  │ WeakReference=>ThreadLocal1<----+    |    |
                                                                     |    ├──────┼─────────────────────────────┤   |    |    |
                                                                     |    │ value│ stored value                │   |    |    |
                                                                     |    └──────┴─────────────────────────────┘   |    |    |
                                                                     |                                             |    |    |
                                                                     |    ┌────────────────────────────────────┐   |    |    |
                                                                     |    │              Entry                 │   |    |    |
                                                                     |    ├──────┬─────────────────────────────┤   |    |    |
         │     │                                                     | +--> key  │ WeakReference=>ThreadLocal2<---------+    |
         │     │                                                     | |  ├──────┼─────────────────────────────┤   |    |    |
         │ │   │                           ┌───────────────────────┐ | |  │ value│ stored value                │   |    |    |
         │ │   │                           │table[]                │ | |  └──────┴─────────────────────────────┘   |    |    |
         │ │   │    ┌──────────────────┐   ├───────────────────────┤ | |                                           |    |    |
request1 │ │   │    │ Thread1          │   │  ThreadLocal1=>index1 +-+ |  ┌────────────────────────────────────┐   |    |    |
         │ │   │    │                  │   ├───────────────────────┤   |  │              Entry                 │   |    |    |
         │ │◄──┼────┤   ThreadLocalMap+--->│  ThreadLocal2=>index2 +---+  ├──────┬─────────────────────────────┤   |    |    |
         │ │   │    │                  │   ├───────────────────────┤      │ key  │ WeakReference=>ThreadLocal3<--------------+
         │ │   │    └──────────────────┘   │  ThreadLocal3=>index3 +----->├──────┼─────────────────────────────┤   |    |    |
         │ ▼   │                           └───────────────────────┘      │ value│ stored value                │   |    |    |
         │     │                                                          └──────┴─────────────────────────────┘   |    |    |
         │   │ │                                                                                                   |    |    |
         │   │ │                           ┌───────────────────────┐                                               |    |    |
         │   │ │    ┌──────────────────┐   │table[]                │      ┌────────────────────────────────────┐   |    |    |
         │   │ │    │ Thread2          │   ├───────────────────────┤      │              Entry                 │   |    |    |
request2 │   │◄├────┤                  │   │  ThreadLocal1=>index1 +----->├──────┬─────────────────────────────┤   |    |    |
         │   │ │    │   ThreadLocalMap+--->├───────────────────────┤      │ key  │ WeakReference=>ThreadLocal1<----+    |    |
         │   │ │    │                  │   │  ThreadLocal2=>index2 +---+  ├──────┼─────────────────────────────┤        |    |
         │   │ │    └──────────────────┘   ├───────────────────────┤   |  │ value│ stored value                │        |    |
         │   ▼ │                           │  ThreadLocal3=>index3 +-+ |  └──────┴─────────────────────────────┘        |    |
         │     │                           └───────────────────────┘ | |                                                |    |
         │     │                                                     | |  ┌────────────────────────────────────┐        |    |
      xxxxxxxxxxxxx                                                  | |  │              Entry                 │        |    |
       xxxxxxxxxxx                                                   | |  ├──────┬─────────────────────────────┤        |    |
        xxxxxxxxx                                                    | +->│ key  │ WeakReference=>ThreadLocal2<---------+    |
         xxxxxxx                                                     |    ├──────┼─────────────────────────────┤             |
          xxxxx                                                      |    │ value│ stored value                │             |
           xxx                                                       |    └──────┴─────────────────────────────┘             |
            x                                                        |                                                       |
                                                                     |    ┌────────────────────────────────────┐             |
                                                                     |    │              Entry                 │             |
                                                                     |    ├──────┬─────────────────────────────┤             |
                                                                     +--->│ key  │ WeakReference=>ThreadLocal3<--------------+
                                                                          ├──────┼─────────────────────────────┤
                                                                          │ value│ stored value                │
                                                                          └──────┴─────────────────────────────┘



扩展:InheritableThreadLocal

ThreadLocal这个类只是在当前线程有效的,如果我们开了一个子线程去执行任务时是无法使用到ThreadLocal里面存储的值的。那如何解决呢?Java提供了给我们一个继承子类InheritableThreadLocal专门用于这种场景,下面我们来看看:

/**
 * This class extends <tt>ThreadLocal</tt> to provide inheritance of values
 * from parent thread to child thread: when a child thread is created, the
 * child receives initial values for all inheritable thread-local variables
 * for which the parent has values.  Normally the child's values will be
 * identical to the parent's; however, the child's value can be made an
 * arbitrary function of the parent's by overriding the <tt>childValue</tt>
 * method in this class.
 *
 * <p>Inheritable thread-local variables are used in preference to
 * ordinary thread-local variables when the per-thread-attribute being
 * maintained in the variable (e.g., User ID, Transaction ID) must be
 * automatically transmitted to any child threads that are created.
 *
 * @author  Josh Bloch and Doug Lea
 * @see     ThreadLocal
 * @since   1.2
 */

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * <p>
     * This method merely returns its input argument, and should be overridden
     * if a different behavior is desired.
     *
     * @param parentValue the parent thread's value
     * @return the child thread's initial value
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * Get the map associated with a ThreadLocal.
     *
     * @param t the current thread
     */
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the table.
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

通过阅读类注释知道InheritableThreadLocal是可以在线程与子线程之间传播线程变量的,同时它也提及到了,一般传播的线程变量的值和父线程是一样的,我们也可以通过覆盖childValue做到不一样(自己做一些修改),当然无论修改与否这个方法childValue都是必须要覆盖的,因为父类ThreadLocal是不支持这个方法的调用的(抛错)。好,那么我们开始分析InheritableThreadLocal,这个类型除了覆盖了3个方法,其他全部行为都和它的父类ThreadLocal一样。

我们先来看后两个方法,上文也提及了无论是getsetremove方法都是通过getMap方法获取线程上的散列表的,其中getset在获取为空的情况下还是帮忙创建一个并把值进去,这里我们代码回忆一下:

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

看到这里应该是能看出差别的,ThreadLocal是从线程Thread的成员变量threadLocals处获取和保存散列表的,而InheritableThreadLocal则在线程Thread的成员变量inheritableThreadLocals中做这些操作,这里我们顺藤摸瓜在看看Thread

public class Thread implements Runnable {
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
}

可以看到Thread实例也专门开设了一个成员变量供我们这样去使用线程变量(继承),而我们的子类通过改变获取的线程成员变量来实现切换。这里还是有问题,那是怎么赋值到子线程中的呢?我们可以看到线程Thread创建的构造方法(这里看最简单的就可以了):

public Thread() {
    init(null, null, "Thread-" + nextThreadNum(), 0);
}

private void init(ThreadGroup g, Runnable target, String name,
                   long stackSize) {
    init(g, target, name, stackSize, null, true);
}

private void init(ThreadGroup g, Runnable target, String name,
                    long stackSize, AccessControlContext acc,
                    boolean inheritThreadLocals) {
    //..
    
    Thread parent = currentThread();


    // ..
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    // ..
}

可以看到在线程创建初始化的过程中有这么一段代码,虽然此处为初始化新线程,但实际上调用currentThread还是获取到的是旧/父线程(只是创建对象,还没有到操作系统调度),所以parent.inheritableThreadLocals获取到的还是旧/父线程的散列表,但是Thread对象确实全新的,即Thread中的inheritableThreadLocals是全新的,同时这里也明确的指定了只有inheritableThreadLocals才会进行复制操作,所以这里也解答了为什么ThreadLocal不能传递线程变量而InheritableThreadLocal却可以传递。接着这里又回到我们上文忽略的方法createInheritedMap和私有的构造方法:


public class ThreadLocal<T> {

    /**
     * Factory method to create map of inherited thread locals.
     * Designed to be called only from Thread constructor.
     *
     * @param  parentMap the map associated with parent thread
     * @return a map containing the parent's inheritable bindings
     */
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

    /**
     * Construct a new map including all Inheritable ThreadLocals
     * from given parent map. Called only by createInheritedMap.
     *
     * @param parentMap the map associated with parent thread.
     */
    private ThreadLocalMap(ThreadLocalMap parentMap) {
        Entry[] parentTable = parentMap.table;
        int len = parentTable.length;
        setThreshold(len);
        table = new Entry[len];

        for (int j = 0; j < len; j++) {
            Entry e = parentTable[j];
            if (e != null) {
                @SuppressWarnings("unchecked")
                ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                if (key != null) {
                    Object value = key.childValue(e.value);
                    Entry c = new Entry(key, value);
                    int h = key.threadLocalHashCode & (len - 1);
                    while (table[h] != null)
                        h = nextIndex(h, len);
                    table[h] = c;
                    size++;
                }
            }
        }
    }
}

来到这里就是通过构造方法把父线程中的散列表复制到子线程中(即外部数据->当前对象数据),此处也是childValue发挥作用的时候,需要注意的是childValue是调用父线程的而不是子线程的,即获取值的行为是父线程决定的。而InheritableThreadLocal默认就是复制用相同的值,如果需要自定义的话可以继承InheritableThreadLocal并覆盖childValue进行修改。(注意此处不是覆盖ThreadLocal,覆盖ThreadLocal是无法实现的,因为你无法覆盖getMapcreateMap这两个方法)

最后我们再来谈一下上面涉及的一些相关问题?

  • 过期Entry内存泄露问题?

    上文提到过期Entry的惰性删除机制和在使用过程中通过静态变量去声明ThreadLocal,都有可能会引起内存泄漏,这个问题单靠弱引用是无法被解决的,解决这个问题需要我们在使用完ThreadLocal之后手动的进行删除remove操作

  • 为什么这里是用开放寻址法,而不是链接法?

    在数据量小的情况下选择开放寻址法是没有问题的,因为开放寻址法占用空间小、无新增分配内存消耗等优势确实带来比链接法更好的性能。但是数据量较大的情况下更倾向于使用链接法。因为在数据量比较大时会占用了大量的连续空间,而且可能大部分的数据是空的,这样就大大降低了CPU缓存的命中率。而且对于开放寻址法的删除机制有可能导致查询时间复杂度位于在一个相对较高的花销,这也是需要考虑的。当然ThreadLocalMap自己也做了一定的优化,比如对于添加、删除元素进行了rehash操作,但是这样也会造成元素的大量移动,所以在数据量比较大的情况下,比如使用过多的ThreadLocal,还是需要谨慎。

当然对于想要在线程变量中存储稍微多一点数据还是可以使用上做一些优化的,比如说:

public class ThreadLocalHolder{

    /**
     * The constant threadContext.
     */
    private final static ThreadLocal<Map<String, Object>> THREAD_CONTEXT = new ThreadLocal<Map<String, Object>>() {
        @Override
        protected Map<String, Object> initialValue() {
            return new HashMap<>();
        }
    };

    /**
     * 取得thread context Map的实例。
     *
     * @return thread context Map的实例
     */
    private static Map<String, Object> getContextMap() {
        return THREAD_CONTEXT.get();
    }

    /**
     * Put.
     *
     * @param key   the key
     * @param value the value
     */
    public static void put(String key, Object value) {
        getContextMap().put(key, value);
    }

    /**
     * Remove object.
     *
     * @param key the key
     */
    public static void remove(String key) {
        getContextMap().remove(key);
    }

    /**
     * 清理线程所有被hold住的对象。以便重用!
     */
    public static void remove() {
        THREAD_CONTEXT.remove();
    }
}

这样使用的话其实是只是将ThreadLocalMap当作是一个中间层,真正的散列去到了HashMap

未经本人许可,禁止转载