ThreadLocal源码解读

585 阅读12分钟

ThreadLocal<T>介绍

我们知道在Java并发编程中我们一般会使用synchronized关键字或者CAS操作类来进行共享资源的同步。ThreadLocal类为并发编程提供了另外一种思路,它将共享资源作为每一个线程的副本,这样在某些场景下,我们就可以无需同步完成并发程序的设计与开发,Andoid中Looper类的实现里面就使用到了ThreadLocal。
ThreadLocal本质是本地线程副本工具类,将线程与线程私有变量做一个映射,各个线程之间的变量互不干扰。

ThreadLocal类全览

ThreadLocal存储结构

通过结构图我们可以总结出以下几点:

  • 每一个Thread线程内部都有一个threadLocals对象。
  • threadLocals是一个数组结构,每一项存储的是Entry结构,其中key是ThreadLocal对象,value是线程的变量副本。
  • ThreadLocal作为一个工具类,提供方法对线程的threadLocals对象进行维护,支持设置、获取、移除等操作。

ThreadLocal源码分析

从上面的类全览图我们可以看到,ThreadLocal类提供的主要方法有set(T)、get()、remove()、initialValue()等,下面我们就从这些常用的方法开始分析。

get()方法

get()方法是用来获取当前线程存储的变量副本值。源码如下:

    /**
     * Returns the value in the current thread's copy of this
     * thread-local variable.  If the variable has no value for the
     * current thread, it is first initialized to the value returned
     * by an invocation of the {@link #initialValue} method.
     *
     * @return the current thread's value of this thread-local
     */
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

从源码中我们看到第一行就是获取当前线程对象,然后调用getMap获取ThreadLocalMap对象,我们继续往下跟,看getMap是什么逻辑。

    /**
     * Get the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param  t the current thread
     * @return the map
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

我们看到这个方法直接返回的就是线程对象的threadLocals成员变量,我们继续往下跟,在Thread.java文件中有如下定义

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

经过上面的分析,我们可以梳理出get()方法的相关操作步骤:

  • 获取到当前线程的ThreadLocalMap类型对象threadLocals.
  • 将当前ThreadLocal对象作为key,从map中获取相应的Entry(包含key,value结构). 相关代码如下:
    /**
         * Get the entry associated with key.  This method
         * itself handles only the fast path: a direct hit of existing
         * key. It otherwise relays to getEntryAfterMiss.  This is
         * designed to maximize performance for direct hits, in part
         * by making this method readily inlinable.
         *
         * @param  key the thread local object
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
  • 如果map为空,即当前线程中没有存储变量副本,那么就调用setInitialValue()方法来设置默认值(默认值可以通过initialValue来进行自定义,默认返回null)。我们来看下这个方法的逻辑:
    /**
     * Variant of set() to establish initialValue. Used instead
     * of set() in case user has overridden the set() method.
     *
     * @return the initial value
     */
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

这个方法的逻辑就是:获取设置的默认值,然后将其设置到map中。

set()方法

    /**
     * Sets the current thread's copy of this thread-local variable
     * to the specified value.  Most subclasses will have no need to
     * override this method, relying solely on the {@link #initialValue}
     * method to set the values of thread-locals.
     *
     * @param value the value to be stored in the current thread's copy of
     * this thread-local.
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

set()方法的逻辑如下:

  • 获取当前线程的ThreadLocalMap对象map.
  • 如果map!=null就将新值更新到map中,否则创建map来存储第一个设置的值,我们来看createMap方法逻辑.
    /**
     * Create the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the map
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

initialValue()方法

该方法是用来设置默认值的,即map中找不到数据时返回的默认值,我们在初始化ThreadLocal的时候可以自定义这个默认值,源码如下:

    /**
     * Returns the current thread's "initial value" for this
     * thread-local variable.  This method will be invoked the first
     * time a thread accesses the variable with the {@link #get}
     * method, unless the thread previously invoked the {@link #set}
     * method, in which case the {@code initialValue} method will not
     * be invoked for the thread.  Normally, this method is invoked at
     * most once per thread, but it may be invoked again in case of
     * subsequent invocations of {@link #remove} followed by {@link #get}.
     *
     * <p>This implementation simply returns {@code null}; if the
     * programmer desires thread-local variables to have an initial
     * value other than {@code null}, {@code ThreadLocal} must be
     * subclassed, and this method overridden.  Typically, an
     * anonymous inner class will be used.
     *
     * @return the initial value for this thread-local
     */
    protected T initialValue() {
        return null;
    }

