从源码实现理解ThreadLocal和InheritableThreadLocal

138 阅读7分钟

文章目录


1. 简介

线程同步除了使用锁机制外,还可以使用ThreadLocal来避免线程不安全的出现。ThreadLocal提供线程本地变量,如果创建一个ThreadLocal变量,那么访问这个变量的每个线程都会有一个该变量的一个副本。在实际的多线程操作时,每个线程操作的都是自己本地内存中的变量,彼此之间不会出现干扰,从而避免了线程安全问题的出现。


在这里插入图片描述


2. 案例

public class ThreadLocalDemo {
    private static ThreadLocal<String> local = new ThreadLocal<>();

    public static void main(String[] args) {
        new Thread(() -> {
           local.set("local value1");
                print("thread1");
                System.out.println("after remove:" + local.get());
        }).start();


        new Thread(() -> {
            // 设置当前线程中本地内存的变量值
            local.set("local value2");
            print("thread2");
            System.out.println("after remove:" + local.get());
        }).start();
    }

    public static void print(String str){
        // 获取当前线程它本地内存中的变量值
        System.out.println(str + ":" + local.get());
        // 清除本地内存中的本地变量
        local.remove();
    }
}
thread1:local value1
thread2:local value2
after remove:null
after remove:null

3. Thread类

Thread的类图如下所示:


在这里插入图片描述

Thread的源码定义为:

public 
class Thread implements Runnable {
    
