Java中的ThreadLocal原理剖析

63 阅读8分钟

ThreadLocal简介

当多线程访问同一个共享变量的时候容易出现并发问题,特别是多个线程对一个变量进行写入的时候,为了保证线程安全,一般使用者在访问共享变量的时候需要进行额外的同步措施才能保证线程安全性。ThreadLocal是除了加锁这种同步方式之外的一种保证一种规避多线程访问出现线程不安全的方法,当我们在创建一个变量后,如果每个线程对其进行访问的时候访问的都是线程自己的变量这样就不会存在线程不安全问题。

不过需要注意的是ThreadLocal不能保证共享变量在多线程中的并发安全性,它是通过每个线程各自持有一个实例来避免共享,自然也不存在并发安全问题了。

ThreadLocal是JDK包提供的,它提供线程本地变量,如果创建一个ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个副本,在实际多线程操作的时候,操作的是自己本地内存中的变量,从而规避了线程安全问题。

ThreadLocal 实现原理

ThreadLocalMap初始化

ThreadLocalMap是ThreadLocal中的一个内部类,是ThreadLocal的核心,本质上是一个存放数据的Map结构,其定义如下:

主要的字段:

  • table:Entry类型的数组,用于存储数据的数组
  • size:数组元素的个数
  • threshold:下一次容器扩容的大小
  • INITIAL_CAPACITY:数组容器的初始大小,默认为16

调用ThreadLocal.set(Object value)时,可能存在两种情况:

  • 存储数据的Map不存在
  • 存储数据的Map存在
public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            map.set(this, value);
        } else {
            createMap(t, value);
        }
    }

如果根据当前线程去Map去查询不到对应的值,则创建Map:

void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

创建Map:

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

流程:

  • Map中的数组容器大小设置为INITIAL_CAPACITY,也就是16
  • 对第一个需要存入的数据进行hash取模,判断该数据需要存储在数组哪个位置
  • 将数据key和value构造为一个Eentry对象,并存入数组中
  • 数组大小设置为1
  • 下次扩容大小设置为16

ThreadLocalMap存储的时候需要把当前ThreadLocal对象作为Key,要存入的值作为Value存放到ThreadLocalMap中。

向ThreadLocal中设置值

已经存在Map的情况下set值流程:

private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

从ThreadLocal中获取值

获取值的方法很简单,看get()方法的源码:

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

根据当前线程获取对应的ThreadLocalMap,如果Map存在,则以当前的ThreadLocal为key从ThreadLocalMap获取值返回;如果Map不存在,则调用setInitialValue()方法。该方法的定义如下:

private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            map.set(this, value);
        } else {
            createMap(t, value);
        }
        if (this instanceof TerminatingThreadLocal) {
            TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
        }
        return value;
    }

从代码来看,其实就是初始化一下ThreadLocalMap,这里和set()方法初始化ThreadLocalMap类似,然后返回一个默认的初始值null,为什么不用set()方法呢?注释里解释说是为了防止set()被重载,导致set()方法的初始化失效。

ThreadLocalMap中如何解决Hash冲突

常见的解决Hash冲突的方法有:

  • 链式地址法
  • 开放寻址法
  • 再Hash法
  • 建立公共溢出池法

其中再Hash法算法复杂度高,建立公共溢出池法浪费内存,用的比较多的是前两种。

HashMap是通过链式地址法来解决hash冲突,这种方法的思想是一旦发生了冲突,就去寻找下一个空的三列地址,只要散列表足够大,空的散列地址总能找到,并把数据存入数组中。

比如有个长度为10的数组,已经有元素{12, 33, 4,5,15,25},用散列函数f(key)=key mod 元素后,可以得到前面四个元素{12, 33, 4,5}在当前数组中位置如下:

数组index0123456789
key123345
valuevaluevaluevaluevalue

当计算key=15时,发现f(15) = 5,和元素5计算得到的位置冲突。

于是我们应用上面的公式f(15) = (f(15)+1) mod 10 =6。于是将15存入下标为6的位置。

数组index0123456789
key12334515
valuevaluevaluevaluevaluevalue

\

链地址法和开放地址法的优缺点:

开放地址法:

  1. 容易产生堆积问题,不适于大规模的数据存储。
  2. 散列函数的设计对冲突会有很大的影响,插入时可能会出现多次冲突的现象。
  3. 删除的元素是多个冲突元素中的一个,需要对后面的元素作处理,实现较复杂。

链地址法:

  • 处理冲突简单,且无堆积现象,平均查找长度短。
  • 链表中的结点是动态申请的,适合构造表不能确定长度的情况。
  • 删除结点的操作易于实现。只要简单地删去链表上相应的结点即可。
  • 指针需要额外的空间,故当结点规模较小时,开放定址法较为节省空间。

至于ThreadLocalMap为什么采用开放地址法,其实是考虑到ThreadLocal存储的数据量一般不会很大,并且key是弱引用类型,容易及时清除数据,这个时候开放地址法简单的结构会显得更省空间,同时数组的查询效率也是非常高,加上第一点的保障,冲突概率也低。

ThreadLocal中的内存泄露问题

ThreadLocalMap中的数组类型Entry定义:

static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

可以看到Entry中的Key是一个弱引用对象类型,因此当ThreadLocal没有引用时,key就被GC进行自动回收。但是它的Value不是弱引用,因为如果Value也是弱引用的话,会导致GC的时候会被回收,对象就无法访问了。

当Entry中的Key被GC自动回收,但是Value还在Map中,并且其他地方也未使用到Value对象时,就会导致Value在内存进行永久堆积,导致内存泄露。

如何解决内存泄露问题?

既然是由于Value导致的内存泄露,那我们需要考虑的是在适当的时机主动把Value设置为Null,可以通过不需要数据时手动调用ThreadLocal的remove()方法来解决,该方法定义:

public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null) {
             m.remove(this);
         }
     }

本质上就是从ThreadLocalMap中找到当前线程对应的数据,将该数据进行删除,类似HashMap中的remove()方法。

ThreadLocal应用场景

ThreadLocal适用于以下两种场景:

  • 每个线程需要各自单独的实例,互不影响
  • 实例需要在一个线程中不同方法中共享,但不希望被多线程共享。

Session共享

可以尝试使用ThreadLocal替代Session的使用,当用户要访问需要授权的接口的时候,可以现在拦截器中将用户的Token存入ThreadLocal中;之后在本次访问中任何需要用户用户信息的都可以直接冲ThreadLocal中拿取数据。例如自定义获取用户信息的类AuthHolder:

public class AuthNHolder {
    private static final ThreadLocal<Map<String,String>> threadLocal = new ThreadLocal<>();
    public static void map(Map<String,String> map){
        threadLocal.set(map);
    }
    // 获取用户id
    public static String userId(){
        return get("userId");
    }
    // 根据键值获取对应的信息
    public static String get(String key){
        Map<String,String> map = getMap();
        return map.get(key);
    }
    // 用完清空ThreadLocal
    public static void clear(){
        threadLocal.remove();
    }
}

在多个方法中传参

假设我们一个业务需要调用多个方法传参完成业务处理,同时每个方法都需要用到统一个对象时,就可以使用ThreadLocal来代替显式的参数传递,因为使用参数传递会导致代码的耦合度比较高,同一个参数传太长的链条的话看起来也不优雅。

当该对象用ThreadLocal包装过后,就可以保证在该线程中独此一份,同时和其他线程隔离。

例如在Spring的@Transaction事务声明的注解中就使用ThreadLocal保存了当前的Connection对象,避免在本次调用的不同方法中使用不同的Connection对象。