ThreadLocal 小结

365 阅读2分钟

ThreadLocal 作用

ThreadLocal 提供了线程独有的局部变量,每个线程都可以通过 setget 来修改或获取且不会和其它线程的操作产生冲突,实现线程间的数据隔离。

ThreadLocal 原理

public class ThreadLocal<T> {
    ...  
    public T get() {
    	...
    }
    
    public void set(T value) {
    	...    
    }
    
    static class ThreadLocalMap {
        ...
    }
    ...
}
public class Thread implements Runnable {
    ...
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;
    ...
}

ThreadLocal 底层实现

  • 每个线程都可以有多个 ThreadLocal 对象
  • 每个线程对应一个 ThreadLocalMap, ThreadLocalMap 为 ThreadLocal 的静态内部类
  • ThrealLocalMap 底层实现为一个 Entry 类型的数组, 数组大小恒为 2 的整数次幂, 初始值16
  • 每个 Entry 的 key 为 ThreadLocal 对象, 值为相对应的值
  • 每个 Entry 在数组中的位置为当前 threadLocal 的哈希值与 数组大小-1 相与
  • 出现哈希冲突采用开放定址法

ThreadLocal set/get 方法

public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程对应的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    // map 不为空,插入值
    if (map != null)
        map.set(this, value);
    else // 否则创建 map 并插入值
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

private void set(ThreadLocal<?> key, Object value) { // ThreadLocalMap 中的方法
	// 获取当前 ThreadLocalMap 的 Entry 数组
    Entry[] tab = table;
    int len = tab.length;
    // 计算应插入的位置
    int i = key.threadLocalHashCode & (len-1);
	
    // 采用 开放定址法 插入
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 判断是否需要扩容(resize方法会在rehash中多进行一次判断是否调用)
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程对应的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        // map 不为空且 entry 存在返回结果
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 否则创建 map 或者往 map 中添加 entry
    return setInitialValue();
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

注: 文中源码为 jdk1.8