我们可以像下面这样,自定义默认值:

ThreadLocal<String> threadLocal=new ThreadLocal<String>() {
    @Override
    protected String initialValue() {
        return "default";
    }
};

remove()方法

    /**
     * Removes the current thread's value for this thread-local
     * variable.  If this thread-local variable is subsequently
     * {@linkplain #get read} by the current thread, its value will be
     * reinitialized by invoking its {@link #initialValue} method,
     * unless its value is {@linkplain #set set} by the current thread
     * in the interim.  This may result in multiple invocations of the
     * {@code initialValue} method in the current thread.
     *
     * @since 1.5
     */
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

remove()方法比较简单,就是移除map中的副本。

ThreadLocalMap解析

ThreadLocalMap类图:

在ThreadLocalMap中,是使用Entry结构来存储K-V数据的,key是ThreadLocal对象,这个在构造函数中已经指定了,源码如下:

    /**
         * The entries in this hash map extend WeakReference, using
         * its main ref field as the key (which is always a
         * ThreadLocal object).  Note that null keys (i.e. entry.get()
         * == null) mean that the key is no longer referenced, so the
         * entry can be expunged from table.  Such entries are referred to
         * as \"stale entries\" in the code that follows.
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

Entry继承WeakReference(弱引用,生命周期只能存活到下次GC前),但只有Key是弱引用类型的,Value并非弱引用。

ThreadLocalMap成员变量

static class ThreadLocalMap {
        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         */
        private int threshold; // Default to 0
}

从源码做我们看到,ThreadLocalMap类结构很简单,主要结构是一个Entry[]数组,来存放每一个线程存放的变量副本。Entry结构和HashMap的Node结构还是有很大不同的,Node结构需要存储next节点,发生冲突后会形成链表结构,而Entry结构只是单纯的K-V结构。那么如果发生冲突如何解决?我们去set方法里面寻找答案。

    /**
         * Set the value associated with key.
         *
         * @param key the thread local object
         * @param value the value to be set
         */
        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.'

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

从set()方法中我们可以看到,先根据key的hashCode来确定元素在table列表中的位置,然后判断是否存在数据,如果存在数据就更新操作。然后e = tab[i = nextIndex(i, len)])这一句很关键,我们看到会调用nextIndex方法获取下一个节点。源码如下:

    /**
     * Increment i modulo len.
     */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

这里使用的是固定步长来查找下一个可以存放的位置。现在我们可以总结ThreadLocalMap解决Hash冲突时通过固定步长,查找上一个或者下一个位置来存放Entry数据。

固定步长引入的效率问题

如果有大量的ThreadLocal对象存入ThreadLocalMap对象中,会导致冲突,甚至是二次冲突,这就一定程度上降低了效率。那我们如何解决呢? 良好建议:
每个线程只存一个变量,这样所有的线程存放到map中的Key都是相同的ThreadLocal,如果一个线程要保存多个变量,就需要创建多个ThreadLocal,多个ThreadLocal放入Map中时会极大的增加Hash冲突的可能。如果需要存储多份数据,可以使用包装类进行包装下。

ThreadLocal内存泄漏探究

很多文章都会提到不正确的使用ThreadLocal会导致内存泄漏的发生,这一节我们来好好研究下内存泄漏时如何发生的。 我们来看一下在使用ThreadLocal对象时,对象的内存分布:

ThreadLocal的原理:每个Thread内部维护着一个ThreadLocalMap,它是一个Map。这个映射表的Key是一个弱引用,其实就是ThreadLocal本身,Value是真正存的线程变量Object。

也就是说ThreadLocal本身并不真正存储线程的变量值,它只是一个工具,用来维护Thread内部的Map,帮助存和取。注意上图的虚线,它代表一个弱引用类型,而弱引用的生命周期只能存活到下次GC前。

