- 从get、set为切入点,往下看
- 边分析边考虑几个问题
-
如果做到线程分离?
-
为什么会导致内存泄漏?
-
为什么这么设计?
-
如何解决的?
/*
- Copyright (c) 1997, 2013, Oracle and/or its affiliates. All rights reserved.
- DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
- This code is free software; you can redistribute it and/or modify it
- under the terms of the GNU General Public License version 2 only, as
- published by the Free Software Foundation. Oracle designates this
- particular file as subject to the "Classpath" exception as provided
- by Oracle in the LICENSE file that accompanied this code.
- This code is distributed in the hope that it will be useful, but WITHOUT
- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
- version 2 for more details (a copy is included in the LICENSE file that
- accompanied this code).
- You should have received a copy of the GNU General Public License version
- 2 along with this work; if not, write to the Free Software Foundation,
- Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
- Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
- or visit www.oracle.com if you need additional information or have any
- questions. */
package java.lang; import java.lang.ref.*; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier;
public class ThreadLocal {
private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647; private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); } protected T initialValue() { return null; } public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) { return new SuppliedThreadLocal<>(supplier); } public ThreadLocal() { //构造方法,什么都没做 } public T get() { Thread t = Thread.currentThread();//获取当前线程 ThreadLocalMap map = getMap(t);//获取该线程独有的成员ThreadLocal.ThreadLocalMap if (map != null) {//命中 //从Entry[]数组table中根据hashCode值获取Entry ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue();//创建该线程的ThreadLocal.ThreadLocalMap } private T setInitialValue() { T value = initialValue();//initialValue()这个方法的返回null Thread t = Thread.currentThread();//获取当前线程 ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value);//这里会new ThreadLocalMap(this, value); return value; } public void set(T value) { Thread t = Thread.currentThread();//获取当前线程 ThreadLocalMap map = getMap(t);//获取该线程独有的成员ThreadLocal.ThreadLocalMap if (map != null) map.set(this, value);//注意这里,插入时会做优化操作,尽可能的去避免内存泄漏 else createMap(t, value);//这里会new ThreadLocalMap(this, value); } public void remove() { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null) m.remove(this); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { return new ThreadLocalMap(parentMap); } T childValue(T parentValue) { throw new UnsupportedOperationException(); } static final class SuppliedThreadLocal<T> extends ThreadLocal<T> { private final Supplier<? extends T> supplier; SuppliedThreadLocal(Supplier<? extends T> supplier) { this.supplier = Objects.requireNonNull(supplier); } @Override protected T initialValue() { return supplier.get(); } } static class ThreadLocalMap { //注意这里,Entry继承WeakReference,垃圾回收时,可能会回收掉Entry的key,导致内存泄漏 static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } /** * The initial capacity -- MUST be a power of two. */ //ThreadLocal.ThreadLocalMap的初始容量 //必须为2的幂,具体为什么,这里大概提一下。 //因为2的幂次-1,转为二进制位全为1,例如;7:111,15:1111 //在放入table[]时,通过hashCode&(2的幂-1), 可以尽可能的不重复和均匀插入 private static final int INITIAL_CAPACITY = 16; /** * The table, resized as necessary. * table.length MUST always be a power of two. */ private Entry[] table; /** * The number of entries in the table. */ private int size = 0; /** * The next size value at which to resize. */ private int threshold; // Default to 0 /** * Set the resize threshold to maintain at worst a 2/3 load factor. */ private void setThreshold(int len) { threshold = len * 2 / 3; } /** * Increment i modulo len. */ //获取下一个index,如超过len,从头开始 private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); } /** * Decrement i modulo len. */ //获取上一个index,如小于0,从最后开始 private static int prevIndex(int i, int len) { return ((i - 1 >= 0) ? i - 1 : len - 1); } /** * Construct a new map initially containing (firstKey, firstValue). * ThreadLocalMaps are constructed lazily, so we only create * one when we have at least one entry to put in it. */ ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { table = new Entry[INITIAL_CAPACITY]; //通过计算,获取插入的index位置 int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); table[i] = new Entry(firstKey, firstValue); size = 1; setThreshold(INITIAL_CAPACITY);//设置阈值 } /** * Construct a new map including all Inheritable ThreadLocals * from given parent map. Called only by createInheritedMap. * * @param parentMap the map associated with parent thread. */ private ThreadLocalMap(ThreadLocalMap parentMap) { Entry[] parentTable = parentMap.table; int len = parentTable.length; setThreshold(len); table = new Entry[len]; for (int j = 0; j < len; j++) { Entry e = parentTable[j]; if (e != null) { @SuppressWarnings("unchecked") ThreadLocal<Object> key = (ThreadLocal<Object>) e.get(); if (key != null) { Object value = key.childValue(e.value); Entry c = new Entry(key, value); int h = key.threadLocalHashCode & (len - 1); while (table[h] != null) h = nextIndex(h, len); table[h] = c; size++; } } } } /** * Get the entry associated with key. This method * itself handles only the fast path: a direct hit of existing * key. It otherwise relays to getEntryAfterMiss. This is * designed to maximize performance for direct hits, in part * by making this method readily inlinable. * * @param key the thread local object * @return the entry associated with key, or null if no such */ private Entry getEntry(ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1);//hash获取index Entry e = table[i]; if (e != null && e.get() == key)//命中 return e; else//hash得到到index没有值,或者key==null(被垃圾回收掉) return getEntryAfterMiss(key, i, e); } /** * Version of getEntry method for use when key is not found in * its direct hash slot. * * @param key the thread local object * @param i the table index for key's hash code * @param e the entry at table[i] * @return the entry associated with key, or null if no such */ private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i);//删除 else i = nextIndex(i, len); e = tab[i]; } return null; } /** * Set the value associated with key. * * @param key the thread local object * @param value the value to be set */ private void set(ThreadLocal<?> key, Object value) { // We don't use a fast path as with get() because it is at // least as common to use set() to create new entries as // it is to replace existing ones, in which case, a fast // path would fail more often than not. Entry[] tab = table;//获取当前Entry[] int len = tab.length;//拿到长度 int i = key.threadLocalHashCode & (len-1);//计算出插入的index for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); if (k == key) {//i上有key,且相等 e.value = value;//直接覆盖原来的值 return; } //注意这里,e不为null,但是key为null //因为Entry的key为weakReference弱引用 //弱引用在垃圾回收时,只要扫描到,就会回收 if (k == null) {//这里对内存泄漏做了优化 replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); } /** * Remove the entry for key. */ private void remove(ThreadLocal<?> key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { if (e.get() == key) { e.clear(); expungeStaleEntry(i); return; } } } private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; // Back up to check for prior stale entry in current run. // We clean out whole runs at a time to avoid continual // incremental rehashing due to garbage collector freeing // up refs in bunches (i.e., whenever the collector runs). int slotToExpunge = staleSlot;//需要清除的index //从后往前轮询 for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null)//找到key被回收的Entry slotToExpunge = i; // Find either the key or trailing null slot of run, whichever // occurs first //从前往后轮询 for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == key) { e.value = value;//value覆盖 tab[i] = tab[staleSlot];//把被回收的entry赋值给下标为i的entry tab[staleSlot] = e;//赋值给原被回收的entry位置 // Start expunge at preceding stale entry if it exists if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } // If we didn't find stale entry on backward scan, the // first stale entry seen while scanning for key is the // first still present in the run. if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // If key not found, put new entry in stale slot tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // If there are any other stale entries in run, expunge them if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); } //删除已被回收的entry private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; // expunge entry at staleSlot tab[staleSlot].value = null; tab[staleSlot] = null; size--; // Rehash until we encounter null Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; tab[i] = null; size--; } else { //从新hash,获取插入table的index int h = k.threadLocalHashCode & (len - 1); if (h != i) {//位置有变化 tab[i] = null;//原位置赋null // Unlike Knuth 6.4 Algorithm R, we must scan until // null because multiple entries could have been stale. while (tab[h] != null)//当前位置有值 h = nextIndex(h, len);//往后查找,直到该位置为null tab[h] = e; } } } return i; } //清理 private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) {//被回收的entry n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; } /** * Re-pack and/or re-size the table. First scan the entire * table removing stale entries. If this doesn't sufficiently * shrink the size of the table, double the table size. */ private void rehash() { expungeStaleEntries(); // Use lower threshold for doubling to avoid hysteresis if (size >= threshold - threshold / 4)//大于等于阈值的3/4 resize(); } /** * Double the capacity of the table. */ //扩容 private void resize() { //新table长度为旧table的两倍 Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (int j = 0; j < oldLen; ++j) { Entry e = oldTab[j]; if (e != null) { ThreadLocal<?> k = e.get(); if (k == null) { //key为空,value也赋null,使其可以被gc扫描回收 e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1);//通过新的table,获取hash值 while (newTab[h] != null)//依次往后查询获取entry==null的位置 h = nextIndex(h, newLen); newTab[h] = e; count++; } } } //重设阈值、大小以及table setThreshold(newLen); size = count; table = newTab; } /** * Expunge all stale entries in the table. */ //删除所有key==null的entry private void expungeStaleEntries() { Entry[] tab = table; int len = tab.length; for (int j = 0; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null) expungeStaleEntry(j); } } }
}
-