ThreadLocal 基于源码学习

240 阅读6分钟

概括

  • ThreadLocal的功能是可以存取一个线程间隔离的变量。
    • ThreadLocal底层通过字典ThreadLocalMap实现,每一个线程Thread都自带了一个属性ThreadLocalMap,不同线程的ThreadLocal都会存取在对应线程的ThreadLocalMap中,从而实现了线程隔离。
  • ThreadLocalMap是ThreadLocal的内部类
    • ThreadLocalMap是通过基于哈希表的数据结构存取键值对。键是通过一个ThreadLocal的静态方法产生的随机数,且该产生方法并发安全,因此确保了所有的ThreadLocal键都是唯一的。而值是Entity实例。
    • ThreadLocalMap发生哈希碰撞,则向后探索直到遇到一个空的位置,并放置。
    • ThreadLocalMap元素数量达到大小的一半时会发生扩容操作,新数组大小是原来的两倍。
  • Entity是ThreadLocalMap的内部类
    • Entity是一个键值对,该键值是ThreadLocal本身,值则是变量,Entity是真正存储变量的位置。
    • Entity的键是一个指向ThreadLocal的弱引用,当外部的TheadLocal不被使用时,不需要手动删除ThreadLocalMap中的Entity,Entity会随着外部ThreadLocal失效而被ThreadLocalMap的逻辑清除。
  • 弱引用的作用:弱引用本质上是降低程序员负担的同时帮助垃圾回收器工作。
    • 当外部ThreadLocal失效时,因为Entity关联着ThreadLocal,而导致ThreadLocal无法被回收,当这里是弱引用时则没有这个问题。

Entity

/**
 * ThreadLocalMap核心存储值得类
 */
static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;  // 存放变量值

    Entry(ThreadLocal<?> k, Object v) {
        super(k); // 这里ThreadLocal是弱引用
        value = v;
    }
}

ThreadLocal核心逻辑追踪

哈希键值生成

/**
 * 用于存储当前ThreadLocal的哈希值
 */
private final int threadLocalHashCode = nextHashCode();

/**
 * 用于被所有ThreadLocal生成哈希值的起始值
 * 被下面nextHashCode方法用到
 */
private static AtomicInteger nextHashCode = new AtomicInteger();

/**
 * 哈希值增加的步幅
 * 被下面nextHashCode方法用到
 */
private static final int HASH_INCREMENT = 0x61c88647;

/**
 * 用于所有ThreadLocal生成哈希值的方法。
 */
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

存储变量

set方法

/**
 * 存储ThreadLocal的值
 */
public void set(T value) {
    Thread t = Thread.currentThread();
    // 从当前线程t中获取独属当前线程的LocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value); // 把value值和ThreadLocal放入到LocalMap
    else
        createMap(t, value);  // 创建一个新的LocalMap
}

createMap方法

/**
 * 初始化一个ThreadLocalMap,并传入第一个值
 */
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

getMap方法

/**
 * 获取ThreadLocalMap
 * @param Thread 一般是当前线程
 */
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

ThreadLocalMap.set方法

  • 该方法是ThreadLocalMap内部类的方法。
/**
 * 把ThreadLocal和值存进ThreadLocalMap中。
 */
