ThreadLocal底层原理学习

3,958 阅读14分钟

是什么?

首先ThreadLocal类是一个线程数据绑定类, 有点类似于HashMap<Thread, 你的数据> (但实际上并非如此), 它所有线程共享, 但读取其中数据时又只能是获取线程自己的数据, 写入也只能给线程自己的数据

image.png

注意, 非常重要 我们需要知道

  • 一个线程Thread有一个ThreadLocalMap
  • ThreadLocalMap底层是Entry[]
  • Entry是一个键值对
    • EntrykeyThreadLocal对象
    • EntryvalueThreadLocal<T>

怎么用?

public class ThreadLocalDemo {
	
	private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();
	
	public static void main(String[] args) {
		for (int i = 0; i < 10; i++) {
			new Thread(() -> {
				threadLocal.set("zhazha" + Thread.currentThread().getName());
				String s = threadLocal.get();
				System.out.println("threadName = " + Thread.currentThread().getName()  + " [ threadLocal = "  + threadLocal + "\t data = " + s + " ]");
			}, "threadName" + i).start();
		}
	}
}

从他的输入来看, ThreadLocal是同一个, 数据存的是线程自己的名字, 所以和threadName是一样的名称

threadName = threadName9 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName9 ]
threadName = threadName3 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName3 ]
threadName = threadName7 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName7 ]
threadName = threadName0 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName0 ]
threadName = threadName6 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName6 ]
threadName = threadName1 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName1 ]
threadName = threadName2 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName2 ]
threadName = threadName4 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName4 ]
threadName = threadName5 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName5 ]
threadName = threadName8 [ threadLocal = java.lang.ThreadLocal@43745e1f	 data = zhazhathreadName8 ]

有什么使用场景

我们使用获取到一个保存数据库请求, tomcat会有一个线程去操作数据库保存数据和响应数据给客户, 而操作数据库需要存在一个数据库链接Connection对象, 只要是同一个数据库链接, 就可以得到同一个事务 但一个线程是如何获取同一个Connection从而获取同一个事务 ? 方法其实很简单, 使用 ThreadLocal绑定在线程中, 类似于Map<Thread, Connection>去存储

实际是ThreadLocalMap<ThreadLcoal, Connection>

面试官: ThreadLocalEntry对象的key为何是弱引用?

答案很明显, 防止内存泄漏, 我们来详细分析分析

内存泄漏: 程序中己动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果

static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        // ThreadLocal被设置为弱引用
        super(k);
        // 保存value
        value = v;
    }
}

我们会发现 ThreadLocal 被设置为弱引用, 但是在Entry[]中的Entry元素之间的关系本身是强引用

也就是说只有在 key 触发了 弱引用 机制, 将Entrykey设置为 null

那么什么是弱引用呢?

  • 强引用: 如果引用变量没被指向null则, 引用对象将被停留在堆中, 无法被虚拟机回收Object obj = new Object()
  • 软引用: 如果虚拟机堆内存不够用了(在发生内存溢出之前), 虚拟机可以选择回收软引用对象, 虚拟机提供SoftReference类实现软引用, 一般用于相对比较重要但又可以不用的对象, 比如: 缓存
  • 弱引用: 生于系统回收之前, 死于系统回收完毕之后, 弱引用需要依附于强引用或者软引用才能够防止被虚拟机回收, 比如放到一个引用队列(ReferenceQueue)中或者对象中, 比如: ThreadLocalMapEntry对象, 需要依附于ThreadLocalMap才能够不被删除掉
  • 虚引用: 可以理解为跟强引用对象没了引用变量一样, 随时可以被回收, 只要依附于引用队列中才不会被回收, 通常用于网络通讯的NIO上, 用于引用直接内存, java提供类PhantomReference来实现虚引用

弱引用: 如果没有引用对象引用被定义为弱引用对象, 那么 gc 就会将弱引用回收掉

key为弱引用就没有问题了么?

还是存在问题

Entrykey 被设置为 null, 但是整个 Entry 还是存在于 Entry[] 数组中

也就是说实际的对象还是存在于内存中

public static void main(String[] args) throws InterruptedException {
    A a = firstStack();
    System.gc();
    a = null;
    System.gc();
    TimeUnit.SECONDS.sleep(1);
    Thread thread = Thread.currentThread();
    System.out.println(thread); // 在这里打断点,观察thread对象里的ThreadLocalMap数据
}