ThreadLocal内存泄漏原因

ThreadLocal对象在ThreadLocalMap对象中是使用一个弱引用进行被Entry中的Key进行引用的,因此如果ThreadLocal对象没有外部强引用来引用它,那么ThreadLocal对象会在下次GC的时候被回收(注意:如果ThreadLocal对象还有强引用引用,GC过后,WeakReference还是不会被回收的)。这时候Entry结构中就会出现Null Key的情况。外部读取ThreadLocal是无法 使用Null Key来找到Value的。因此如果当前线程执行时间足够长的话,就会形成一条强引用的链。Thread-->ThreadLocalMap对象-->Entry(Key为Null,Value还引用其他对象)-->Object对象。这就导致了Entry对象不会被回收,当然Object对象也不会回收。

相关解决方案

ThradLocal类的设计者也考虑到这种情况了,所以在ThreadLocal类的相关操作方法中,例如:get()、set()、remove()等方法中都会寻找Key为Null的Entry节点,将Entry的Key和Value结构都设置为Null,利于GC回收。下面我们就来看相关方法。

    /**
         * Get the entry associated with key.  This method
         * itself handles only the fast path: a direct hit of existing
         * key. It otherwise relays to getEntryAfterMiss.  This is
         * designed to maximize performance for direct hits, in part
         * by making this method readily inlinable.
         *
         * @param  key the thread local object
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
    
    
        /**
         * Version of getEntry method for use when key is not found in
         * its direct hash slot.
         *
         * @param  key the thread local object
         * @param  i the table index for key's hash code
         * @param  e the entry at table[i]
         * @return the entry associated with key, or null if no such'
         */
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

从上面的源码我们看到,在get方法中会调用getEntry()方法,getEntry()方法内部会根据entry对象和key是否是Null执行getEntryAfterMiss()方法。我们在getEntryAfterMiss()方法内部可以看到k==null时执行的是expungeStaleEntry方法(即删除、擦除)。我们继续看这个方法。

        /**
         * Expunge a stale entry by rehashing any possibly colliding entries
         * lying between staleSlot and the next null slot.  This also expunges
         * any other stale entries encountered before the trailing null.  See
         * Knuth, Section 6.4
         *
         * @param staleSlot index of slot known to have null key
         * @return the index of the next null slot after staleSlot
         * (all between staleSlot and this slot will have been checked
         * for expunging).
         */
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

从源码中我们看到,把当前的Entry删除后,会继续循环往下寻找是否存在Key为Null的节点,如果存在的话也删除,防止内存泄漏发生。

ThreadLocal为什么使用弱引用

有很多人觉得之所以发生内存泄漏是因为ThreadLocal中Entry结构Key是弱引用导致的。其实是因为Entry中Key为Null之后,没有主动清除Value所导致的。那Entry结构为什么使用弱引用呢?官方的注释是这样的:

To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys.
为了处理非常大和运行周期非常长的线程。哈希表使用弱引用。 

下面我们分两种情况来看:

  • ThreadLocal对象使用强引用 引用的ThreadLocal的对象被回收了,但是ThreadLocalMap还持有ThreadLocal的强引用,如果没有手动删除,ThreadLocal不会被回收,导致Entry内存泄漏。
  • ThreadLocal对象使用弱引用 引用的ThreadLocal的对象被回收了,由于ThreadLocalMap持有ThreadLocal的弱引用,即使没有手动删除,ThreadLocal也会被回收。value在下一次ThreadLocalMap调用set,get,remove的时候会被清除。 比较两种情况,我们可以发现:由于ThreadLocalMap的生命周期跟Thread一样长,如果都没有手动删除对应key,都会导致内存泄漏,但是使用弱引用可以多一层保障:弱引用ThreadLocal不会内存泄漏,对应的value在下一次ThreadLocalMap调用set,get,remove的时候会被清除。

总结

每次使用完ThreadLocal,都调用它的remove()方法,清除数据。
在使用线程池的情况下,没有及时清理ThreadLocal,不仅是内存泄漏的问题,更严重的是可能导致业务逻辑出现问题。所以,使用ThreadLocal就跟加锁完要解锁一样,用完就清理。