ThreadLocal知识点整理.md

476 阅读8分钟

一、概述

ThreadLocal用于实现线程的数据隔离,也就是每个线程独立持有数据的变量,其他线程无法访问和修改。

  • API接口:get()、set(xxx)、remove()
  • 对象的弱引用
  • 实际应用
    • Spring实现事务隔离级别的源码
      • Spring采用Threadlocal的方式,来保证单个线程中的数据库操作使用的是同一个数据库连接,同时,采用这种方式可以使业务层使用事务时不需要感知并管理connection对象,通过传播级别,巧妙地管理多个事务配置之间的切换,挂起和恢复。
      • Spring的事务主要是ThreadLocal和AOP去做实现的
    • SimpleDataFormat多线程场景下会解析时间错误,使用ThreadLocal来保证每个线程只会有一个SimpleDataFormat
      • 不安全原因:当时我们使用SimpleDataFormat的parse()方法,内部有一个Calendar对象,调用SimpleDataFormat的parse()方法会先调用Calendar.clear(),然后调用Calendar.add(),如果一个线程先调用了add()然后另一个线程又调用了clear(),这时候parse()方法解析的时间就不对了。
    • 后台管理平台项目中,用户登录后携带的信息,比如当前用户名、权限等等,可以在项目的整个周期内获取到。

数据结构简介

  • 每个线程单独持有一个ThreadLocalMap对象,其他线程无法获取,实现了数据隔离,多线程数据安全。
  • ThreadLocalMap是一个Entry数组,Entry是继承WeakReference(弱引用)的,Entry是一个key、value结构,key是ThreadLocal对象,value是T对象

为什么用数组来表示?

  • 因为每个线程里面可以设置多个ThreadLocal对象,而这些都维护在ThreadLocalMap中,所以需要数组来表示,且还需要维护扩容。

JVM引用

  • 强引用:当内存不足时触发GC,宁愿抛出OOM也不会回收强引用的内存
  • 弱引用:触发GC后便会回收弱引用的内存

二、优缺点

优点

  • 简单实现线程之间的隔离,线程安全

  • 可以在任何方法中轻松获取到该对象

  • 不需要加锁,执行效率高

  • 更加节省内存,节省开销

  • 免去传参的繁琐,降低代码耦合度

内存泄漏

ThreadLocal的key是弱引用,正常情况下当Thread运行结束后,ThreadLocal中的key和value都会被回收,那么不会存在内存泄漏问题。

但是如果线程一直未结束(比如线程池引入使得线程重用),那么就会导致key和value没有被回收导致内存泄漏。

处理办法

手动调用remove()方法:删除ThreadLocalMap中这个key和value的引用

  • JDK的设计已经考虑到了这个问题,所以在set()、remove()、resize()方法中会扫描到key为null的Entry,并且把对应的value设置为null,这样value对象就可以被回收。
  • 阿里规约:使用完ThreadLocal后,要调用remove()方法

ThreadLocal的空指针异常问题

如果get方法返回值为基本类型,则会报空指针异常,如果是包装类型就不会出错。这是因为基本类型和包装类型存在装箱和拆箱的关系,造成空指针问题的原因在于使用者。

共享对象问题

如果在每个线程中ThreadLocal.set()进去的东西本来就是多个线程共享的同一对象,比如static对象,那么多个线程调用ThreadLocal.get()获取的内容还是同一个对象,还是会发生线程安全问题。

优先使用框架的支持,而不是自己创造

例如在Spring框架中,如果可以使用RequestContextHolder,那么就不需要自己维护ThreadLocal,因为自己可能会忘记调用remove()方法等,造成内存泄漏。

ThreadLocal对象存储在哪里?

虽然ThreadLocal都是被各个线程所持有,但是线程持有的是对象的引用,ThreadLocal对象都存储在堆上,且其对应的对象是存储在堆上的。

三、测试用例

新建多个ThreadLocal对象


public class ThreadLocalTest {

    private static final ThreadLocal<String> threadLocal = ThreadLocal.withInitial(() -> "threadLocal");
    private static final ThreadLocal<Object> threadLocal2 = ThreadLocal.withInitial(() -> "threadLocal2");

    public static void main(String[] args) throws InterruptedException {

        Thread t1 = new Thread(() -> {
            threadLocal.set("test");
            System.out.println(Thread.currentThread().getName() + "\t" + threadLocal.get());
            threadLocal.remove();
        });
        t1.start();
        t1.join();

        // 在运行到这里前,ThreadLocal已经初始化过了,比如java.lang.StringCoding#decoder
        threadLocal.set("test1");
        threadLocal2.set("test2");

        System.out.println(Thread.currentThread().getName() + "\t" + threadLocal.get());
        threadLocal.remove();

        System.out.println(Thread.currentThread().getName() + "\t" + threadLocal2.get());
        threadLocal2.remove();
    }
}

