阅读 22
ThreadLocal

ThreadLocal

介绍

除了通过锁来控制资源访问外,我们可以通过ThreadLocal增加资源来保证所有对象的安全性。ThreadLocal从名字可以看出,这是线程的局部变量。只有当前线程可以访问,是线程安全的。

ThreadLocal实现原理

ThreadLocal如何保证这些对象只被当前线程访问?下面我们带着这个简单的问题来详细分析下。

ThreadLocal的set方法

public void set(T value) {
    Thread t = Thread.currentThread();
    //获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        //线程中ThreadLocalMap不为空则将value添加到对应Map中
        map.set(this, value);
    else
        //对当前线程来说,首次创建ThreadLocalMap
        createMap(t, value);
}
复制代码

我们来看下createMap方法:

void createMap(Thread t, T firstValue) {
    //新建一个ThreadLocalMap,key为当前ThreadLocal实例,value为设置的数据,ThreadLocalMap赋值给当前线程的threadLocals变量
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
复制代码

从上面代码我们就可以看到ThreadLocalMap是维护在当前线程的threadLocals变量中,所以只有当前线程可以访问

Thread的threadLocals变量声明如下:

public class Thread implements Runnable {
···
/* ThreadLocal values pertaining to this thread. This map is maintained
 * by the ThreadLocal class. 
 */
ThreadLocal.ThreadLocalMap threadLocals = null;

/*
 * InheritableThreadLocal values pertaining to this thread. This map is
 * maintained by the InheritableThreadLocal class.
 */
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
···
复制代码

通过以上源码,我们可以看到通过ThreadLocalMap进行数据保存。那么ThreadLocalMap是一个什么样的结构那?

ThreadLocalMap结构

我们先来看下ThreadLocalMap的源码:

static class ThreadLocalMap {

    /**
     * The entries in this hash map extend WeakReference, using
     * its main ref field as the key (which is always a
     * ThreadLocal object).  Note that null keys (i.e. entry.get()
     * == null) mean that the key is no longer referenced, so the
     * entry can be expunged from table.  Such entries are referred to
     * as "stale entries" in the code that follows.
     */
    //静态内部类,继承WeakReference(弱引用),保存元素的key和value
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            //key为弱引用
            super(k);
            value = v;
        }
    }
  
/**
 * The initial capacity -- MUST be a power of two.
 */
private static final int INITIAL_CAPACITY = 16;

/**
 * The table, resized as necessary.
 * table.length MUST always be a power of two.
 */
private Entry[] table;

/**
 * The number of entries in the table.
 */
private int size = 0;

/**
 * The next size value at which to resize.
 */
private int threshold; // Default to 0
}

...
复制代码

结构类似如下:

未命名文件 (16).png

ThreadLocalMap底层为数组,元素为Entry,Entry中key为定义的ThreadLocal实例,value为定义的ThreadLocal中存储的数据。ThreadLocalMap在一个线程中不同元素存放的为定义的不同类型的ThreadLocal对象。也就是说ThreadLocalMap存放的是ThreadLocal变量的集合

为什么ThreadLocalMap使用数组,而不是像HashMap那样使用数组+链表的底层数据结构那。个人认为 ThreadLocalMap中每一个元素,也就是Entry存储的就是一种类型的ThreadLocal,我们定义的ThreadLocal类型一般不会太多。所以无需定义成数组+链表的结构,这样获取数据也会更快一些。

另外需要注意的是Entry中的key(也就是ThreadLocal的实例)是弱引用,value是强引用

ThreadLocalMap的set方法