// 通过是否获取返回值观察A对象里的local对象是否被回收
private static A firstStack() {
    A a = new A();
    System.out.println("value: " + a.get());
    return a;
}

private static class A {
    private ThreadLocal<String> local = ThreadLocal.withInitial(() -> "in class A");

    public String get() {
        return local.get();
    }

    public void set(String str) {
        local.set(str);
    }
}

image.png

image.png

但是整个 Entry 对象还是存在于 table 中 占据着内存位置

解决方法是什么?

有两种方式

  • ThreadLocal自带的清除功能, 但是可能不是那么及时
  • 主动调用remove方法(推荐)

既然ThreadLocal自带了清除功能, 那为什么还需要主动调用remove方法?

我们需要深入底层

底层源码分析

get方法分析

public T get() {
	// 拿到当前线程
	Thread t = Thread.currentThread();
    // 拿到线程的一个字段ThreadLocalMap, 你可以将其看做一个Map集合
	ThreadLocal.ThreadLocalMap map = getMap(t);
    // 如果能够从Thread中拿到 map 集合的话
	if (map != null) {
		// 从 map 集合中找到以 key 为 ThreadLocal对象的 Entry
        // 在这个 getEntry 函数中, 将会通过 hashCode 和 len计算出 i, 如果不为空则返回给 e
        // 如果为空, 则从 Map 中 while 循环
        // 在这个 while 循环中, 拿出 k 对象, 判断 k 是否相等, 相等则直接返回
        // 不相等, 则判断 k == null 吗? 是的话直接删除掉, 因为它过期了
        // 然后向后循环遍历, 再做上面两个步骤
        // 如果都没有则直接返回 null
		ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);
		if (e != null) {
            // 从 e 中拿到 value
			T result = (T)e.value;
			return result; // 最后返回 value
		}
	}
	return setInitialValue();
}

最后的删除也非常简单 tab[i] = null

根据上面源码分析发现ThreadLocal底层使用的不是类似Map<Thread, Data> 这种结构而是 每个线程都有一个属于自己的ThreadLocalMap结构

image.png

而他的结构是这样的

getMap方法分析

private Entry getEntry(ThreadLocal<?> key) {
    // 根据k计算在hash桶的位置
    int i = key.threadLocalHashCode & (table.length - 1);
    // 从Entry[]中拿到Entry
    Entry e = table[i];
    // 找到对应的对象了
    if (e != null && e.get() == key)
        // 直接返回
        return e;
    else
        // 没找到直接调用
        return getEntryAfterMiss(key, i, e);
}

getEntryAfterMiss方法分析

查找与给定键关联的条目

在查找entry之前, 会计算 index 也就是位置, 再根据 index 拿到 entry, 然后 entrykey 跟我们传入的 key 进行判断, 最后拿到 entry

但是能够进入下面这个函数, 只能说明我们需要的 entry 并不在对应的 index 位置

说明他被放到了 index 位置的后面

这里跟HashMap拉链法的另一个方法开放寻址方法一样, 在当前位置的下一个位置找到一个空的位置存放Entry

  • key,即要查找的键;

  • i,即键的哈希代码的表索引;

  • e,即表中索引为i的条目。

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    // 该方法使用while循环遍历表中的条目
    // 从索引为i的条目开始
    // 直到找到与给定键关联的条目或遍历完整个表
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 如果找到与给定键关联的条目,则该方法返回该条目。
        if (k == key)
            return e;
        // 如果发现一个过期条目(即其键为null),
        // 则调用expungeStaleEntry方法将其删除。
        if (k == null)
            expungeStaleEntry(i);
        else
            // 将索引更新为下一个索引,并继续循环。
            i = nextIndex(i, len);
        e = tab[i];
    }
    // 如果未找到与给定键关联的条目,则该方法返回null。
    return null;
}

所以在读取的时候, 如果我们的 key == null, 并且在该 key 位置的之后的位置, 那么才会被清除掉

否则无法被删除, 导致内存泄漏

setInitialValue方法分析

其中的table数组在上面的 setInitialValue() 方法创建详细源码在这

