浅析ThreadLocal

164 阅读4分钟

1. ThreadLocal简介与基本适用

在项目中经常适用SimpleDateFormat对象实例来解析字符串类型的日期。但是在多线程的情况下会出现解析错误,如下面这段代码。

    public class ParseDate implements Runnable{
    private static final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    int i = 0;
    public ParseDate(int i) {this.i = i;}
    @Override
    public void run() {
        try {
            Date date = sdf.parse("2015-03-11 12:29:"+i%60);
            System.out.println(i+":"+date);
        } catch (ParseException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        ExecutorService es = Executors.newFixedThreadPool(10);
        for (int i = 0; i<1000; i++) {
            es.execute(new ParseDate(i));
        }
    }
}

会出抛出下面的异常,主要是由于多线程下parse()方法不是线程安全的,因此在线程池中共享这个对象必然导致错误。 image.png 一种方法是在parse()方法前后加锁。从ThreadLocal的名字可以看出,这是线程的一个局部变量。也就是只有当前线程才可以访问

public class ParseDate implements Runnable{
    private static ThreadLocal<SimpleDateFormat> threadLocal = new ThreadLocal<>();
    int i = 0;
    public ParseDate(int i) {this.i = i;}
    @Override
    public void run() {
        try {
            if (threadLocal.get() == null) {
                threadLocal.set(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
            }
            Date date = threadLocal.get().parse("2015-03-11 12:29:"+i%60);
            System.out.println(i+":"+date);
        } catch (ParseException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        ExecutorService es = Executors.newFixedThreadPool(10);
        for (int i = 0; i<1000; i++) {
            es.execute(new ParseDate(i));
        }
    }
}

2. ThreadLocal的实现原理

2.1 ThreadLocal

ThreadLocal中主要包含两个方法:get()和set()。

  • get

首先获得当前线程对象的ThreadLocalMap,然后通过将自己作为key取得内部的实际数据。

    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的threadLocals(Thread类的成员变量)
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //ThreadLocalMap的key是ThreadLocal对象,value就是存储的变量
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

首先获得当前线程对象的ThreadLocalMap,然后将当前ThreadLocal实例对象作为key存入ThreadLocalMap中。

  • set
    public void set(T value) {
       //获取当前调用这行代码的线程对象
        Thread t = Thread.currentThread();
        //获取当前线程的threadLocals(Thread类的成员变量)
        ThreadLocalMap map = getMap(t);
        if (map != null)
            //Entry中的key为ThreadLocal对象,v为value
            map.set(this, value);
        else
            //初始化当前线程的threadLocals变量,并放入数据
            createMap(t, value);
    }

2.png

ThreadLocal中的get和set方法都是对当前线程持有的ThreadLocalMap进行操作,不同的线程持有不同的Thread对象,而不同的Thread对象又持有不同的ThreadLocalMap对象,这样就实现了线程之间的相互隔离。而ThreadLocalMap类为ThreadLocal的静态内部类,接下来讲下ThreadLocalMap的实现。

2.2 ThreadLocalMap

ThreadLocalMap是一个类似与Map的数据结构。其中Entry的实现结构如下:

        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
  • set
        private void set(ThreadLocal<?> key,  Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            //从i往后找
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                //重新赋值
                if (k == key) {
                    e.value = value;
                    return;
                }
                //原来的key被回收了,替换和回收失效位置的元素
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //没有找到旧值,直接new一个entry到i位置上
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //清理数组中i到sz的key失效的元素。
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                //如果没有key失效,并且sz大于等于阈值做一次rehash(因为数组长度变了)
                rehash();
        }

ThreadLocal中解决hash冲突使用的是线性探测法,即一直往下找,直到有合适的插入位置。而HashMap中使用的是拉链法。

  • getEntry
    private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            //通过hash定位插入位置,并且判断获得的key是否和传入的key相等
            if (e != null && e.get() == key)
                return e;
            else
                //直接定位没找到,说明没有set或者发生了hash碰撞
                return getEntryAfterMiss(key, i, e);
        }

2.3 InheritableThreadLocal

父子线程之间如果需要传递值就需要用到InheritableThreadLocal。

ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

2.3.1 使用

    public static void main(String[] args) {
        ThreadLocal<String> tl = new InheritableThreadLocal<>();
        tl.set("haha");
        new Thread(){
            @Override
            public void run() {
                System.out.println(tl.get());
            }
        }.start();
    }

2.3.2 原理

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
  
    protected T childValue(T parentValue) {
        return parentValue;
    }


    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }


    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

InheritableThreadLocal继承自ThreadLocal,实现了三个方法childValue(),getMap(),createMap()。getMap()方法就是获取当前线程的inheritableThreadLocals变量,createMap()方法就是给inheritableThreadLocals变量做初始化。

那么InheritableThreadLocal究竟是如何实现父子线程直接的变量共享呢?

从Thread类中的init()方法找到了答案。

 private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        if (name == null) {
            throw new NullPointerException("name cannot be null");
        }

        this.name = name;
        //获取父线程(也就是当前线程)
        Thread parent = currentThread();

        //省略
        //如果父类的inheritableThreadLocals不为null,并且inheritThreadLocals为true
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        //设置子线程的inheritableThreadLocals
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        //省略
    }
    //进行一次拷贝
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

可以看出:父线程在new Thread()创建子线程的时候,将父线程中的inheritableThreadLocals拷贝到子线程的inheritableThreadLocals中。

由于大多数开发使用的是线程池,这时候创建的线程只是复用了线程池中已有的线程,并不会new新的线程,因此使用InheritableThreadLocal无效。

3. ThreadLocal内存泄漏

如果Entry中的可以被设计为强引用,那么即使ThreadLocal对象设置为null,仍然不能被回收。因为ThreadLocalMap对其也是强引用。

ThreadLocalMapEntry的key被设计为弱引用主要为了帮助GC回收ThreadLocal对象,而Entry 中的value是强引用。只被弱引用关联的对象在GC时无论内存是否足够,都会进行回收因此只要ThreadLocal对象设置为null,Entry中的key在下次GC时就会被自动回收。那么value对象如何回收呢?

ThreadLocalMap在每次getEntry和set方法中,都会将这些stale entry的value设置为null,使得原来value指向对象可以被回收。

因此防止内存泄漏的关键在于每次使用完ThreadLocal对象后,需要手动调用remove方法。如果没有调用remove方法,ThreadLocal对象会一直存在,直到当前线程被销毁。

综上所述,防止内存泄漏的方法:

  1. key设计为弱引用
  2. get和set方法时会清理stale entry
  3. 主动调用ThreadLocal的remove方法