其中比较重要的是ThreadLocalMap的set(ThreadLocal<?> key, Object value)方法:

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    //数组下标
    int i = key.threadLocalHashCode & (len-1);
    //通过数组下标依次向后遍历
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        //存在相同的ThreadLocal对象,则覆盖value,并退出循环
        if (k == key) {
            e.value = value;
            return;
        }
        //key为弱引用,k == null表示key被垃圾回收
        if (k == null) {
            //初始化Entry放在该位置(替换旧的Entry(key被GC))
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //未找到e!=null的需要替换的元素,则在i为空位置的地方放入新的Entry
    tab[i] = new Entry(key, value);
    //数组长度自增
    int sz = ++size;
    //cleanSomeSlots为清除一些弱引用key被GC的Entry元素
    //如果cleanSomeSlots返回true表示清除了部分被GC的元素,则不会进行rehash。如果cleanSomeSlots返回false表示没有Entry的key被GC(没有元素被清除),这时数组长度已经达到阈值threshold则会进行扩容rehash
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        //扩容rehash
        rehash();
}
复制代码

set的时候当Entry!=null&k==null的话则会进行替换旧的Entry的操作,会执行replaceStaleEntry方法。

//key和value为要添加的元素,staleSlot为key被GC的需要被替换的元素
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    //slotToExpunge为需要被清除的下标
    int slotToExpunge = staleSlot;
    //从staleSlot下标向前遍历,查找e!=null&key==null最小下标并记录到slotToExpunge
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    //从staleSlot下标位置向后遍历
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run. 
        //e!=null&k == key 找到相同的key则替换value值
        if (k == key) {
            e.value = value;
            //staleSlot位置key被GC,则把staleSlot位置元素和i位置替换(因为采用的开放地址法进行处理hash冲突,前面的元素失效,则后面hash算法相同向后移动的元素需要向前移动)
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            //staleSlot前面没有key被GC的元素,则设置向后遍历的i设置到slotToExpunge(i为大于staleSlot的首个无效元素)
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            //清理失效元素
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        //staleSlot位置向前遍历未找到失效Entry,向后遍历出现k==null的Entry则把slotToExpunge设置为i(向后遍历的第一个失效的元素下标)
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    //如果没找到相同的key,则把失效的Entry的value置为null并将新的Entry设置到staleSlot下标位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    //表明还有其它失效的Entry
    if (slotToExpunge != staleSlot)
        //清除失效元素
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
复制代码

从上面代码我们可以看到在set的时候也在清除Entry中key被回收的元素,其中涉及到两个方法,一个是expungeStaleEntry,一个是cleanSomeSlots。我们先来看下expungeStaleEntry方法:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    //清除staleSlot位置的元素并数组长度减1
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        //e!=null&k==null 的时候说明key弱引用被回收,则清除该元素
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            //计算key的首个下标(没有hash冲突的下标)
            int h = k.threadLocalHashCode & (len - 1);
            //key的首个下标不等于当前的i下标则表示存在hash冲突,i位置是遇到hash冲突向后寻找递增的下标
            if (h != i) {
                //重新编排,将i位置数据设置为空
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                //从h下标向后遍历寻找第一个为null的下标并把e设置到该位置(为了防止前面存在失效的数据,因为在获取元素的时候如果Entry为null则不会继续向下面寻找)。该步操作也相当于数据向前移动
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    //返回staleSlot下标向后第一个Entry为null的下标
    return i;
}
复制代码

expungeStaleEntry方法也就是清除staleSlot下标后的Entry!=null&key==null的元素。对Entry!=null&key!=null的这个元素如果在前面从hash算法算出的下标(a)到该元素位置(b)之间存在null的Entry,则会把该元素放到a到b之间为null的最小的下标位置。(这是由于获取元素的时候如果判断Entry为null,则不会继续向下遍历,所以需要进行元素的移动)。

接下来我们再来看下cleanSomeSlots方法。

private boolean cleanSomeSlots(int i, int n) {
    //是否发生清除的标识
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            //再次进行清除失效的元素,expungeStaleEntry可看上面的方法解析
            i = expungeStaleEntry(i);
        }
        //二进制无符号向右移动一位(除以2)。n为数组的长度,满足(n >>>= 1) != 0则会进行执行循环体。执行部分遍历而不是全部遍历进行清除失效的元素。也是性能方面的权衡吧
    } while ( (n >>>= 1) != 0);
    return removed;
}
复制代码

