浅析ThreadLocal

350 阅读8分钟

ThreadLocal

本文从问题入手,结合源码分析threadLocal的设计思想和实现方案。

一、如何解决多线程数据独立问题

web应用接口开发中,在处理请求时常常需要跨类,跨方法保存一些值,比如traceId,userId等,这时就需要想一些策略支持这种功能。

  • 方案1:使用静态类中的静态对象

优点:全局可用;缺陷:每个请求是一个独立的线程,无法保证每个线程的数据独立

  • 方案2:使用方法参数进行传递

优点:所需参数可透传上下文且线程数据独立;缺陷:太过麻烦

假如让我们设计线程类

  • 方案3:Thread类中添加一个局部变量保存该线程的信息,存的时候获取当前线程赋值,取的时候拿到当前线程取值

优点:所需参数可透传上下文且线程数据独立;缺点:若信息需要初始化值,则每个线程都要初始化一次

根据“任何软件工程遇到的任何问题都可以通过增加一个中间层来解决”原理,我们可以考虑在上一个方案的基础上加一层封装

  • 方案4:封装一个类来做Thread类中的存取操作,在这个封装类中设置一个默认值,当取不到Thread中变量值时为其设置默认值

优点:线程数据独立,且可为所有线程设置初始化值;缺点:无法在线程类中保存多个对象

二、ThreadLocal是什么

到我们的方案4,已经有点意思了。ThreadLocal是一个更完善的解决方案。先看类图:

如图所示,Thread类中维护了一个变量ThreadLocalMap记录该线程数据。

public class Thread implements Runnable {
    ...
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;
    ...
}

ThreadLocalMap中维护了一个Entry数组,可存多个对象,解决了方案4的缺点。

public class ThreadLocal<T> {
    ...
    static class ThreadLocalMap {
        ...
        private Entry[] table;
        
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        ...
    }
    ...
}
public abstract class Reference<T> {
    ...
    /**
     * Returns this reference object's referent.  If this reference object has
     * been cleared, either by the program or by the garbage collector, then
     * this method returns <code>null</code>.
     *
     * @return   The object to which this reference refers, or
     *           <code>null</code> if this reference object has been cleared
     */
    public T get() {
        return this.referent;
    }
    ...
}

注意这里Entry继承了WeakReference,key实际上是ThreadLocal的弱引用类型。如果一个对象只有弱引用与之关联,则该对象会在下次YGC时被回收。之所以这么设计,是为了ThreadLocal的生命周期不与线程生命周期强绑定,当ThreadLocal对象不需要强引用时,Entry的key会被回收变成null。当再次调用set或者get方法时,会自动清理key为null的value值,使Entry的值能自动回收。

ThreadLocalMap值的获取由ThreadLocal类封装,可统一初始化默认值。源码如下:

2.1、set方法
public class ThreadLocal<T> {    
    ...
	/**
     * Sets the current thread's copy of this thread-local variable
     * to the specified value.  Most subclasses will have no need to
     * override this method, relying solely on the {@link #initialValue}
     * method to set the values of thread-locals.
     *
     * @param value the value to be stored in the current thread's copy of
     *        this thread-local.
     */
    public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的成员变量
        ThreadLocalMap map = getMap(t);
        //若map不为空则直接赋值,否则新建
        if (map != null)
            //this指当前ThreadLocal对象
            map.set(this, value);
        else
            //创建ThreadLocalMap
            //Thread类的threadLocals持有ThreadLocalMap的引用
            createMap(t, value);
    }

    /**
     * Get the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param  t the current thread
     * @return the map
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    
    /**
     * Create the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the map
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    ...
}

ThreadLocalset方法其实调的是ThreadLocalMap中的set方法。这里通过hash算法计算存入Entry数组的位置,解决冲突的方式是线性探测再散列

    static class ThreadLocalMap {        
        ...
		/**
         * Set the value associated with key.
         *
         * @param key the thread local object
         * @param value the value to be set
         */
        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            //计算存入的位置
            int i = key.threadLocalHashCode & (len-1);
            //遍历查找key是否已存在
            for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                //key存在则覆盖原值
                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    //清理key为空的entry
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //存入value值,key为当前ThreadLocal
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //全量清理key为null的entry并再hash
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
        ...
    }