private T setInitialValue() {
    // 这个方法在我们的用例中没写, 所以默认放回 null
    T value = initialValue();
    Thread t = Thread.currentThread();
    // 获取线程单独的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 如果map不为空,说明当前线程已经有了ThreadLocalMap对象,那么就调用map.set(this, value)方法,将当前ThreadLocal对象和value作为键值对存入map中。
        map.set(this, value);
    } else {
        // 线程ThreadLocalMap 没被创建, 需要创建出来, 
        // 其中的 table 数组在这里被创建
        createMap(t, value);
    }
    // 判断当前ThreadLocal对象是否是TerminatingThreadLocal的实例
    // 然后将 ThreadLocal 注册到TerminatingThreadLocal的容器中
    // 这个类主要用于清除资源使用
    if (this instanceof TerminatingThreadLocal) {
        TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
    }
    // 返回value作为本地变量的初始值
    return value;
}

TerminatingThreadLocal是一个内部类,它继承了ThreadLocal,并实现了一个接口Terminatable。这个接口定义了一个terminate()方法,用来在线程结束时清理本地变量的资源。

如果当前ThreadLocal对象是TerminatingThreadLocal的实例,那么就调用TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this)方法,将其注册到一个静态列表中,以便在线程结束时调用其terminate()方法。

createMap方法分析

他会在ThreadLocalMap中调用构造方法初始化

// 其中 firstValue是我们的值
void createMap(Thread t, T firstValue) {
    // 关注下 this , 它是ThreadLocal
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // 创建一个Entry类型的数组table,大小为INITIAL_CAPACITY,这是一个常量,表示初始容量,其值为16。
    table = new Entry[INITIAL_CAPACITY];
    // 根据firstKey的threadLocalHashCode属性,计算出一个索引i, 通过下面的计算可以保证i在0到15之间。
    // threadLocalHashCode这个属性是ThreadLocal对象的哈希码,它是在创建ThreadLocal对象时随机生成的
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 将firstKey和firstValue封装成一个Entry对象,并存入table[i]位置。Entry是一个静态内部类,它继承了WeakReference<ThreadLocal<?>>,表示对ThreadLocal对象的弱引用。弱引用可以避免内存泄漏,当ThreadLocal对象没有被其他强引用引用时,它就可以被垃圾回收器回收。Entry对象还有一个value属性,用来存储本地变量的值。
    table[i] = new Entry(firstKey, firstValue);
    // 初始时, 将size属性设置为1,表示当前ThreadLocalMap中有一个键值对。
    size = 1;
    // 调用setThreshold(INITIAL_CAPACITY)方法,设置阈值为初始容量的2/3,即10。阈值表示当ThreadLocalMap中的键值对数量达到这个值时,就需要扩容。	
    setThreshold(INITIAL_CAPACITY);
}

set方法分析

输入: ThreadLocalvalue, 也就是新的 keyvalue

输出: 无

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    // 根据键的哈希码和数组长度计算出索引位置i。
    int i = key.threadLocalHashCode & (len-1);

    // 从索引位置i开始向后遍历数组,查找是否存在与键相同或者为空的条目
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        // 如果找到相同的键,则直接替换其值,并返回。
        if (k == key) {
            e.value = value;
            return;
        }
		// 如果找到一个条目,但其对应的键为null(说明该条目已经被垃圾回收),
        // 则调用 replaceStaleEntry 方法来替换掉这个过期的条目,并返回。
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 如果在遍历完数组后仍然没有找到相同或者空的条目,那么在索引位置 i 处创建一个新的 Entry 来存储传入的 key 和 value,并增加 size 计数。
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 如果数组中存在过多的过期条目或者 size 计数超过了阈值,就会调用 cleanSomeSlots 和 rehash 方法来清理和可能扩容数组。
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

set函数也是, 将会检测后面的entry是否过期, 然后删除掉

replaceStaleEntry源码分析

