ThraadLocal 源码解析

57 阅读5分钟
public class ThreadLocal<T> {

    private final int threadLocalHashCode = nextHashCode();

	private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    /**
     * 下一个哈希码,初始值为 0
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * 对于长度为 2 的次幂,可以实现完美散列
     */
    private static final int HASH_INCREMENT = 0x61c88647;

	// 获得当前线程的 ThreadLocalMap
	ThreadLocalMap getMap(Thread t) {
       	return t.threadLocals;
    }

	// 新建一个 ThreadLocalMap
	void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

	// 返回当前线程的 threadlocal 值,
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                T result = (T)e.value;
                return result;
            }
        }
		// 如果 ThreadLocalMap == null,则初始化 ThreadLocalMap
        return setInitialValue();
    }

	// 初始化 ThreadLocalMap,key 为当前的 threadlocal,value == null
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
	protected T initialValue() {
        return null;
    }

	public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

	public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
    }

    static class ThreadLocalMap {

		// ThreadLocal 为弱引用
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * 初始容量
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 存放数据的数组,其大小必须是2的次幂
         */
        private Entry[] table;

        /**
         * 数组的实际大小
         */
        private int size = 0;

        /**
         * 扩容的阈值
         */
        private int threshold; // Default to 0

        /**
         * 定义为长度的 2/3
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * 下一个索引,使用线性探测法解决哈希冲突
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * 上一个索引
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /**
         * 构造方法,至少包含一个值,因此可以看出ThreadLocalMap是延迟加载
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

		// 获取 key 对应的value
		private Entry getEntry(ThreadLocal<?> key) {
			int i = key.threadLocalHashCode & (table.length - 1);
			Entry e = table[i];
			if (e != null && e.get() == key)
				return e;
			else
				// 没有找到相同的key
				return getEntryAfterMiss(key, i, e);
		}

		// 使用线性探测法,向后找到对应的 key
		private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
			Entry[] tab = table;
			int len = tab.length;

			while (e != null) {
				ThreadLocal<?> k = e.get();
				if (k == key)
					return e;
				if (k == null)
					// 清理过期 key
					expungeStaleEntry(i);
				else
					i = nextIndex(i, len);
				e = tab[i];
			}
			return null;
		}

        private void set(ThreadLocal<?> key, Object value) {

			Entry[] tab = table;
			int len = tab.length;
			// 计算索引值
			int i = key.threadLocalHashCode & (len-1);

			// 如果当前索引位置有元素,会使用线性探测法,找到 entry 为空的位置
			for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
				ThreadLocal<?> k = e.get();

				// 相同的ThreadLocal,新值替换旧值
				if (k == key) {
					e.value = value;
					return;
				}
				// table[i]上的key为空,说明被回收了(上面的弱引用中提到过)。
				// 这个时候说明 table[i] 可以重新使用,用新的 key-value 将其替换,
				// 并删除其他无效的entry
				if (k == null) {
					replaceStaleEntry(key, value, i);
					return;
				}
			}
			// 找到 entry 为空的位置,将元素插入
			tab[i] = new Entry(key, value);
			int sz = ++size;
			// cleanSomeSlots用于清除那些e.get()==null,
			// 也就是 table[index] != null && table[index].get()==null
			// 之前提到过,这种数据key关联的对象已经被回收,所以这个Entry(table[index])可以被置null。
			// 如果没有清除任何entry,并且当前使用量达到了负载因子所定义(长度的2/3),那么进行rehash()
			if (!cleanSomeSlots(i, sz) && sz >= threshold)
				rehash();
		}

		// 
		private void replaceStaleEntry(ThreadLocal<?> key, Object value,
									   int staleSlot) {
			Entry[] tab = table;
			int len = tab.length;
			Entry e;

			int slotToExpunge = staleSlot;
			// 一直向前找,
			// 		如果遇到 entry = null 的槽,结束循环
			// 		如果 entry != null,那么迭代的时候发现 key = null 的槽,也就是发现了过期键
			// 		更新 slotToExpunge 的值。
			for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
				if (e.get() == null)
					slotToExpunge = i;

			// 向后查找,直到 entry 为空,或找到与当前 key 相同的 entry
			for (int i = nextIndex(staleSlot, len); 
					(e = tab[i]) != null; 
					i = nextIndex(i, len)) {

				ThreadLocal<?> k = e.get();

				// 如果向后找到了 key 相同的槽,当前过期槽(staleSlot)与新槽(i)交换位置,
				if (k == key) {
					e.value = value;

					tab[i] = tab[staleSlot];
					tab[staleSlot] = e;

					// 当 slotToExpunge == staleSlot 时,代表我们在向前查找时没有找到过期数据
					if (slotToExpunge == staleSlot)
						slotToExpunge = i;

					cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
					return;
				}

				// k == null说明当前遍历的 Entry 是一个过期数据,则更新 slotToExpunge 的值
				if (k == null && slotToExpunge == staleSlot)
					slotToExpunge = i;
			}

			// 没有找到相同的 key,则新建一个槽替换当前过期槽
			tab[staleSlot].value = null;
			tab[staleSlot] = new Entry(key, value);

			// 如果我们向前向后遍历时发现了其它的过期数据,也就是 slotToExpunge != staleSlot
			// 清理过期数据
			if (slotToExpunge != staleSlot)
				cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
		}

		// 探测式清理,线性清理,向后遍历,直到遇到空 entry 时结束
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // 先将当前位置数据清空
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
			// 向后遍历,直到遇到空 entry
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
				// 如果当前的 key == null,则当前 entry 为过期数据,需要清空
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
					// 计算当期位置是否发生偏移,如果发生偏移,
					// 重新计算slot位置,
					// 目的是让正常数据尽可能存放在正确位置或离正确位置更近的位置
					// 也就是更接近 i= key.hashCode & (tab.len - 1)的位置。
					// 这种优化会提高整个散列表查询性能。
                    if (h != i) {
                        tab[i] = null;

                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

		// 启发式清理,循环 log2n 次
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
				// 当出现过期 key 时,重置 n = len
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
			// n >>>= 1 = n / 2
            } while ( (n >>>= 1) != 0);
            return removed;
        }

        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
					// 探测式清理
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

        /**
         * Re-pack and/or re-size the table. First scan the entire
         * table removing stale entries. If this doesn't sufficiently
         * shrink the size of the table, double the table size.
         */
        private void rehash() {
            expungeStaleEntries();

            // 过期 key 清理完了,size >= threshold - threshold / 4,扩容
            if (size >= threshold - threshold / 4)
                resize();
        }

		// 遍历整个table,清理过期 key
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }

        /**
         * Double the capacity of the table.
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

			// 将旧数组中的元素放到新数组中
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }
    }
}