ThreadLocal 学习

489 阅读10分钟

JDK 中的定义

ThreadLocal 提供了线程本地的实例。它与普通变量的区别在于,每个使用该变量的线程都会初始化一个完全独立的实例副本。ThreadLocal 变量通常被private static修饰。当一个线程结束时,它所使用的所有 ThreadLocal 相对的实例副本都可被回收。

使用 ThreadLocal 实现线程 ID 自增器:

private static class ThreadId {
    private static final AtomicInteger nextId = new AtomicInteger(0);

    protected static ThreadLocal<Integer> threadId = new ThreadLocal<Integer>() {
        @Override
        protected Integer initialValue() {
            return nextId.getAndIncrement();
        }
    };

    public static int get() {
        return threadId.get();
    }
}

ThreadLocal 适用于每个线程需要自己独立的实例且该实例需要在多个方法中被使用,也即变量在线程间隔离而在方法或类间共享的场景。

ThreadLocal原理

Thread维护ThreadLocal与实例的映射

每个 Thread 会维护一个 ThreadLocal 与实例的映射 Map。需要注意如果不删除这些引用(映射),则这些 ThreadLocal 不能被回收,可能会造成内存泄漏。

ThreadLocal实现

Map 由 ThreadLocal 类的静态内部类 ThreadLocalMap 提供。Map 的实例维护某个 ThreadLocal 与具体实例的映射。与 HashMap 不同的是,ThreadLocalMap 的每个 Entry 都是一个对键的弱引用,这一点从super(k)可看出。另外,每个 Entry 都包含了一个对 值 的强引用。

static class Entry extends WeakReference<ThreadLocal<?>> {
  /** The value associated with this ThreadLocal. */
  Object value;
  Entry(ThreadLocal<?> k, Object v) {
    super(k);
    value = v;
  }
}

使用弱引用的原因在于,当没有强引用指向 ThreadLocal 变量时,它可被回收,从而避免 ThreadLocal 不能被回收而造成的内存泄漏的问题。

但是,这里又可能出现另外一种内存泄漏的问题。ThreadLocalMap 维护 ThreadLocal 变量与具体实例的映射,当 ThreadLocal 变量被回收后,该映射的键变为 null,该 Entry 无法被移除。从而使得实例被该 Entry 引用而无法被回收造成内存泄漏。

读取实例

public T get() {
  Thread t = Thread.currentThread();
  ThreadLocalMap map = getMap(t);
  if (map != null) {
    ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
      @SuppressWarnings("unchecked")
      T result = (T)e.value;
      return result;
    }
  }
  return setInitialValue();//刚开始时线程中的ThreadLocalMap为null走这里
}

读取实例时,线程首先通过getMap(t)方法获取自身的 ThreadLocalMap。从如下该方法的定义可见,该 ThreadLocalMap 的实例是 Thread 类的一个字段,即由 Thread 维护 ThreadLocal 对象与具体实例的映射

ThreadLocalMap getMap(Thread t) {
  return t.threadLocals;
}

ThreadLocalMap 初始化

private T setInitialValue() {
  T value = initialValue();
  Thread t = Thread.currentThread();
  ThreadLocalMap map = getMap(t);
  if (map != null)
    map.set(this, value);
  else
    createMap(t, value);
  return value;
}

该方法为 private 方法,无法被重载。

首先,通过initialValue()方法获取初始值。该方法为 public 方法,且默认返回 null。所以典型用法中常常重载该方法。上例中即在内部匿名类中将其重载。

protected T initialValue() {
    return null;
}

然后拿到该线程对应的 ThreadLocalMap 对象,若该对象不为 null,则直接将该 ThreadLocal 对象与对应实例初始值的映射添加进该线程的 ThreadLocalMap中。若为 null,则先创建该 ThreadLocalMap 对象再将映射添加其中。

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);//根据hash码给映射找位置
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

/**
 * table增长的阀值是table长度的2/3;
 */
private void setThreshold(int len) {
    threshold = len * 2 / 3;
}

这里并不需要考虑 ThreadLocalMap 的线程安全问题。因为每个线程有且只有一个 ThreadLocalMap 对象,并且只有该线程自己可以访问它,其它线程不会访问该 ThreadLocalMap,也即该对象不会在多个线程中共享,也就不存在线程安全的问题。

ThreadLocalMap 不为 null 时,则尝试从中取值

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

获取与ThreadLocal关联的条目。 此方法本身仅处理快速路径:直接命中现有ThreadLocal。 否则它会转发到getEntryAfterMiss方法。 这旨在使直接命中的性能最大化。如果对应位置的值为null,这就存在如下几种可能:

  • key对应的值确实为null
  • 由于位置冲突,key对应的值存储的位置并不在i位置上,即i位置上的null并不属于key的值。

因此,需要一个函数再次去确认key对应的value的值,即getEntryAfterMiss函数:

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

这就是擦除(expunge)过期Entry的方法了,传进来的staleSlot就是有null 键值的table数组的下标。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 擦除掉这个脏数据
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // 这就是前面的rehash 操作了,停止条件是Entry e = null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            // 又发现一个过期的,擦除
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 重新计算哈希值,并且移动Entry
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

设置实例

public void set(T value) {
  Thread t = Thread.currentThread();
  ThreadLocalMap map = getMap(t);
  if (map != null)
    map.set(this, value);
  else
    createMap(t, value);
}