cleanSomeSlots方法会对expungeStaleEntry返回的需要替换的下标(staleSlot)后面(经过hash冲突元素的向前移动后)的首个为空的Entry后面的元素继续进行清除无效的元素。

从ThreadLocal的set逻辑中很大一部分是对Entry中key为null已被GC数据的清除。这样做可以及时清除无效元素。

ThreadLocal的get方法

上面分析了ThreadLocal的get方法,下面我们来分析下它的set方法。

public T get() {
    Thread t = Thread.currentThread();
    //获取当前线程Thread的ThreadLocalMap变量
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //当前线程的ThreadLocalMap存在,则根据ThreadLocal实例获取保存数据的Entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //当前线程ThreadLocalMap为空,则进行初始化(会初始化一个value为null的Entry)
    return setInitialValue();
}
复制代码

上面主要方法为getEntry,我们来看下我们是怎么在ThreadLocalMap中获取到对应key的Entry的。

private Entry getEntry(ThreadLocal<?> key) {
    //定位数组下标
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    //定位数组下标对应的key和参数key一致则直接返回当前Entry
    if (e != null && e.get() == key)
        return e;
    else
        //定位数组下标非当前key则继续向后遍历寻找
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    //向后遍历寻找相同key的Entry
    while (e != null) {
        ThreadLocal<?> k = e.get();
        //找到key相同的Entry则直接返回
        if (k == key)
            return e;
         //key为null(被垃圾回收),则调用expungeStaleEntry进行清除key被垃圾回收的Entry(回收value)
        if (k == null)
            expungeStaleEntry(i);
        else
            //向后遍历下一个下标
            i = nextIndex(i, len);
        e = tab[i];
    }
    //如果直到遍历到Entry为null的元素还未找到,则数组中不存在该元素,直接返回null。
    //这个地方注意下条件为e == null就直接会返回而不会继续遍历(这就是为什么在set的时候要对hash冲突的数据进行一次重排列(可能会涉及到数据的向前移动,因为前面的key被垃圾回收了)),如果不这样的话,那么会导致在get的时候数组中存在当前要查找的key,但是却返回null的错误结果。
    return null;
}
复制代码

上面我们对set和get方法做了相关解析,我们会发现在set和get方法的时候都会调用expungeStaleEntry来对被垃圾回收的key相关Entry进行清除,来释放数组空间。

我们可以看到会对Entry弱引用key进行GC,下面我们再具体理下ThreadLocal的回收机制。

ThreadLocal的回收机制

如下图所示:

未命名文件 (17).png

我们可以看到使用ThreadLocal实例作为ThreadLocalMap的key,但只是弱引用。当ThreadLocal实例的外部强引用被回收时,ThreadLocalMap中的key就会变成null(被GC)。当系统进行ThreadLocalMap清理时,就会将这些垃圾数据回收。上面我们分析的set和get方法时都会发现在适当时候都会进行清理这些垃圾数据。

ThreadLocal内存溢出问题

ThreadLocalMap中的key为弱引用,当外部强引用丢失则key会被回收,但是value是强引用。如果线程Thread不退出,则value的强引用就会一直存在。例如对线程池来说,核心线程可能会一直存在(系统生命周期一直存在,例如 Executors.newFixedThreadPool),那么value可能一直不会被回收,则可能会引发内存泄漏的问题。

jdk中是怎么避免的那?在ThreadLocal的实现原理中我们介绍了set方法和get方法在特定的时候都会进行清除key==null&&Entry!=null的数据expungeStaleEntry方法清除垃圾数据)。但是它并不能完全避免内存泄漏。例如我们set的时候Entry的key不存在hash冲突,get的时候访问的都是一直存在的ThreadLocal,那么ThreadLocalMap里面的get不到的但是key已经被垃圾回收的Entry就不会被清除,后期就会可能导致内存泄漏。

