「源码学习」ThreadLocal 类

68 阅读3分钟

ThreadLocal是JDK的一个非常重要的类,它可以理解为线程的的本地变量(副本),只有线程自己可以访问,保证各个线程之间的变量互不干扰。

在学习Thread的源码的时候,了解到ThreadLocal类提供了get/set方法。

ThreadLocal方法

set

public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            //如果存在,直接存储value
            map.set(this, value);
        else
           //如果不存在,先创建map,再存储value
            createMap(t, value);
    }

get

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //把当前ThreadLocal作为key,获取map中的entry
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                //返回entry的value
                return result;
            }
        }
        //map为空,则返回null
        return setInitialValue();
    }

ThreadLocalMap

ThreadLocalMap和HashMap的结构类似,HashMap是由数组+链表实现的,但是ThreadLocalMap没有链表结构。

还有Entry,它的key是继承自WeekReference,属于弱引用类型,在GC的时候,key很有可能被回收掉。

为么ThreadLocal不直接使用HashMap呢?

HashMapThreadLocalMap
数据结构数组+链表+红黑树数组
引用类型强引用Key是弱引用
Hash冲突链表+红黑树解决开放式寻址解决
性能表现大数据量,表现优越小数据量性能更好

一般情况下,ThreadLocal存储的数据量不会很大,被remove后,被回收器回收,数据的存储接口,节省时间和提高效率。

ThreadLocalMap底层是个动态数组,所以源码很简单

static class ThreadLocalMap {

        /**
         * 存放键值对,Key是ThreadLocal,Entry的key是对ThreadLocal的弱引用。
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        // 下面的参数和HashMap很相似
        /**
         * 数组初始容量
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * Entry数组,用于存储 <ThreadLocal<?> k, Object v>键值对
         */
        private Entry[] table;

        /**
         * Entry元素数量
         */
        private int size = 0;

        /**
         * 类似于 HashMap 扩容因子机制
         */
        private int threshold; // Default to 0
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /**
         * 构造方法
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        /**
         * 根据 ThreadLocal对象,获取Entry实例
         */
        private Entry getEntry(ThreadLocal<?> key) {  
	//计算entry table索引
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

        /**
         * set()方法,key为ThreadLocal类型
         */
        private void set(ThreadLocal<?> key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            // 计算数组下标
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
		// 如果key相等(==判断内存地址相等),覆盖原来的value
                if (k == key) {
                    e.value = value;
                    return;
                }
		// 如果key为null,用新key、value覆盖,同时清理历史key=null的陈旧数据
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
	    //没有找到ThreadLocal,也不存在过期的key(key == null),则直接在索引下标位置设置entry
            tab[i] = new Entry(key, value);
            int sz = ++size;
            // 超过数组阈值,进行扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

        /**
         * 扩容操作
         */
        private void rehash() {
        // 扫描整个容器,删除过时的数据
            expungeStaleEntries();

       // 进行扩容操作
            if (size >= threshold - threshold / 4)
                resize();
        }
	//扫描整个Entry,删除过期数据
	private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }

        /**
         * 扩容为原容量的两倍,重新调整数组
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
	     // 遍历Entry
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    // 如果key=null,把value也置null,有利于GC回收对象
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
			//计算新的hash
                        int h = k.threadLocalHashCode & (newLen - 1);
			//如果新的下标存在元素,那就继续往后寻找下一个为null的下标
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
	// 设置新的阈值为以前两倍
            setThreshold(newLen);
            size = count;
            table = newTab;
        }
    }

ThreadLocal使用

public class ThreadLocalTest {

    public static ThreadLocal<String> local = new ThreadLocal<>();

    public static void main(String[] args) {
        Thread threadA = new Thread(new Runnable() {
            @Override
            public void run() {
                local.set("thread-a set local value");
                try {
                    Thread.sleep(5000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(local.get());
            }
        });
        Thread threadB = new Thread(new Runnable() {
            @Override
            public void run() {
                local.set("thread-b set local value");
                System.out.println(local.get());
                local.remove();
                System.out.println("thread-b remove local value");
            }
        });

        threadA. start();
        threadB. start();
    }
}

日志输出

thread-b set local value
thread-b remove local value
thread-a set local value

ThreadLocal会存在内存泄漏?

会存在,在源码中看到,key是弱引用,但是value是强引用,不会被GC回收。ThreadLocalMap清理数据是在set()、get()、remove()方法中触发,长时间不主动触发清理数据的话,会存在内存泄漏的问题。