该方法先获取该线程的 ThreadLocalMap 对象,然后直接将 ThreadLocal 对象(即代码中的 this)与目标实例的映射添加进 ThreadLocalMap 中。当然,如果映射已经存在,就直接覆盖。另外,如果获取到的 ThreadLocalMap 为 null,则先创建该 ThreadLocalMap 对象。

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    // 首先拿到新的Entry的地址 i。
    int i = key.threadLocalHashCode & (len-1);
    //然后做循环,直到table桶对应位置中没有存放Entry,也就是null
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        //key相同,也就是在这之前该线程曾经设置过该ThreadLocal,那么就直接赋予新的值,然后返回。
        if (k == key) {
            e.value = value;
            return;
        }
        //这里的key值就是ThreadLocal了,它有可能会出现null
        //这是因为Entry继承的是WeakReference,这是弱引用带来的坑
        if (k == null) {
            //出现了null,就要置换过期(stale)了的Entry
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //执行到这里,说明找到了合适的位置,就把新的Entry放入table数组
    tab[i] = new Entry(key, value);
    int sz = ++size;
    //由于弱引用带来了这个问题,所以先要清除无用数据,才能判断现在的size有没有达到阀值threshhold
    //如果没有要清除的数据,并且达到阀值,那就要执行扩容:rehash()
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

将set操作期间遇到的陈旧条目替换为指定键的条目。

ThreadLocalMap还要置换过期的Entry,过期的条件就是table数组中存放的Entry的key(ThreadLocal)没了,但是Entry的value却还是存在的,以下就是置换过期Entry的代码以及涉及到方法的代码。

//根据发生的条件,传进来的key就是null了
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    //往前找,找到table中第一个脏了的数据的下标
    //清理整个table是为了避免因为垃圾回收带来的连续增长哈希的危险
    //也就是说,哈希表没有清理干净,当GC到来的时候,后果很严重。。。
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 找到我们传进来的key或者是在传进来下标值后面的第一个过期Entry
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 如果我们找到了key,那么我们就需要把它跟新的过期数据交换来保持哈希表的顺序
        // 那么剩下的过期Entry呢,就可以交给expungeStaleEntry方法来擦除掉
        // 或者执行rehash方法。
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // 在这里开始清除过期数据
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        if (k == null && slotToExpunge == staleSlot)
            // 如果我们没有在后向查找中找见过期数据,那么slotToExpunge就是第一个过期Entry的下标了
            slotToExpunge = i;
    }

    // 如果以上的查找都没有找见key的话,就放一个Entry<null, value>进tab[staleSlot]
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 如果有其他的脏数据,依然要擦除
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

private static int prevIndex(int i, int len) {
    return ((i - 1 >= 0) ? i - 1 : len - 1);
}

启发式的扫描过期数据并擦除,启发式是这样的:

  • 如果实在没有过期数据,那么这个算法的时间复杂度就是O(log n)
  • 如果有过期数据,那么这个算法的时间复杂度就是O(n)
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            //找到一个过期数据,就对n重新赋值为len
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

防止内存泄漏

对于已经不再被使用且已被回收的 ThreadLocal 对象,它在每个线程内对应的实例由于被线程的 ThreadLocalMap 的 Entry 强引用,无法被回收,可能会造成内存泄漏。

针对该问题,ThreadLocalMap 的 set 方法中,通过 replaceStaleEntry 方法将所有键为 null 的 Entry 的值设置为 null,从而使得该值可被回收。另外,会在 rehash 方法中通过 expungeStaleEntry 方法将键和值为 null 的 Entry 设置为 null 从而使得该 Entry 可被回收。通过这种方式,ThreadLocal 可防止内存泄漏。

private void set(ThreadLocal<?> key, Object value) {
  Entry[] tab = table;
  int len = tab.length;
  int i = key.threadLocalHashCode & (len-1);
  for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
    ThreadLocal<?> k = e.get();
    if (k == key) {
      e.value = value;
      return;
    }
    if (k == null) {
      replaceStaleEntry(key, value, i);
      return;
    }
  }
  tab[i] = new Entry(key, value);
  int sz = ++size;
  if (!cleanSomeSlots(i, sz) && sz >= threshold)
    rehash();
}

ReHash

计算哈希值的实现代码

private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode = new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;

private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

其实就是一个原子类不停地去加上0x61c88647,这是一个很特别的数,叫斐波那契散列(Fibonacci Hashing),斐波那契又有一个名称叫黄金分割,也就是说将这个数作为哈希值的考量将会使哈希表的分布更为均匀。

重新包装和/或重新调整表格的大小。 首先扫描整个表,删除陈旧的条目。 如果这不足以缩小表的大小,则将表大小加倍。

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

/**
 * 扩容的思想跟HashMap很相似,都是把容量扩大两倍
 * 不同之处还是因为WeakReference带来的
 */
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                // 就算是扩容,也不能忘了为擦除过期数据做准备
                e.value = null; // Help the GC
            } else {
                //重新计算哈希,并放到newTab原来没有被占用的桶中
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
    //重新设置阀值,大小,table指针
    setThreshold(newLen);
    size = count;
    table = newTab;
}

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    // 根据key的哈希值确定位置
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        // 哈希值是相等了,还要循环比较key,这是由插入机制带来的
        if (e.get() == key) {
            // 找到了,就调用WeakReference#clear方法
            e.clear();
            //再擦除掉这个Entry就OK了
            expungeStaleEntry(i);
            return;
        }
    }
}