你真的了解ThreadLocal吗?

240 阅读6分钟

一, 是什么?怎么用?

是什么?

是每个线程的本地变量,可以存储每个线程独有的变量.

怎么用?

可以为每个线程创建一个独有的变量对象

可以实现线程间的数据隔离

Spring声明式事务中使用ThreadLocal实现数据库隔离


二, 类架构

ThreadLocal架构图

ThreadLocal属性

/**
 * 该值用于给ThreadLocalHashMap中存入值时线性探测插入的bucket位置
 */
private final int threadLocalHashCode = nextHashCode();

/**
 * 下一个要给出的hashCode,每次原子性更新,从0开始
 */
private static AtomicInteger nextHashCode = new AtomicInteger();

/**
 * hashCode增值,使用这个数字可以使key均匀的分布在2的幂次方的数组上
 * 具体可以参考 https://www.javaspecialists.eu/archive/Issue164-Why-0x61c88647.html
 * 因为比较复杂,在这里不展开讨论
 */
private static final int HASH_INCREMENT = 0x61c88647;

ThreadLocalMap属性

/**
 * Map的初始化容量
 */
private static final int INITIAL_CAPACITY = 16;

/**
 * 哈希表,长度始终是2的幂,原因在于2的幂-1的二进制全部为1
 * 便于按位与操作
 * 如16: 10000 - 1 = 1111
 * 按位与后都会在数组中
 */
private Entry[] table;

/**
 * 哈希表的长度
 */
private int size = 0;

/**
 * 扩容阈值,默认为0,扩容大小为哈希表长度的2/3
 */
private int threshold;

三, 实现原理

1, 为什么ThreadLocal可以实现线程隔离?

/**
 * 创建一个ThreadLocal,可以看出在构造器内部并没有处理任何事情
 */
public ThreadLocal() {}
/**
 * 给ThreadLocal中设置值会调用该方法
 */
public void set(T value) {
    //获取当前线程
    Thread t = Thread.currentThread();
    //获取当前线程中的变量threadLocals
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        //当前线程的threadLocals变量为空,创建map设置值
        createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
    //获取t的线程变量
    return t.threadLocals;
}

由上可以看出给ThreadLocal设置值本质上是给Thread的本地变量threadLocals变量设置值,这就是为什么ThreadLocal可以实现线程之间数据隔离


2, 增删查操作

/**
 * 给ThreadLocal中设置值会调用该方法
 */
public void set(T value) {
    //获取当前线程
    Thread t = Thread.currentThread();
    //获取当前线程中的变量threadLocals
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        //当前线程的threadLocals变量为空,创建map设置值
        createMap(t, value);
}
void createMap(Thread t, T firstValue) {
    //给线程设置变量值
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
/**
 * 设置与key相关的值,key为ThreadLocal
 */
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    //按位与获取插入的位置
    int i = key.threadLocalHashCode & (len - 1);
    /**
     * 使用线性探测法插入值
     * 从获取到的下标向后遍历,如果当前key等于数组中当前下标的key则直接修改值
     * 如果数组当前下标位置为空则替换掉当前下标的entry
     */
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        if (k == key) {
            e.value = value;
            return;
        }
        if (k == null) {
            //当前位置为空需要替换掉
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //如果i向后没有与其相同的key或者为空的位置则直接替换掉当前位置的entry
    tab[i] = new Entry(key, value);
    int sz = ++size;
    //从i向后没有空位,并且数量大于阈值需要rehash,删除哈希表中某些为空的entry如果size大于阈值的3/4则需要扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
/**
 * 替换掉无效的entry
 */
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
        //从后向前寻找空位,寻找到距离staleSlot最远的一个位置
        if (e.get() == null)
            slotToExpunge = i;

    // 向后遍历
    for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        //找到了与ThreadLocal相同的key,替换值
        if (k == key) {
            e.value = value;
            //替换掉当前位置的entry
            tab[i] = tab[staleSlot];
            //交换entry
            tab[staleSlot] = e;

            // 前面没有空位,设置删除的位置
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            //删除后面为空的位置
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }
    // 如果没有发现key,则直接替换
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
    // staleSlot前有空位需要删除为空的entry
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
/**
 * rehash哈希表
 */
private void rehash() {
    //删除无效的entry并且rehash哈希表
    expungeStaleEntries();
    // 如果数量大于等于阈值的3/4,需要扩容
    if (size >= threshold - threshold / 4)
        resize();
}
/**
 * 删除无效的entry并且rehash哈希表
 */
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}
/**
 * 双倍扩容之前的哈希表
 */
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            //当前位置key为空,则直接GC
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                //获取新位置
                int h = k.threadLocalHashCode & (newLen - 1);
                //从当前位置向后寻找一个为空的位置
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                //重新插入entry
                newTab[h] = e;
                count++;
            }
        }
    }
    //重新设置阈值和哈希表属性
    setThreshold(newLen);
    size = count;
    table = newTab;
}
/**
 * 删除某些entry以减半的方式删除
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
        //无条件右移1位           /2
    } while ((n >>>= 1) != 0);
    return removed;
}
/**
 * 删除具体位置的entry,并且rehash之后的entry
 */
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    // 删除当前下标的entry
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;
    Entry e;
    int i;
    //从当前位置向后遍历,寻找为空的或者rehash之后的entry
    for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            //rehash之后的位置不是当前的位置需要删除当前的entry
            if (h != i) {
                //help GC
                tab[i] = null;
                //从h向后遍历,寻找一个为空的位置
                while (tab[h] != null)
                    h = nextIndex(h, len);
                //插入entry
                tab[h] = e;
            }
        }
    }
    return i;
}
/**
 * 获取ThreadLocal中存放的值
 */
