聊聊 ThreadLocal 那些事

1,249 阅读7分钟

前言

这篇文章聊聊 ThreadLocal,我们经常会在一些开源中间件的源码中见到它的身影,比较常见的用途是保存上下文信息,还有就是保证了线程安全。

实际上,ThreadLocal 为每个线程提供一个单独的变量,确是一种保证线程安全的手段,ThreadLocal 创建的变量只能被当前线程访问,其他线程不得干涉。

ThreadLocal API

使用 ThreadLocal 其实非常简单,直接看下面的示例:


public class ThreadLocalSimpleDateFormat {
    private static final ThreadLocal<SimpleDateFormat> formatter = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMss HHmm"));

    public String formatDate(Date date) {
        return formatter.get().format(date);
    }
}    
  • 使用 withIntial() 静态方法 初始化变量
  • 调用 get() 方法获取当前线程的对象值

我们都知道 SimpleDateFormat 不是线程安全,所以上面使用 ThreadLocal 的方式来保证线程安全,保证每个线程都有特定的 SimpleDateFormat。

当然不只这两个方法,还有常见的 set()、initialValue()方法,使用起来都比较简单。

什么时候使用TheadLocal

有时候我们会遇到将某个变量和线程进行绑定起来的场景,一种方法是定义一个Map,其中key是线程ID,value是对象,这样每个线程都有属于自己的对象,同一个线程只能对应一个对象。

比如,定义一个计数器,每个线程需要有一个属于自己的计数器,保证线程安全。


public class Counter {
    private AtomicLong counter = new AtomicLong();

    private static Map<Long, Counter> counterMap = new ConcurrentHashMap<>();

    public static Counter getCounter() {
        long id = Thread.currentThread().getId();
        counterMap.putIfAbsent(id, new Counter());
        return counterMap.get(id);
    }

    public long incrementAndGet() {
        return counter.incrementAndGet();
    }
}

而 Java 提供了更简单的方法,也就是 ThreadLocal 工具类,使用 ThreadLocal 类可以直接给每个线程提供单独的计数器


public class ThreadLocal_Counter {
    private ThreadLocal<AtomicLong> threadLocal = new ThreadLocal<>();
    
    public long incrementAndGet() {
        return threadLocal.get().incrementAndGet();
    }
}

显然使用 ThreadLocal 来得更简洁方便,之前提到使用 ThreadLocal 构建线程独享的 SimpleDateFormat 也是同样的道理,主要是保证线程安全。

ThreadLocal如何实现

探究ThreadLocal是如何实现的,从它的set()方法可以看出端倪,如下:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
  • 获取当前线程
  • 根据当前线程作为参数获取一个 ThreadLocalMap 的对象
  • 如果 ThreadLocalMap 对象不为空,则将 ThreadLocal 本身作为 key,value 为值,否则直接创建这个 ThreadLocalMap 对象并设置值

重点在于这个 ThreadLocalMap 对象,走进 getMap 方法 :

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

实际上是获取的 Thread 对象的 threadLocals 变量,也就是:

class Thread implements Runnable {

    ThreadLocal.ThreadLocalMap threadLocals = null;
}

每个线程都有一个 ThreadLocalMap 变量,将 value 值放入其中,自然是只能在本线程中访问。

那这 ThreadLocalMap 究竟是怎样一种构造,它的作用与 Map 有些类似,请继续往下看:

static class ThreadLocalMap {
       
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        // 其他代码省略
}

ThreadLocalMap 是 Thread 的静态内部类, 而它的内部还有一个静态内部类 (禁止套娃) ,Entry类就是 ThreadLocalMap 用来进行 key-value 存储的, key是 ThreadLocal,value 是值。

Entry 继承自 WeakReference,WeakReference 表示弱引用,那么 Entry 的 key (ThreadLocal) 就是 一个弱引用。弱引用的意思是:如果在发生垃圾回收时,若这个对象只被弱引用指向,那么就会被回收。key 被回收了,而value 值却没有被回收,导致value一直存在,从而引发内存泄漏,这就是 ThreadLocal内存泄漏的问题,所以我们需要手动调用 remove() 方法去清除value值。

值得注意的是,如果 key有被强引用指向,那么在垃圾回收的时候是不会被回收的。

ThreadLocal的注意事项

static修饰ThreadLocal

在《阿里巴巴Java开发手册》中给出了使用 ThreadLocal 的建议:

【参考】ThreadLocal无法解决共享对象的更新问题,ThreadLocal对象建议使用static修饰。这个变量是针对一个线程内所有操作共享的,所以设置为静态变量,所有此类实例共享此静态变量 ,也就是说在类第一次被使用时装载,只分配一块存储空间,所有此类的对象(只要是这个线程内定义的)都可以操控这个变量。