这个方法主要作用: 替换掉过期的条目(即键为null的条目),并维护哈希表的顺序

  • ThreadLocal<?> key:新的键,即要存储的线程本地变量的ThreadLocal对象。
  • Object value:新的值,即要存储的线程本地变量的值。
  • int staleSlot:过期元素的位置,即哈希表中键为null的元素的索引。
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 从传入的 staleSlot 位置的前一个位置开始向前扫描哈希表数组,
    // 找到最早出现的过期条目,并记录其位置为 slotToExpunge
    // 这是为了后续清理整个哈希槽中的过期条目做准备。
    // prevIndex 是在一个循环数组中, 往数组之前查找
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i; // 找到过期位置 slotToExpunge

    // 从传入的 staleSlot 位置的后一个位置开始向后扫描哈希表数组,
    // 寻找与传入的 key 相同或者为空的条目。
    // 在循环数组中, 往循环数组之后查找符合的元素
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 如果找到与传入的 key 相同的键
        if (k == key) {
            // 就将该条目的值更新为传入的 value
            e.value = value;
			// 然后将该条目与传入的 staleSlot 位置的条目进行位置交换。
            // 这样做是为了确保哈希表中每个键都在其应该在的位置。
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果找到一个条目,但其对应的键为空(说明该条目已经被垃圾回收),
        // 则记录当前位置为 slotToExpunge。
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 如果在向后扫描的过程中没有找到与传入的 key 相同的键,
    // 就将传入的 key 和 value 创建一个新的 Entry 放在 staleSlot 位置。
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 如果在向后扫描的过程中发现了其他的过期条目,或者在向前扫描时已经发现了过期条目(slotToExpunge 不等于 staleSlot)
    if (slotToExpunge != staleSlot)
        // 调用 expungeStaleEntry 方法来清理这些过期条目,
        // 并重新分配哈希槽中的有效条目。这是为了保持哈希表的性能和一致性。
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

cleanSomeSlots源码分析

此方法用于清理ThreadLocalMap表中的过期条目

两个参数:

  • i,即要开始清理的插槽的索引

  • n,即要清理的插槽数

方法返回一个布尔值,指示是否删除了任何过期条目

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        // Entry[] 数组拿出来的item, 然后判断 key 是否等于 null
        // 如果是则调用 expungeStaleEntry 去删除 item
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ((n >>>= 1) != 0);
    return removed;
}

expungeStaleEntry源码分析

清除哈希表中过期的条目。方法接受一个参数:staleSlot,表示已知的过期条目的位置。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 将已过期的条目删除掉
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // 下面是重hash的过程
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // 表示过期的条目
        if (k == null) {
            // 删除条目
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 如果没有过期
            // 根据 key 计算出在 hash 桶的位置在哪里?
            int h = k.threadLocalHashCode & (len - 1);
            // 如果计算出的位置 `h` 不等于当前遍历到的位置 `i`
            // 说明 hash 的位置错了
            if (h != i) {
                // 将 hash 桶中的数据删除掉
                tab[i] = null;

                // 判断 key 计算出来的 hash 位置的 hash 桶是否有数据
                // 如果有数据, 则计算下一个位置
                while (tab[h] != null)
                    h = nextIndex(h, len);
                // 将旧的数据存放在新计算出来的位置中
                tab[h] = e;
            }
        }
    }
    // 方法返回下一个空条目的位置。这样,在调用方可以知道在 `staleSlot` 和返回值之间的所有条目都已经检查过并清除了过期条目。
    return i;
}

rehash源码分析

private void rehash() {
    // 删除所有过期的 entry
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    // 重新 hash
    if (size >= threshold - threshold / 4)
        resize();
}

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    // 遍历所有 entry[]
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            // 删除过期entry
            expungeStaleEntry(j);
    }
}

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    // 创建一个新的 Entry[]
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    // 遍历旧的 Entry
    for (Entry e : oldTab) {
        if (e != null) {
            ThreadLocal<?> k = e.get();
            // 删除掉过期的 entry
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // key 不为空
                // 根据 key 计算 位置
                int h = k.threadLocalHashCode & (newLen - 1);
                // 根据位置拿到 entry
                while (newTab[h] != null)
                    // 向这个位置之后的位置遍历, 直到找到 entry == null 才停止
                    h = nextIndex(h, newLen);
                // 找到了空的位置, 将 e 放到空的位置中
                // 至此 一个 entry 的 rehash 完成
                newTab[h] = e;
                count++;
            }
        }
    }
	
    // 设置域值
    // threshold = len * 2 / 3;
    setThreshold(newLen);
    size = count;
    // 最后完成整个 rehash 的过程
    table = newTab;
}

resize方法创建一个新的条目数组,其大小是原数组的两倍。然后遍历旧数组中的所有条目,对于每个非空条目,如果其键不为空,则根据键重新计算其在新数组中的位置,并将其放入新数组中。最后,更新阈值并将表设置为新数组。

remove源码分析

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算位置
    int i = key.threadLocalHashCode & (len-1);
    // 遍历
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        // 找到位置
        if (e.get() == key) {
            // reference = null 也就是置空
            e.clear();
            // 删除 entry
            expungeStaleEntry(i);
            return;
        }
    }
}