public T get() {
    //1, 获取当前线程
    Thread t = Thread.currentThread();
    //2, 获取当前线程中的threadLocals变量,也就是ThreadLocalMap对象
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //3, 获取之前设置的值并返回
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //3, 如果当前线程的threadLocals变量为空,则还没有初始化,需要进行初始化
    return setInitialValue();
}
/**
 * 设置初始化值
 */
private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}
/**
 * 初始化值为null
 */
protected T initialValue() {
    return null;
}
/**
 * 删除变量
 */
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}
/**
 * 删除某个entry
 */
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len - 1);
    //从i向后遍历
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            //help GC
            e.clear();
            //删除当前entry并且rehash后面的entry
            expungeStaleEntry(i);
            return;
        }
    }
}

三, 存在问题

1, 内存泄漏问题

因为使用ThreadLocal本质上是使用ThreadLocalMap,在使用完ThreadLocal后,无法手动删除ThreadLocalMap中的key(ThreadLocal的引用),所以可能会引起内存泄漏问题,但是代码在设计的时候就考虑到了这一点,所以将ThreadLocalMap中的key(ThreadLocal)设置为了弱引用(WeakReference),即很容易被GC掉,但即便如此,我们还是要在使用完后调用ThreadLocal的remove方法,手动删除ThreadLocal引用,避免内存泄漏.

2, 子线程不能访问父线程变量

可以使用InheritableThreadLocal

原理

ThreadLocal<String> local = new InheritableThreadLocal<>();
local.set("main local variable");
public void set(T value) {
    //获取当前线程
    Thread t = Thread.currentThread();
    //因为InheritableThreadLocal重写了ThreadLocal的getMap方法,所以下面调用的为InheritableThreadLocal中的getMap,获取的为Thread类的inheritableThreadLocals变量
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        //当前线程的threadLocals变量为空,创建map设置值
        createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
    return t.inheritableThreadLocals;
}
//创建子线程
Thread thread = new Thread(() -> {
    local.set("child thread variable");
    System.out.println("child thread get local variable : " + local.get());
});
public Thread(Runnable target) {
    init(null, target, "Thread-" + nextThreadNum(), 0);
}
private void init(ThreadGroup g, Runnable target, String name, long stackSize) {
    init(g, target, name, stackSize, null, true);
}
private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
    //获取当前线程,这里为父线程,子线程还没有被创建出来
    Thread parent = currentThread();
    ...
        //这里的parent.inheritableThreadLocals在父线程set的时候已经初始化了
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            //所以子类的ThreadLocal也是使用的inheritableThreadLocals,不是之前的threadLocals
            this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
}
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}
private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];
	
    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            //向上转型
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                //调用InheritableThreadLocal对象的childValue方法,返回e.value
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                //插入到子线程的inheritableThreadLocals,也就是将父线程的inheritableThreadLocals数据复制到子线程中
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}