四、父子线程数据传递


private static void test() throws InterruptedException {



    ThreadLocal threadLocal = new InheritableThreadLocal<String>();

    threadLocal.set("my test");



    Thread t = new Thread(()->{

        System.out.println(threadLocal.get());

        threadLocal.remove();

        System.out.println(threadLocal.get());

    });



    t.start();

    t.join();



}



// 输出结果:

my test

null

五、源码

Thread类内部持有ThreadLocalMap对象threadLocals,也就是每个线程单独持有引用,所以线程数据都是隔离的,其他线程拿不到

// 每个Thread对象都持有一个ThreadLocalMap的引用
ThreadLocal.ThreadLocalMap threadLocals = null;

(一)Entry节点

ThreadLocalMap是ThreadLocal的内部类,内部节点是Entry数组。Entry继承了WeakReference(弱引用),弱引用独享在GC时会被jvm直接回收。

构造方法

static class ThreadLocalMap {
    // 对ThreadLocal的弱引用
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    // 初始容量16,容量必须是2的幂次方
    private static final int INITIAL_CAPACITY = 16;
    // Entry数组
    private Entry[] table;
    // 扩容阈值,默认是0,在ThreadLocalMap的set(xxx)时被设置成数组长度的2/3
    private int threshold; // Default to 0

    // 其他源码
}
  • initialValue()初始化方法 该方法用于设置初始值,并且在调用get()方法时才会被触发,所以是 懒加载。 但是如果在get()之前进行了set()操作,这样就不会调用initialValue()。 通常每个线程只能调用一次本方法,但是调用了remove()后就能再次调用

ThreadLocal#get方法

public T get() {
    Thread t = Thread.currentThread();
    // 每个线程获取自己单独的引用ThreadLocalMap 
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 如果使用get前没有set值,则需要初始化并返回默认值
    return setInitialValue();
}

根据当前线程的hashCode,获取到对应的数组位置=hashCode & (len - 1) 如果对应位置是null节点或者对应的key不相同,那么顺着数组获取下一个节点(一会儿可以看hash冲突时set方法的解决办法)

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else   
        // 如果数组节点为null或者是节点对应的key不相同
        return getEntryAfterMiss(key, i, e);
}

寻找下一个节点

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            // 如果key相同则找到了;
            return e;
        if (k == null)
            // 如果当前key为null,需要删除该无用的节点,后面讲
            expungeStaleEntry(i);
        else
            // 否则一直寻找。
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

(二)hash冲突怎么解决

这个ThreadLocalMap 像HashMap,但是节点并没有next指针,所以这个ThreadLocalMap 没有链表结构,那么怎么解决Hash冲突呢?

ThreadLocal#set方法

1、首先根据当前线程获取到hashCode,然后找到数组位置=hashCode & (len-1) 2、如果对应数组位置的key有值,且key相同,则直接set值(因为同一个ThreadLocal只有一个值) 3、如果对应数组位置的key为null,则在这里set值 4、如果对应数组位置的key有值,且key不同,则顺着数组找下一个节点。重复上面2、3步骤

public void set(T value) {
    Thread t = Thread.currentThread();
    // getMap放当前线程对应的ThreadLocalMap引用
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        // main线程启动时已经初始化了,在某些类中,比如java.lang.StringCoding#decoder,所以最好单独起线程测试
        createMap(t, value);
    }

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

// 核心set方法
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 每个线程的hashCode值,在ThreadLocal初始化时就存在了且不会变化,是成员变量
    int i = key.threadLocalHashCode & (len-1);
    // 从节点i开始往后查找,如果找到相同key或者是null位置,直接赋值。如果没有则一直找
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        if (k == key) {
            // 在i位置往后找到该key,直接赋值并返回
            e.value = value;
            return;
        }
        if (k == null) {
            // 找到数据节点为null,存储key和value
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 找到了一个数组节点为null,直接赋值(不会找不到为null的值,因为扩容因子是2/3,后面步骤保证了要扩容)
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // 如果i+1后面都是有key而且对应的value不为空、且达到了扩容阈值,则开始扩容
        rehash();
}

// 此处的i是set值时位置的i+1了,也就是上一步赋值的null位置处
// 从i+1位置处,一直往后面找,判断是否含有无用的key和value
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

// 扩容:2倍扩容,如果新位置有值,则顺序往后找空位置放元素
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 扩容核心:如果映射后的位置已经有值,则往后找null位置放值
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

删除无用的旧节点 1、直接将该节点清空,注意是将key和value都设置为null 2、判断该节点后面的其他节点,如果没有遇到null节点,那么就会Rehash

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 删除当前节点的key和value
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null(在没有遇到null前会一直Rehash,也就是没有null节点的话)
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

ThreadLocal#remove方法

public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null)
         m.remove(this);
}