2.2、get方法
public class ThreadLocal<T> {
    ...
    /**
     * Returns the value in the current thread's copy of this
     * thread-local variable.  If the variable has no value for the
     * current thread, it is first initialized to the value returned
     * by an invocation of the {@link #initialValue} method.
     *
     * @return the current thread's value of this thread-local
     */
    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //若未获取到,则设置并返回初始值
        //ThreadLocal设置统一的初始值
        return setInitialValue();
    }
    
    /**
     * Get the map associated with a ThreadLocal. Overridden in
     * InheritableThreadLocal.
     *
     * @param  t the current thread
     * @return the map
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    
    /**
     * Variant of set() to establish initialValue. Used instead
     * of set() in case user has overridden the set() method.
     *
     * @return the initial value
     */
    //基本同set
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
    ...
}

ThreadLocalget方法实际调的是ThreadLocalMapgetEntry方法

    static class ThreadLocalMap {
        ...
        /**
         * Get the entry associated with key.  This method
         * itself handles only the fast path: a direct hit of existing
         * key. It otherwise relays to getEntryAfterMiss.  This is
         * designed to maximize performance for direct hits, in part
         * by making this method readily inlinable.
         *
         * @param  key the thread local object
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            //直接根据hash值获取,大概率能获取到
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                //根据save时解决hash冲突的方法继续查找
                return getEntryAfterMiss(key, i, e);
        }
        ...
    }
2.3、remove方法
public class ThreadLocal<T> {    
    ...
	/**
     * Removes the current thread's value for this thread-local
     * variable.  If this thread-local variable is subsequently
     * {@linkplain #get read} by the current thread, its value will be
     * reinitialized by invoking its {@link #initialValue} method,
     * unless its value is {@linkplain #set set} by the current thread
     * in the interim.  This may result in multiple invocations of the
     * {@code initialValue} method in the current thread.
     *
     * @since 1.5
     */
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }
    ...
}

实际调用ThreadLocalMapremove方法

    static class ThreadLocalMap {
        ...
        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    //key置空
                    e.clear();
                    //清理key为空的value
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
        ...
    }

三、弱引用回收过程分析

如上图所示,线程对象持有ThreadLocalMap对象的引用,ThreadLocalMap对象持有entry数组的引用。调用ThreadLocal的set方法实际存值在entry的value中,entry的key为ThreadLocal的弱引用。

正常执行流程为:

(1)ThreadLocal对象调用set方法,将值存入当前线程的ThreadLocalMap中的Entry中。如果当前ThreadLocalMap为空,则会新建Map并给初始值;

(2)ThreadLocal1使用之后会出栈,此时①处强连接断开,ThreadLocal对象1只有entry中的key的弱连接与之关联,下一次YGC时③处连接会断开,ThreadLocal对象1被回收,此时entry中key对应的值为null;

(3)ThreadLocal2的get或set操作都会触发entry数组清理所有key为null的value值,当entry的key和value都为空时,entry被回收。此时线程仍在进行,ThreadLocal的生命周期不与Thead生命周期关联,使ThreadLocal在使用完之后可以提前被回收;

(4)线程执行完毕,②出连接断开,对应Thread对象、ThreadLocalMap对象、Entry数组均被回收。

然而,实际使用场景中ThreadLocal往往不会定义在某个线程中,这样就失去了其共享线程之间数据的意义。正如以下官方描述:

{@code ThreadLocal} instances are typically private* static fields in classes that wish to associate state with a thread

四、常见问题

4.1、内存泄漏

ThreadLocal通常被定义为private static,如图中黄色标识。这就导致ThreadLocal对象一直有强引用指向,不会被自动回收,而entry中的值也不会被回收。直到线程结束,entry数组才会被回收。然而如果我们使用线程池,线程使用完后不回收,如果value中存的又是大对象就有可能导致内存溢出。

4.2、错误数据

当使用线程池时,线程使用后被回收到线程池。下一个使用的线程会发现上一个线程ThreadLocalMap中的数据仍然存在。

ThreadLocal的弱引用有过度设计的感觉。设计者既想透线程上下文传参,又不想参数与线程生命周期保持一致。导致此处设计不合理。

五、如何使用

public class RequestHeaderContext {

    private static ThreadLocal<Request.Header> threadLocalHeaderMap = new ThreadLocal<>();

    public static Request.Header getHeaderContext() {
        return threadLocalHeaderMap.get();
    }

    public static void setHeaderContext(Request.Header headerContext) {
        threadLocalHeaderMap.set(headerContext);
    }

    public static void clear() {
        threadLocalHeaderMap.remove();
    }
}

使用后一定要调用remove方法,手动释放内存,避免以上错误