private void set(ThreadLocal<?> key, Object value) {
    // 该map基于数组
    Entry[] tab = table;
    int len = tab.length;
    // 计算出ThreadLocal对应的索引
    int i = key.threadLocalHashCode & (len-1); 
    // 定位索引的槽位
    // 若该槽位被占用,则向后循环遍历致空槽位
    for (Entry e = tab[i];
         e != null; 
         e = tab[i = nextIndex(i, len)]) {  // nextIndex:获取后一个索引
        // 获取entity中的ThreadLocal
        ThreadLocal<?> k = e.get();  
        // 判断存储ThreadLocal是否是当前的ThreadLocal
        if (k == key) { 
            e.value = value;
            return;
        }
        // entity存储的ThreadLocal为空,
        // 说明外部k的强引用已被断开,k的空间已被gc回收。
        if (k == null) { 
            // 把该槽位的值更新为当前的键值对
            replaceStaleEntry(key, value, i); 
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 检测清除部分失效的Entity,判断大小是否超过阈值
    if (!cleanSomeSlots(i, sz) && sz >= threshold) 
        rehash(); // 去除无效节点,判断是否需要数组扩容。
}

ThreadLocalMap.replaceStaleEntry方法

  • 该方法是ThreadLocalMap内部类的方法。
/**
 * 把原来槽位失效的Entity置空。
 * 把新的ThreadLocal和value封装成Entity放进该槽位
 */
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 从staleSlot位置向前遍历,找出最靠前的失效节点索引。
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 从staleSlot位置向后移动,找到ThreadLocal的位置,把它移到staleSlot位置。
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // 后面节点重新进行哈希计算,扫描移除无效节点,判断数组是否需要扩容。
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // k==null 表示当前i这个位置的Entity已失效,即索引i是废弃的节点
        // slotToExpunge == staleSlot 意味着staleSlot索引前面没有废弃节点
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 若找不到,则生成一个对应值为null的ThreadLocal放到staleSlot位置。
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

ThreadLocalMap.expungeStaleEntry方法

  • 该方法是ThreadLocalMap内部类的方法。
/**
 * 移除指定位置节点
 * 因为ThreadLocalMap的哈希碰撞处理机制是放到后面空白的位置
 * 因此删除某个位置节点后,要判断是否需要把后面的节点往前移。
 */
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 擦除节点
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    Entry e;
    int i;
    // 向后遍历,移除废弃节点,非废弃节点重新计算新的索引位置。
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    // 发生哈希碰撞,索引后移。
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

ThreadLocalMap.cleanSomeSlots方法

  • 该方法是ThreadLocalMap内部类的方法。
/**
 * 清除ThreadLocalMap失效的Entity
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

获取变量

get方法

/**
 * 获取变量
 */
public T get() {
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap。
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 尝试在LocalMap检索ThreadLocal。
        ThreadLocalMap.Entry e = map.getEntry(this); 
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 找不到对应的ThreadLocal
    // 创建一个新的ThreadLocal插入到LocalMap,对应value是null值。
    // 返回null值
    return setInitialValue(); 
}

ThreadLocalMap.getEntry方法

/**
 * 在当前线程的ThreadLocalMap中获取变量值
 */
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        // 运行到这表示放值的是否发生了哈希碰撞
        // 当前值被放到后面的槽位
        return getEntryAfterMiss(key, i, e);
}

ThreadLocalMap.getEntryAfterMiss方法

/**
 * 在ThreadLocalMap检索发生了哈希碰撞的值
 */
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    // 线性遍历
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 判断是否就是要取得值
        if (k == key)
            return e;
        if (k == null)
            // 移除无效的Entity
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

ThreadLocalMap.setInitialValue方法

/**
 * 设置初始值,该初始值是控制
 */
private T setInitialValue() {
    T value = initialValue(); // 返回null
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    // 判断ThreadLocalMap是否初始化
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

扩容机制

ThreadLocalMap.rehash方法

/**
 * 判断是否需要扩容数组,需要则扩容
 */
private void rehash() {
    // 移除所有无效节点
    expungeStaleEntries();
    // 判断元素数量是否达到一半
    if (size >= threshold - threshold / 4)
        resize();
}

ThreadLocalMap.resize方法

/**
 * 扩容数组
 */
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2; // 大小扩容为两倍
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    // 遍历旧数组
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            // 判断entity是否无效
            if (k == null) {
                e.value = null;
            } else {
                // 把entity插入到新的数组
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
	// 重新计算阈值
    setThreshold(newLen);
    size = count;
    table = newTab;
}

ThreadLocalMap.setThreshold方法

/**
 * 重新计算阈值
 */
private void setThreshold(int len) {
    threshold = len * 2 / 3;
}