所以我们需要及时回收对象,养成良好的习惯,在不需要ThreadLocal变量的时候调用ThreadLocal.remove()方法将这个变量移除

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    //获取数组下标
    int i = key.threadLocalHashCode & (len-1);
    //根据数组下标向后遍历
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
         //找到key相等的Entry进行清除操作
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}
复制代码

InheritableThreadLocal

如果在主线程中运行一个子线程,我们如果想要子线程可以访问到主线程的ThreadLocal对象数据的话,我们可以使用InheritableThreadLocal。

那么子线程是怎么获得主线程的ThreadLocal对象的那,那是因为在线程初始化的时候将主线程的InheritableThreadLocal设置给了子线程的InheritableThreadLocal。我们来看下Thread初始化相关代码就明白了。

if (inheritThreadLocals && parent.inheritableThreadLocals != null)
    this.inheritableThreadLocals =
        ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
复制代码

ThreadLocal使用案例

合理使用ThreadLocal可以带来很大的性能提高。例如典型的案例就是在多线程下产生随机数,我们来看下这个案例:

public class ThreadLocalDemo {
    public static final int GEN_COUNT = 10000000;
    public static final int THREAD_COUNT = 4;
    static ExecutorService exe = Executors.newFixedThreadPool(THREAD_COUNT);
    public static Random rnd = new Random(123);

    public static ThreadLocal<Random> tRnd = ThreadLocal.withInitial(() -> new Random(123));

    public static class RndTask implements Callable<Long>{
        //0:多个线程访问同一个Random 1:多个线程使用ThreadLocal包装的Random实例
        private int mode = 0;

        public RndTask(int mode) {
            this.mode = mode;
        }

        public Random getRandom(){
            if(mode == 0){
                return rnd;
            } else if (mode == 1){
                return tRnd.get();
            } else {
                return null;
            }
        }

        @Override
        public Long call() throws Exception {
            long b = System.currentTimeMillis();
            for (int i =0; i < GEN_COUNT; i++){
                getRandom().nextInt();
            }
            long e = System.currentTimeMillis();
            System.out.println(Thread.currentThread().getName() + "spend:" + (e - b) + "ms");
            return e - b;
        }
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        Future<Long>[] futs = new Future[THREAD_COUNT];
        for (int i = 0; i < THREAD_COUNT; i++){
            futs[i] = exe.submit(new RndTask(0));
        }
        int totalTIme = 0;
        for(int i = 0; i < THREAD_COUNT; i++){
            totalTIme +=futs[i].get();
        }
        System.out.println("多线程访问同一个Random实例:" + totalTIme + "ms");
        totalTIme = 0;
        for (int i = 0; i < THREAD_COUNT; i++){
            futs[i] = exe.submit(new RndTask(1));
        }
        for(int i = 0; i < THREAD_COUNT; i++){
            totalTIme +=futs[i].get();
        }
        System.out.println("使用ThreadLocal包装Random实例:" + totalTIme + "ms");

    }
}
复制代码

输出为:

pool-1-thread-2spend:3247ms
pool-1-thread-1spend:3288ms
pool-1-thread-3spend:3290ms
pool-1-thread-4spend:3294ms
多线程访问同一个Random实例:13119ms
pool-1-thread-1spend:151ms
pool-1-thread-3spend:152ms
pool-1-thread-2spend:152ms
pool-1-thread-4spend:154ms
使用ThreadLocal包装Random实例:609ms
复制代码

我们可以看到多线程访问同一个Random实例花费了13s多,使用ThreadLocal包装Random实例只花费了0.6秒多,可见在该实例中使用ThreadLocal对性能有很大的提升。

参考书籍:《Java高并发程序设计(第2版)》

文章分类
后端
文章标签