	/* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
}

Thread类的成员变量中定义了两个ThreadLocal.ThreadLocalMap类型的变量threadLocals和inheritableThreadLocals,并且初始化都是null,只有当第一次调用ThreadLocal的set()get()时才会创建。threadLocals用于存放每个线程的本地变量,set()会将value添加到调用线程的threadLocals中,通过get()可以获取保存的变量。只要调用线程不终止,threadLocals中的变量会一直存在,直到调用remove()主动的删除变量。


4. ThreadLocal类

下面我们接着看一下ThreadLocal类的源码,重点关注其中比较重要的几个方法:

public class ThreadLocal<T> {
    protected T initialValue() {
        return null;
    }
    
     public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
    
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    
    public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }
    
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    
     void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
}

而其中使用的ThreadLocalMap是ThreadLocal中定义的静态内部类。

4.1 set()

ThreadLocal中set()的源码为:

 public void set(T value) {
     // 获取当前线程
     Thread t = Thread.currentThread();
     // 获取Thread的成员变量ThreadLocalMap对象
     ThreadLocalMap map = getMap(t);
     // 如果当前map为null,则将value保存在map中
     if (map != null)
         map.set(this, value);
     else
         // 否则,懒加载模式创建map,然后将当前线程变量和value放到map中
         createMap(t, value);
 }

其中getMap()的实现为:

ThreadLocalMap getMap(Thread t) {
    // 获取线程自己的threadLocals变量,并绑定到调用线程的threadLocals上
    return t.threadLocals;
}

方法的返回值就是每个线程的threadLocals变量,它就是ThreadLocalMap类型的对象。如果map不为空,调用set()直接存放到threadLocals中,key就是当前线程的引用,值为传入的value;否则调用createMap(),它的源码为:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

它会新建一个ThreadLocalMap对象,并将其作为当前调用线程的threadLocals变量的初始化值。

4.2 get()

get()的源码为:

public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的threadLocals变量
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 获取当前线程的Entry对象
        ThreadLocalMap.Entry e = map.getEntry(this);
        // 如果Entry对象不为null,则获取保存的值
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 否则初始化当前线程的threadLocals变量
    return setInitialValue();
}

setInitialValue()的源码实现为:

private T setInitialValue() {
    // threadLocals初始为null
    T value = initialValue();
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 查找当前线程引用对应的线程变量
    ThreadLocalMap map = getMap(t);
    // 如果threadLocals不为null,直接添加本地变量
    // key为当前线程引用,value为本地变量值
    if (map != null)
        map.set(this, value);
    else	
        // 创建对应的threadLocals变量
        createMap(t, value);
    return value;
}
protected T initialValue() {
    return null;
}

4.3 remove()

remove()的源码为:

public void remove() {
    //获取当前线程绑定的threadLocals
    ThreadLocalMap m = getMap(Thread.currentThread());
    // 如果线程的threadLocals变量不为null,则直接删除 
    if (m != null)
        m.remove(this);
}

4.4 不支持继承性

同一个ThreadLocal变量在父线程中被设置值后,在子线程中是获取不到的

public class ThreadLocalDemo2 {
    private static ThreadLocal<String> local = new ThreadLocal<>();

    public static void main(String[] args) {
        local.set("main value");
        new Thread(() -> 
            System.out.println(Thread.currentThread().getName() + " : " + local.get())
        ).start();

        System.out.println(Thread.currentThread().getName() + " : " + local.get());
    }
}
main : main value
Thread-0 : null

从输出中可以看出,主线程和自定义线程各自本地内存中的变量是不同的,这再一次验证了ThreadLocal的原理。


5. InheritableThreadLocal类

InheritableThreadLocal类实现了子线程可以访问父线程的本地变量,首先改变下上面的代码,看下效果:

public class ThreadLocalDemo3 {
    private static InheritableThreadLocal<String> local = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        local.set("main value");
        new Thread(() ->
            System.out.println(Thread.currentThread().getName() + " : " + local.get())
        ).start();

        System.out.println(Thread.currentThread().getName() + " : " + local.get());
    }
}
main : main value
Thread-0 : main value

那它是如何实现子线程来访问父线程的本地变量呢?首先看下InheritableThreadLocal类的源码:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

InheritableThreadLocal继承了ThreadLocal类,并添加了自己的三个方法childvalue()getMap()createMap()

5.1 createMap()

createMap()方法这里new的同样是ThreadLocalMap对象,但获取的对象是作为inheritableThreadLocals的初始值。

5.2 getMap()

getMap()这里获取的的就是当前线程的inheritableThreadLocals变量。

5.3 childvalue()

childvalue()来获取父线程的本地变量值。那它是如何实现的呢?理解实现的过程,首先需要看Thread类中的init()的源码 :

 private void init(ThreadGroup g, Runnable target, String name,long stackSize) {
     init(g, target, name, stackSize, null, true);
 }
private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    this.name = name;
	// 获取当前线程
    Thread parent = currentThread();
    // 安全校验
    SecurityManager security = System.getSecurityManager();
    if (g == null) {
        if (security != null) {
            g = security.getThreadGroup();
        }
        if (g == null) {
            g = parent.getThreadGroup();
        }
    }


    g.checkAccess();

    if (security != null) {
        if (isCCLOverridden(getClass())) {
            security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
        }
    }

    g.addUnstarted();

    this.group = g;
    this.daemon = parent.isDaemon();
    this.priority = parent.getPriority();
    if (security == null || isCCLOverridden(parent.getClass()))
        this.contextClassLoader = parent.getContextClassLoader();
    else
        this.contextClassLoader = parent.contextClassLoader;
    this.inheritedAccessControlContext =
        acc != null ? acc : AccessController.getContext();
    this.target = target;
    setPriority(priority);
    // 如果父线程的inheritableThreadLocal不为null
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        // 设置子线程中的inheritableThreadLocals为父线程的inheritableThreadLocals
        this.inheritableThreadLocals =
        ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

    this.stackSize = stackSize;

    tid = nextThreadID();
}

如果父线程inheritableThreadLocal不为null,则通过createInheritedMap()将父线程的inheritableThreadLocals作为构造函数的参数创建一个新的ThreadLocalMap变量,然后赋给子线程。

static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}

private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];

    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                //调用重写的方法
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

在构造函数中将父线程的inheritableThreadLocals成员变量的值赋值到新的ThreadLocalMap对象中,返回之后赋值给子线程的inheritableThreadLocals。总之,InheritableThreadLocals类通过重写getMap和createMap两个方法将本地变量保存到了具体线程的inheritableThreadLocals变量中,当线程通过InheritableThreadLocals实例的set或者get方法设置变量的时候,就会创建当前线程的inheritableThreadLocals变量。而父线程创建子线程的时候,ThreadLocalMap中的构造函数会将父线程的inheritableThreadLocals中的变量复制一份到子线程的inheritableThreadLocals变量中。


6. ThreadLocalMap

ThreadLocalMap的类图如下所示:
在这里插入图片描述

实现源码:

static class ThreadLocalMap {
        static class Entry extends WeakReference<ThreadLocal<?>> {
            // value就是和ThreadLocal绑定的
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        ...
}

ThradLocalMap类内部又定义了一个静态内部类Entry,它类似于Map集合中的Entry对象,类型为键值对形式,真正起到了ThradLoacalMap存放数据的能力。Entry本身继承了弱引用WeakReference类,因此键值对的key为ThreadLocal的弱引用,value为ThreadLocal的set()传入的value。
在这里插入图片描述

WeakReference的源码为:

public class WeakReference<T> extends Reference<T> {

    public WeakReference(T referent) {
        super(referent);
    }

    public WeakReference(T referent, ReferenceQueue<? super T> q) {
        super(referent, q);
    }
}

而WeakReference又继承了Reference,它是一个抽象类,所有类型的引用都是它的子类。

弱引用的特点简单来说就是发现即回收,这就导致了ThreadLocal可能会造成内存泄漏,这如何理解呢?ThreadLocal存放数据依赖于ThreadLocalMap这个内部类,而ThreadLocalMap又依赖于它的静态内部类Entry,Entry对象的键为当前TreadLocal的弱引用this,值为通过set()传入的value。因此,当Java虚拟机发现了这个弱引用后,垃圾收集器执行GC时就会对其进行回收,而这个弱引用和value并不是同生共死的。虽然key被回收了,但只要当前线程一直存在且没有调用ThreadLocal的remove(),value就会一直存在于内存空间中,但是失去了key再也无法获取到它,此时就发生了内存泄漏。

关于不同引用之间的区别,可查看:垃圾回收相关概念 + 引用分析


7. 参考

Java虚拟机中的垃圾回收和引用分析
Java中的ThreadLocal详解
一针见血 ThreadLocal