java并发编程(3)-ThreadLocal原理剖析

255 阅读7分钟

java并发编程系列前文

  1. java并发编程(1)-并发编程基础(上)
  2. java并发编程(2)-并发编程基础(下)

经过前两章基础知识的铺垫,java并发编程系列从现在开始就要进入高级篇的内容。
大家假想这么一个场景。你有一个变量并且你想让这个变量在多线程环境中可以将其隔离,各个线程对这个变量的修改操作的影响范围只局限于线程本地,不会影响到其他线程。说的比较抽象,我们举两个实际应用中可能遇到的场景说明。

  1. 项目中需要管理数据库连接,希望在多线程环境下,每个线程使用的连接是隔离的,彼此互不影响的。
  2. web项目中多线程处理用户请求,为了记录用户在一个会话中的请求链路,就需要让同一会话的所有请求都在一个线程中处理。这时候你希望线程之间的会话数据是隔离的,即线程A看不到线程B的会话,线程B看不到线程A的会话,他们对会话的操作也不会影响到其他线程。

看过java并发系列前两章内容或者对并发有一定基础的同学很容易就能想到两种方案,加锁或者是通过在线程本地内创建局部变量去解决。

以上方案都有缺点,加锁对性能消耗大,使用局部变量又会使代码臃肿,实现不够优雅。接下来我们看看如何在java中以更优雅的方式解决多线程下变量隔离的问题。

ThreadLocal

ThreadLocal是java给我们提供的类,它提供了线程本地变量的功能。访问ThreadLocal的每个线程,都会在线程本地产生一份副本,后续的读写操作都是基于线程本地的副本,帮我们完美的解决了变量隔离的问题。

使用ThreadLocal

首先我们来看看如何使用使用ThreaLocal

public class ThreadLocalTest {

    private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

    static void print(String str) {
        // 打印当前线程本地内存中的threadLocal的值
        System.out.println(str + ":" + threadLocal.get());
        // 清除当前线程本地内存中的threadLocal的值
        threadLocal.remove();
    }

    public static void main(String[] args) throws InterruptedException {
        // 设置threadLocal值
        threadLocal.set("init");
        Thread thread1 = new Thread(() -> {
            // 打印threadLocal当前值
            System.out.println("current threadLocal:" + threadLocal.get());
            // 设置线程1的本地变量的值
            threadLocal.set("thread1 local variable");
            // 调用print方法
            print("thread1");
            System.out.println("thread1 remove after : " + threadLocal.get());
        });
        thread1.start();
        Thread.sleep(1000);
        System.out.println("threadLocal value : " + threadLocal.get());

        threadLocal.remove();
        System.out.println("threadLocal value : " + threadLocal.get());
    }
}

在上面的例子中,我们在类中创建了一个ThreadLocal的静态常量,常量名是threadLocal。同时还有一个print方法,这个方法的作用在于打印threadLocal中的数据后调用remove方法清除数据。

我们在进入main方法时会对threadLocal进行设值,然后创建一个线程,在线程中我们先打印threadLocal中当前保存的数据,然后调用set方法对它进行设值,调用print方法,在print方法执行结束后打印threadLocal当前值。

为了保证控制台打印的顺序,我们sleep1秒以后在main方法中再打印threadLocal中的数据。

current threadLocal:null
thread1:thread1 local variable
thread1 remove after : null
threadLocal value : init
threadLocal value : null

最后我们的打印结果是这样。

从结果上来看,可以看到代码17行并没有打印出我们先前在main方法中设置的init字符串而是打印出null,而在26行又打印出了init字符串。

看过我们并发编程基础的第一章的同学应该记得,我们当时说过每个进程都有一个主线程。
之所以出现一会打印不出init,一会又能打印出init的原因是因为,mian方法所在线程处于主线程。
又因为我们对ThreadLocal的读写操作都是针对当前线程本地工作内存。所以14行的set方法只针对于主线程的本地工作内存,所以我们在第17行线程内部没有get到具体线程,因为这时我们已经位于thread1工作线程内部了。

ThreadLocal实现原理

知其然而知其所以然,现在让我们分别去看看ThreadLocal的set,get以及remove方法都是如何实现的

set
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

可以从源码中看到,首先去获取了当前线程,然后通过getMap方法去获取ThreadLocalMap,接着判断ThreadLocalMap是否不为空,如果是则直接把value设置进去,否则创建ThreadLocalMap并且把value设置进去。

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

可以看到getMap方法内部实际就是把Thread中的threadLocals返回出去。
相信此刻大家已经开始有所领悟。没错,ThreadLocal实现变量隔离的方法很精妙,直接把变量存到Thread对象中。所以本篇我们对ThreadLocal的操作实际上都是在操作Thread中的ThreadLocalMap。
ThreadLocalMap大家有兴趣可以自己去看看,其实从命名看就能看出它就是个Map结构,key是ThreadLocal,很显然这么设计是因为一个线程可能拥有多个ThreadLocal。如果你进源码中看一下,就会发现它内部中的Entry是弱引用,也就是说在内存不够时,即便还存活也会被GC掉。

get
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

看过set方法后,get方法的前面代码只要你有所基础都明白它了什么,我们直接去看看setInitialValue方法内部干了什么。

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
    if (this instanceof TerminatingThreadLocal) {
        TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
    }
    return value;
}

可以看到setInitialValue整体逻辑和set方法雷同,以至于我并不想过多的说什么。唯一可以说的一点是initialValue方法默认返回是null,但是我们可以通过继承ThreadLocal,重写initialValue方法改变默认值。 最后总结一下get方法干了什么,实际就是尝试去ThreadLocalMap中获取数据,获取不到就往ThreadLocalMap中插入默认值,key是当前ThreadLocal对象,value是默认值默认为null,然后把默认值返回出去。

remove
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null) {
        m.remove(this);
    }
}

remove方法就贴下代码不说了,不然感觉我有水字数嫌疑,良心不安。。。

InheritableThreadLocal

在部分场景下,可能需要子线程可以访问到父线程的threadLocal,这时候我们就可以使用InheritableThreadLocal了。

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    
    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

可以看到InheritableThreadLocal继承ThreadLocal,重写了ThreadLocal中的三个方法,大家需要注意InheritableThreadLocal的set,get以及remove方法中用的getMap以及createMap是自身重写的方法。

核心实现

在InheritableThreadLocal中,我们并没有看到哪里有特殊的代码逻辑实现了子线程可以访问父线程的threadLocal数据的功能,所以说明核心逻辑不在这,其实大家也能猜出,底层数据是在Thread中的,所以要实现这个功能,核心逻辑那也必然会在Thread中。

public Thread() {
    init(null, null, "Thread-" + nextThreadNum(), 0);
}

private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize, AccessControlContext acc,
                  boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    this.name = name;

    Thread parent = currentThread();
    ...
    ...
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    
}

可以看到核心逻辑就在Thread的构造方法中,在它的构造方法中会去调用init方法,在init方法中通过将当前调用线程作为当前线程的父线程,注意两者区别。
通过将父线程的inheritableThreadLocals数据复制到子线程来达到子线程访问父线程的线程本地变量。
感谢大家看完本篇文章,本篇文章参照《java并发编程之美》。