前面分析过 ThreadLocal 是 ThreadLocalMap 中 Entry 的 key,而用 static 修饰 ThreadLocal,保证了 ThreadLocal 有强引用在,也就是 Entry 的 key有被强引用指向,会一直存在,垃圾回收的时候不会被回收,这样就不容易导致内存泄漏的问题

ThreadLocal结合线程池的问题

当 ThreadLocal 配合线程池使用的时候,我们需要及时对 ThreadLocal 进行清理,清除与本线程绑定的 value 值,否则会出现意料之外的结果。

下面来看一个代码示例,来看看没有调用remove方法和有调用remove下的结果差异。

private static ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 0);

public static void main(String[] args) {
    ExecutorService executorService = Executors.newFixedThreadPool(2);
    for (int i = 0; i < 5; i++) {
        executorService.execute(()->{
            Integer before = threadLocal.get();
            threadLocal.set(before + 1);
            Integer after = threadLocal.get();
            System.out.println("before: " + before + ",after: " + after);
        });
    }
    executorService.shutdown();
}

首先我们没有调用 remove 方法进行清理,它的打印结果是:

before: 0,after: 1
before: 0,after: 1
before: 1,after: 2
before: 2,after: 3
before: 3,after: 4

可以看到出现了 before 不为0的情况,这是因为线程在执行完任务被复用了,被复用的线程使用了上一个线程操作的value对象,从而导致不符合预期。

然后我们加上调用remove方法的逻辑:

try {
    Integer before = threadLocal.get();
    threadLocal.set(before + 1);
    Integer after = threadLocal.get();
    System.out.println("before: " + before + ",after: " + after);
} finally {
    threadLocal.remove();
}

这次输出的结果回归正常了:

before: 0,after: 1
before: 0,after: 1
before: 0,after: 1
before: 0,after: 1
before: 0,after: 1

总结来说,使用 ThreadLocal 的时候要及时调用 remove() 方法进行清理。

ThreadLocal的应用

文章最开头就提到,ThreadLocal 被频繁运用到开源中间件中,比如RocketMQ、Dubbo、Zuul等等,下面就来学习下开源中间件是如何使用 ThreadLocal的。

Zuul

最近有研究 API 网关的实现原理,Zuul 1.x 算的上一款比较优秀的网关,在它的源码实现中的RequestContext类就用到了 ThreadLocal,保存线程上下文信息。

public class RequestContext extends ConcurrentHashMap<String, Object> {

    protected static Class<? extends RequestContext> contextClass = RequestContext.class;
    
    private static RequestContext testContext = null;
    // 使用 ThreadLocal 保存线程上下文信息
    protected static final ThreadLocal<? extends RequestContext> threadLocal = new ThreadLocal<RequestContext>() {
        @Override
        protected RequestContext initialValue() {
            try {
                return contextClass.newInstance();
            } catch (Throwable e) {
                throw new RuntimeException(e);
            }
        }
    };
    
    public static RequestContext getCurrentContext() {
        if (testContext != null) return testContext;

        RequestContext context = threadLocal.get();
        return context;
    }
    ...

通过 ThreadLocal 保存上下文信息,在任意地方调用getCurrentContext()方法就可以获取当前线程的RequestContext,然后再从 RequestContext 获取 Request或者Response进行相应处理。

RocketMQ

在 RocketMQ的源码实现中也有用到 ThreadLocal,代码如下:

public class ThreadLocalIndex {
    private final ThreadLocal<Integer> threadLocalIndex = new ThreadLocal<Integer>();
    private final Random random = new Random();

    public int getAndIncrement() {
        Integer index = this.threadLocalIndex.get();
        if (null == index) {
            index = Math.abs(random.nextInt());
            if (index < 0)
                index = 0;
            this.threadLocalIndex.set(index);
        }

        index = Math.abs(index + 1);
        if (index < 0)
            index = 0;

        this.threadLocalIndex.set(index);
        return index;
    }
}

ThreadLocalIndex 主要用于生产者发送消息的时候,熟悉 RocketMQ 的小伙伴都知道,生产者首先拉取 Topic 的路由信息,一个 Topic 有多个 MessageQueue (消息队列),发送消息时需要选择一个消息队列进行发送,一般采用轮询的方式选择,此时不同的生产者线程需要有自己负责的轮询顺序,所以使用 ThreadLocalIndex来保证。

小结

本文介绍了 ThreadLocal 的一些常见知识点,再次总结一点:为了保证安全和结果的准确性,我们需要在使用 ThreadLocal 后及时调用 remove()方法进行清理工作。

同时,欢迎关注我新开的公众号,定期分享Java后端知识!

pjmike

参考资料