public class ThreadLocal<T> {
private final int threadLocalHashCode = nextHashCode()
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT)
}
/**
* 下一个哈希码,初始值为 0
*/
private static AtomicInteger nextHashCode =
new AtomicInteger()
/**
* 对于长度为 2 的次幂,可以实现完美散列
*/
private static final int HASH_INCREMENT = 0x61c88647
// 获得当前线程的 ThreadLocalMap
ThreadLocalMap getMap(Thread t) {
return t.threadLocals
}
// 新建一个 ThreadLocalMap
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue)
}
// 返回当前线程的 threadlocal 值,
public T get() {
Thread t = Thread.currentThread()
ThreadLocalMap map = getMap(t)
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this)
if (e != null) {
T result = (T)e.value
return result
}
}
// 如果 ThreadLocalMap == null,则初始化 ThreadLocalMap
return setInitialValue()
}
// 初始化 ThreadLocalMap,key 为当前的 threadlocal,value == null
private T setInitialValue() {
T value = initialValue()
Thread t = Thread.currentThread()
ThreadLocalMap map = getMap(t)
if (map != null)
map.set(this, value)
else
createMap(t, value)
return value
}
protected T initialValue() {
return null
}
public void set(T value) {
Thread t = Thread.currentThread()
ThreadLocalMap map = getMap(t)
if (map != null)
map.set(this, value)
else
createMap(t, value)
}
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread())
if (m != null)
m.remove(this)
}
static class ThreadLocalMap {
// ThreadLocal 为弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value
Entry(ThreadLocal<?> k, Object v) {
super(k)
value = v
}
}
/**
* 初始容量
*/
private static final int INITIAL_CAPACITY = 16
/**
* 存放数据的数组,其大小必须是2的次幂
*/
private Entry[] table
/**
* 数组的实际大小
*/
private int size = 0
/**
* 扩容的阈值
*/
private int threshold
/**
* 定义为长度的 2/3
*/
private void setThreshold(int len) {
threshold = len * 2 / 3
}
/**
* 下一个索引,使用线性探测法解决哈希冲突
*/
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0)
}
/**
* 上一个索引
*/
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1)
}
/**
* 构造方法,至少包含一个值,因此可以看出ThreadLocalMap是延迟加载
*/
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY]
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)
table[i] = new Entry(firstKey, firstValue)
size = 1
setThreshold(INITIAL_CAPACITY)
}
// 获取 key 对应的value
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1)
Entry e = table[i]
if (e != null && e.get() == key)
return e
else
// 没有找到相同的key
return getEntryAfterMiss(key, i, e)
}
// 使用线性探测法,向后找到对应的 key
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table
int len = tab.length
while (e != null) {
ThreadLocal<?> k = e.get()
if (k == key)
return e
if (k == null)
// 清理过期 key
expungeStaleEntry(i)
else
i = nextIndex(i, len)
e = tab[i]
}
return null
}
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table
int len = tab.length
// 计算索引值
int i = key.threadLocalHashCode & (len-1)
// 如果当前索引位置有元素,会使用线性探测法,找到 entry 为空的位置
for (Entry e = tab[i]
ThreadLocal<?> k = e.get()
// 相同的ThreadLocal,新值替换旧值
if (k == key) {
e.value = value
return
}
// table[i]上的key为空,说明被回收了(上面的弱引用中提到过)。
// 这个时候说明 table[i] 可以重新使用,用新的 key-value 将其替换,
// 并删除其他无效的entry
if (k == null) {
replaceStaleEntry(key, value, i)
return
}
}
// 找到 entry 为空的位置,将元素插入
tab[i] = new Entry(key, value)
int sz = ++size
// cleanSomeSlots用于清除那些e.get()==null,
// 也就是 table[index] != null && table[index].get()==null
// 之前提到过,这种数据key关联的对象已经被回收,所以这个Entry(table[index])可以被置null。
// 如果没有清除任何entry,并且当前使用量达到了负载因子所定义(长度的2/3),那么进行rehash()
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash()
}
//
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table
int len = tab.length
Entry e
int slotToExpunge = staleSlot
// 一直向前找,
// 如果遇到 entry = null 的槽,结束循环
// 如果 entry != null,那么迭代的时候发现 key = null 的槽,也就是发现了过期键
// 更新 slotToExpunge 的值。
for (int i = prevIndex(staleSlot, len)
if (e.get() == null)
slotToExpunge = i
// 向后查找,直到 entry 为空,或找到与当前 key 相同的 entry
for (int i = nextIndex(staleSlot, len)
(e = tab[i]) != null
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get()
// 如果向后找到了 key 相同的槽,当前过期槽(staleSlot)与新槽(i)交换位置,
if (k == key) {
e.value = value
tab[i] = tab[staleSlot]
tab[staleSlot] = e
// 当 slotToExpunge == staleSlot 时,代表我们在向前查找时没有找到过期数据
if (slotToExpunge == staleSlot)
slotToExpunge = i
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)
return
}
// k == null说明当前遍历的 Entry 是一个过期数据,则更新 slotToExpunge 的值
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i
}
// 没有找到相同的 key,则新建一个槽替换当前过期槽
tab[staleSlot].value = null
tab[staleSlot] = new Entry(key, value)
// 如果我们向前向后遍历时发现了其它的过期数据,也就是 slotToExpunge != staleSlot
// 清理过期数据
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)
}
// 探测式清理,线性清理,向后遍历,直到遇到空 entry 时结束
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table
int len = tab.length
// 先将当前位置数据清空
tab[staleSlot].value = null
tab[staleSlot] = null
size--
// Rehash until we encounter null
Entry e
int i
// 向后遍历,直到遇到空 entry
for (i = nextIndex(staleSlot, len)
(e = tab[i]) != null
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get()
// 如果当前的 key == null,则当前 entry 为过期数据,需要清空
if (k == null) {
e.value = null
tab[i] = null
size--
} else {
int h = k.threadLocalHashCode & (len - 1)
// 计算当期位置是否发生偏移,如果发生偏移,
// 重新计算slot位置,
// 目的是让正常数据尽可能存放在正确位置或离正确位置更近的位置
// 也就是更接近 i= key.hashCode & (tab.len - 1)的位置。
// 这种优化会提高整个散列表查询性能。
if (h != i) {
tab[i] = null
while (tab[h] != null)
h = nextIndex(h, len)
tab[h] = e
}
}
}
return i
}
// 启发式清理,循环 log2n 次
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false
Entry[] tab = table
int len = tab.length
do {
i = nextIndex(i, len)
Entry e = tab[i]
// 当出现过期 key 时,重置 n = len
if (e != null && e.get() == null) {
n = len
removed = true
i = expungeStaleEntry(i)
}
// n >>>= 1 = n / 2
} while ( (n >>>= 1) != 0)
return removed
}
/**
* Remove the entry for key.
*/
private void remove(ThreadLocal<?> key) {
Entry[] tab = table
int len = tab.length
int i = key.threadLocalHashCode & (len-1)
for (Entry e = tab[i]
e != null
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear()
// 探测式清理
expungeStaleEntry(i)
return
}
}
}
/**
* Re-pack and/or re-size the table. First scan the entire
* table removing stale entries. If this doesn't sufficiently
* shrink the size of the table, double the table size.
*/
private void rehash() {
expungeStaleEntries()
// 过期 key 清理完了,size >= threshold - threshold / 4,扩容
if (size >= threshold - threshold / 4)
resize()
}
// 遍历整个table,清理过期 key
private void expungeStaleEntries() {
Entry[] tab = table
int len = tab.length
for (int j = 0
Entry e = tab[j]
if (e != null && e.get() == null)
expungeStaleEntry(j)
}
}
/**
* Double the capacity of the table.
*/
private void resize() {
Entry[] oldTab = table
int oldLen = oldTab.length
int newLen = oldLen * 2
Entry[] newTab = new Entry[newLen]
int count = 0
// 将旧数组中的元素放到新数组中
for (int j = 0
Entry e = oldTab[j]
if (e != null) {
ThreadLocal<?> k = e.get()
if (k == null) {
e.value = null
} else {
int h = k.threadLocalHashCode & (newLen - 1)
while (newTab[h] != null)
h = nextIndex(h, newLen)
newTab[h] = e
count++
}
}
}
setThreshold(newLen)
size = count
table = newTab
}
}
}