ThreadLocal源码分析(Android-30)

248 阅读6分钟
  • 从get、set为切入点,往下看
  • 边分析边考虑几个问题
    • 如果做到线程分离?

    • 为什么会导致内存泄漏?

    • 为什么这么设计?

    • 如何解决的?

      /*

      • Copyright (c) 1997, 2013, Oracle and/or its affiliates. All rights reserved.
      • DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
      • This code is free software; you can redistribute it and/or modify it
      • under the terms of the GNU General Public License version 2 only, as
      • published by the Free Software Foundation. Oracle designates this
      • particular file as subject to the "Classpath" exception as provided
      • by Oracle in the LICENSE file that accompanied this code.
      • This code is distributed in the hope that it will be useful, but WITHOUT
      • ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
      • FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
      • version 2 for more details (a copy is included in the LICENSE file that
      • accompanied this code).
      • You should have received a copy of the GNU General Public License version
      • 2 along with this work; if not, write to the Free Software Foundation,
      • Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
      • Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
      • or visit www.oracle.com if you need additional information or have any
      • questions. */

      package java.lang; import java.lang.ref.*; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier;

      public class ThreadLocal {

      private final int threadLocalHashCode = nextHashCode();
      
      private static AtomicInteger nextHashCode =
          new AtomicInteger();
      
      
      private static final int HASH_INCREMENT = 0x61c88647;
      
      
      private static int nextHashCode() {
          return nextHashCode.getAndAdd(HASH_INCREMENT);
      }
      
      protected T initialValue() {
          return null;
      }
      
      public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
          return new SuppliedThreadLocal<>(supplier);
      }
      
      public ThreadLocal() {
        //构造方法,什么都没做
      }
      
      public T get() {
          Thread t = Thread.currentThread();//获取当前线程
          ThreadLocalMap map = getMap(t);//获取该线程独有的成员ThreadLocal.ThreadLocalMap
          if (map != null) {//命中
              //从Entry[]数组table中根据hashCode值获取Entry
              ThreadLocalMap.Entry e = map.getEntry(this);
              if (e != null) {
                  @SuppressWarnings("unchecked")
                  T result = (T)e.value;
                  return result;
              }
          }
          return setInitialValue();//创建该线程的ThreadLocal.ThreadLocalMap
      }
      
      private T setInitialValue() {
          T value = initialValue();//initialValue()这个方法的返回null
          Thread t = Thread.currentThread();//获取当前线程
          ThreadLocalMap map = getMap(t);
          if (map != null)
              map.set(this, value);
          else
              createMap(t, value);//这里会new ThreadLocalMap(this, value);
          return value;
      }
      
      
      public void set(T value) {
          Thread t = Thread.currentThread();//获取当前线程
          ThreadLocalMap map = getMap(t);//获取该线程独有的成员ThreadLocal.ThreadLocalMap
          if (map != null)
              map.set(this, value);//注意这里,插入时会做优化操作,尽可能的去避免内存泄漏
          else
              createMap(t, value);//这里会new ThreadLocalMap(this, value);
      }
      
       public void remove() {
           ThreadLocalMap m = getMap(Thread.currentThread());
           if (m != null)
               m.remove(this);
       }
      
      
      ThreadLocalMap getMap(Thread t) {
          return t.threadLocals;
      }
      
      void createMap(Thread t, T firstValue) {
          t.threadLocals = new ThreadLocalMap(this, firstValue);
      }
      
      static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
          return new ThreadLocalMap(parentMap);
      }
      
      
      T childValue(T parentValue) {
          throw new UnsupportedOperationException();
      }
      
      
      static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
      
          private final Supplier<? extends T> supplier;
      
          SuppliedThreadLocal(Supplier<? extends T> supplier) {
              this.supplier = Objects.requireNonNull(supplier);
          }
      
          @Override
          protected T initialValue() {
              return supplier.get();
          }
      }
      
      
      static class ThreadLocalMap {
        //注意这里,Entry继承WeakReference,垃圾回收时,可能会回收掉Entry的key,导致内存泄漏
          static class Entry extends WeakReference<ThreadLocal<?>> {
              /** The value associated with this ThreadLocal. */
              Object value;
      
              Entry(ThreadLocal<?> k, Object v) {
                  super(k);
                  value = v;
              }
          }
      
          /**
           * The initial capacity -- MUST be a power of two.
           */
        //ThreadLocal.ThreadLocalMap的初始容量 
        //必须为2的幂,具体为什么,这里大概提一下。
        //因为2的幂次-1,转为二进制位全为1,例如;7:111,15:1111
        //在放入table[]时,通过hashCode&(2的幂-1), 可以尽可能的不重复和均匀插入
          private static final int INITIAL_CAPACITY = 16;
      
          /**
           * The table, resized as necessary.
           * table.length MUST always be a power of two.
           */
          private Entry[] table;
      
          /**
           * The number of entries in the table.
           */
          private int size = 0;
      
          /**
           * The next size value at which to resize.
           */
          private int threshold; // Default to 0
      
          /**
           * Set the resize threshold to maintain at worst a 2/3 load factor.
           */
          private void setThreshold(int len) {
              threshold = len * 2 / 3;
          }
      
          /**
           * Increment i modulo len.
           */
          //获取下一个index,如超过len,从头开始
          private static int nextIndex(int i, int len) {
              return ((i + 1 < len) ? i + 1 : 0);
          }
      
          /**
           * Decrement i modulo len.
           */
          //获取上一个index,如小于0,从最后开始
          private static int prevIndex(int i, int len) {
              return ((i - 1 >= 0) ? i - 1 : len - 1);
          }
      
          /**
           * Construct a new map initially containing (firstKey, firstValue).
           * ThreadLocalMaps are constructed lazily, so we only create
           * one when we have at least one entry to put in it.
           */
          ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
              table = new Entry[INITIAL_CAPACITY];
              //通过计算,获取插入的index位置
              int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
              table[i] = new Entry(firstKey, firstValue);
              size = 1;
              setThreshold(INITIAL_CAPACITY);//设置阈值
          }
      
          /**
           * Construct a new map including all Inheritable ThreadLocals
           * from given parent map. Called only by createInheritedMap.
           *
           * @param parentMap the map associated with parent thread.
           */
          private ThreadLocalMap(ThreadLocalMap parentMap) {
              Entry[] parentTable = parentMap.table;
              int len = parentTable.length;
              setThreshold(len);
              table = new Entry[len];
      
              for (int j = 0; j < len; j++) {
                  Entry e = parentTable[j];
                  if (e != null) {
                      @SuppressWarnings("unchecked")
                      ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                      if (key != null) {
                          Object value = key.childValue(e.value);
                          Entry c = new Entry(key, value);
                          int h = key.threadLocalHashCode & (len - 1);
                          while (table[h] != null)
                              h = nextIndex(h, len);
                          table[h] = c;
                          size++;
                      }
                  }
              }
          }
      
          /**
           * Get the entry associated with key.  This method
           * itself handles only the fast path: a direct hit of existing
           * key. It otherwise relays to getEntryAfterMiss.  This is
           * designed to maximize performance for direct hits, in part
           * by making this method readily inlinable.
           *
           * @param  key the thread local object
           * @return the entry associated with key, or null if no such
           */
          private Entry getEntry(ThreadLocal<?> key) {
              int i = key.threadLocalHashCode & (table.length - 1);//hash获取index
              Entry e = table[i];
              if (e != null && e.get() == key)//命中
                  return e;
              else//hash得到到index没有值,或者key==null(被垃圾回收掉)
                  return getEntryAfterMiss(key, i, e);
          }
      
          /**
           * Version of getEntry method for use when key is not found in
           * its direct hash slot.
           *
           * @param  key the thread local object
           * @param  i the table index for key's hash code
           * @param  e the entry at table[i]
           * @return the entry associated with key, or null if no such
           */
          private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
              Entry[] tab = table;
              int len = tab.length;
      
              while (e != null) {
                  ThreadLocal<?> k = e.get();
                  if (k == key)
                      return e;
                  if (k == null)
                      expungeStaleEntry(i);//删除
                  else
                      i = nextIndex(i, len);
                  e = tab[i];
              }
              return null;
          }
      
          /**
           * Set the value associated with key.
           *
           * @param key the thread local object
           * @param value the value to be set
           */
          private void set(ThreadLocal<?> key, Object value) {
      
              // We don't use a fast path as with get() because it is at
              // least as common to use set() to create new entries as
              // it is to replace existing ones, in which case, a fast
              // path would fail more often than not.
      
              Entry[] tab = table;//获取当前Entry[]
              int len = tab.length;//拿到长度
              int i = key.threadLocalHashCode & (len-1);//计算出插入的index
      
              for (Entry e = tab[i];
                   e != null;
                   e = tab[i = nextIndex(i, len)]) {
                  ThreadLocal<?> k = e.get();
      
                  if (k == key) {//i上有key,且相等
                      e.value = value;//直接覆盖原来的值
                      return;
                  }
                
                //注意这里,e不为null,但是key为null
                //因为Entry的key为weakReference弱引用
                //弱引用在垃圾回收时,只要扫描到,就会回收
                
                
                  if (k == null) {//这里对内存泄漏做了优化
                      replaceStaleEntry(key, value, i);
                      return;
                  }
              }
      
              tab[i] = new Entry(key, value);
              int sz = ++size;
              if (!cleanSomeSlots(i, sz) && sz >= threshold)
                  rehash();
          }
      
          /**
           * Remove the entry for key.
           */
          private void remove(ThreadLocal<?> key) {
              Entry[] tab = table;
              int len = tab.length;
              int i = key.threadLocalHashCode & (len-1);
              for (Entry e = tab[i];
                   e != null;
                   e = tab[i = nextIndex(i, len)]) {
                  if (e.get() == key) {
                      e.clear();
                      expungeStaleEntry(i);
                      return;
                  }
              }
          }
      
          private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                         int staleSlot) {
              Entry[] tab = table;
              int len = tab.length;
              Entry e;
      
              // Back up to check for prior stale entry in current run.
              // We clean out whole runs at a time to avoid continual
              // incremental rehashing due to garbage collector freeing
              // up refs in bunches (i.e., whenever the collector runs).
              int slotToExpunge = staleSlot;//需要清除的index
              //从后往前轮询
              for (int i = prevIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = prevIndex(i, len))
                  if (e.get() == null)//找到key被回收的Entry
                      slotToExpunge = i;
      
              // Find either the key or trailing null slot of run, whichever
              // occurs first
              //从前往后轮询
              for (int i = nextIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = nextIndex(i, len)) {
                  ThreadLocal<?> k = e.get();
                  if (k == key) {
                      e.value = value;//value覆盖
      
                      tab[i] = tab[staleSlot];//把被回收的entry赋值给下标为i的entry
                      tab[staleSlot] = e;//赋值给原被回收的entry位置
      
                      // Start expunge at preceding stale entry if it exists
                      if (slotToExpunge == staleSlot)
                          slotToExpunge = i;
                      cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                      return;
                  }
      
                  // If we didn't find stale entry on backward scan, the
                  // first stale entry seen while scanning for key is the
                  // first still present in the run.
                  if (k == null && slotToExpunge == staleSlot)
                      slotToExpunge = i;
              }
      
              // If key not found, put new entry in stale slot
              tab[staleSlot].value = null;
              tab[staleSlot] = new Entry(key, value);
      
              // If there are any other stale entries in run, expunge them
              if (slotToExpunge != staleSlot)
                  cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
          }
      
          //删除已被回收的entry
          private int expungeStaleEntry(int staleSlot) {
              Entry[] tab = table;
              int len = tab.length;
      
              // expunge entry at staleSlot
              tab[staleSlot].value = null;
              tab[staleSlot] = null;
              size--;
      
              // Rehash until we encounter null
              Entry e;
              int i;
              for (i = nextIndex(staleSlot, len);
                   (e = tab[i]) != null;
                   i = nextIndex(i, len)) {
                  ThreadLocal<?> k = e.get();
                  if (k == null) {
                      e.value = null;
                      tab[i] = null;
                      size--;
                  } else {
                    //从新hash,获取插入table的index
                      int h = k.threadLocalHashCode & (len - 1);
                      if (h != i) {//位置有变化
                          tab[i] = null;//原位置赋null
      
                          // Unlike Knuth 6.4 Algorithm R, we must scan until
                          // null because multiple entries could have been stale.
                          while (tab[h] != null)//当前位置有值
                              h = nextIndex(h, len);//往后查找,直到该位置为null
                          tab[h] = e;
                      }
                  }
              }
              return i;
          }
      
        //清理
          private boolean cleanSomeSlots(int i, int n) {
              boolean removed = false;
              Entry[] tab = table;
              int len = tab.length;
              do {
                  i = nextIndex(i, len);
                  Entry e = tab[i];
                  if (e != null && e.get() == null) {//被回收的entry
                      n = len;
                      removed = true;
                      i = expungeStaleEntry(i);
                  }
              } while ( (n >>>= 1) != 0);
              return removed;
          }
      
          /**
           * Re-pack and/or re-size the table. First scan the entire
           * table removing stale entries. If this doesn't sufficiently
           * shrink the size of the table, double the table size.
           */
          private void rehash() {
              expungeStaleEntries();
      
              // Use lower threshold for doubling to avoid hysteresis
              if (size >= threshold - threshold / 4)//大于等于阈值的3/4
                  resize();
          }
      
          /**
           * Double the capacity of the table.
           */
        //扩容
          private void resize() {
            //新table长度为旧table的两倍
              Entry[] oldTab = table;
              int oldLen = oldTab.length;
              int newLen = oldLen * 2;
              Entry[] newTab = new Entry[newLen];
              int count = 0;
      
              for (int j = 0; j < oldLen; ++j) {
                  Entry e = oldTab[j];
                  if (e != null) {
                      ThreadLocal<?> k = e.get();
                      if (k == null) {
                        //key为空,value也赋null,使其可以被gc扫描回收
                          e.value = null; // Help the GC
                      } else {
                          int h = k.threadLocalHashCode & (newLen - 1);//通过新的table,获取hash值
                          while (newTab[h] != null)//依次往后查询获取entry==null的位置
                              h = nextIndex(h, newLen);
                          newTab[h] = e;
                          count++;
                      }
                  }
              }
      
            //重设阈值、大小以及table
              setThreshold(newLen);
              size = count;
              table = newTab;
          }
      
          /**
           * Expunge all stale entries in the table.
           */
        //删除所有key==null的entry
          private void expungeStaleEntries() {
              Entry[] tab = table;
              int len = tab.length;
              for (int j = 0; j < len; j++) {
                  Entry e = tab[j];
                  if (e != null && e.get() == null)
                      expungeStaleEntry(j);
              }
          }
      }
      

      }