线程局部变量

144 阅读10分钟

摘要

  本文旨在介绍多线程编程中ThreadThreadLocal以及ThreadLocalMap三者之间的关系,并进一步解析InheritableThreadLocal源码实现。在这之后,我们将分析错误的使用示例,并探讨TransmittableThreadLocal及其上下文传播机制。

 

ThreadLocal

  Thread 对象内部包含 ThreadLocalMap 类型的threadLocals字段,ThreadLocalMap类型是一个Map结构,键是ThreadLocal引用,值是线程局部变量。

  其实,threadLocals字段实际是由 ThreadLocal 负责维护的,我们写入或读取线程局部变量都是通过操作 ThreadLocal 类实例。

public class Thread implements Runnable {
    // threadLocals字段由ThreadLocal维护
    ThreadLocal.ThreadLocalMap threadLocals = null;

    // threadLocals字段在线程实际退出时被清理
    private void exit() {
        // ......
        threadLocals = null;
        // ......
    }
}
public class ThreadLocal<T> {
    // 获取当前线程的threadLocals
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    // 创建当前线程的threadLocals
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    // 设置值
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    // 获取值
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    // 移除值
    public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
    }
}

 

InheritableThreadLocal

  InheritableThreadLocalThreadLocal 的一个子类,它拓展增加了线程局部变量的继承机制。当创建一个新的线程时,如果父线程中存在 InheritableThreadLocal 变量,那么在子线程启动时,子线程会获得这些变量的初始副本。

public class Thread implements Runnable {
    // threadLocals字段由ThreadLocal维护
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
    // 无参构造
    public Thread() {
        init(null, null, "Thread-" + nextThreadNum(), 0);
    }
    // 复用init方法
    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize) {
        init(g, target, name, stackSize, null, true);
    }
    // 初始化新线程时,从当前线程复制inheritableThreadLocals属性
    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        // ...
        Thread parent = currentThread();
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                    ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        // ...
    }
    // inheritableThreadLocals字段在线程实际退出时被清理
    private void exit() {
        // ...
        inheritableThreadLocals = null;
        // ...
    }
}
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    protected T childValue(T parentValue) {
        return parentValue;
    }
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

 

错误示例

  请阅读以下提供的代码示例,推断程序预期的输出结果,并解释原因。

public class SimpleThreadLocalPollutionExample {

    private final static ThreadLocal<String> USERNAME_THREADLOCAL = ThreadLocal.withInitial(() -> null);

    private static void login(String username, String password) {
        if ("123456".equals(password)) {
            USERNAME_THREADLOCAL.set(username);
        }
    }

    private static String getCurrentUsername() {
        return USERNAME_THREADLOCAL.get();
    }

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        // 模拟用户排队登陆  key:用户名  value:用户输入的密码
        Map<String, String> userList = MapUtil.<String, String>builder()
                .put("zhangsan", "123456")
                .put("lisi", "123456")
                .put("wangwu", "111111").build();
        for (Map.Entry<String, String> entry : userList.entrySet()) {
            executorService.submit(() -> {
                // 模拟登陆
                login(entry.getKey(), entry.getValue());
                String username;
                if ((username = getCurrentUsername()) != null) {
                    // 登陆成功后打印用户姓名
                    System.out.println("hello, " + username);
                }
            });
        }
    }
}

  未及时调用 ThreadLocal 的 remove() 方法会导致内存泄漏和线程局部变量污染。

  内存泄漏是由于 ThreadLocal 与 ThreadLocalMap 中的值保持强引用,阻碍垃圾回收。而线程局部变量污染则常常发生在线程池场景下,后续任务会读取到旧的脏数据。

public class InheritableThreadLocalExample {

    // 定义一个 InheritableThreadLocal 变量
    private static final InheritableThreadLocal<String> CONTEXT_HOLDER = new InheritableThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        // 创建子线程并启动
        Thread childThread = new Thread(() -> {
            System.out.println("Child thread value: " + CONTEXT_HOLDER.get());
        });
        // 在父线程中设置 InheritableThreadLocal 的值
        CONTEXT_HOLDER.set("Parent thread value");
        childThread.start();
        // 等待子线程结束
        childThread.join();
    }
}

  对于InheritableThreadLocal而言,子线程创建后,副本就已完成同步,父线程的任何更改都不会影响到子线程中的值。因此,线程池一般使用TransmittableThreadLocal传递上下文。

 

TransmittableThreadLocal

  TransmittableThreadLocal ,以下简称 ttl, 是阿里巴巴开源的 transmittable-thread-local 库中的一部分,它扩展了 Java 的 InheritableThreadLocal 类。普通的ThreadLocal变量在线程之间是不共享的,每个线程都有自己独立的副本。而 ttl 使 ThreadLocal 变量可以在线程之间共享,。

  为了确保上下文信息能够正确地从一个线程传递到另一个线程,TTL在使用时需要进行装饰。常见方法的有三种,示例代码如下:

/**
 * TtlRunnable 装饰 Runnable 或 Callable
 */
TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();
ttl.set("hello, ttl");
Executor executor = Executors.newFixedThreadPool(1);
executor.execute(TtlRunnable.get(() -> System.out.println(ttl.get())));
/**
 * TtlExecutors装饰线程池
 */
TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();
ttl.set("hello, ttl");
Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(1));
ttlExecutor.execute(() -> System.out.println(ttl.get()));

  除此之外,还可以使用 Java Agent 直接装饰线程池,无侵入业务代码。不过要注意,多个Agent使用时可能会失效。

-javaagent:xx/xx/transmittable-thread-local-版本号.jar

TIPS: 这里穿插介绍装饰器模式,装饰器模式是动态地给一个对象添加一些额外的职责,侧重于扩展、增强功能。


 

  读懂TTL源码的关键就在于理解它的两个关键机制:一是 上下文的存储与读取;二是上下文在不同线程间传递,接下来我们逐一进行解析。

上下文的存储与读取方式
public final T get() {
    T value = super.get();
    // disableIgnoreNullValueSemantics表示是否忽略控制,默认是false。如果忽略空值并且value不为null,那么就调用addThisToHolder()方法保存值
    if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
    return value;
}
public final void set(T value) {
    // disableIgnoreNullValueSemantics表示是否忽略控制,默认是false。如果忽略空值并且value为null,那么就调用remove()方法删除值
    if (!disableIgnoreNullValueSemantics && null == value) {
        // may set null to remove value
        remove();
    } else {
        super.set(value);
        addThisToHolder();
    }
}
public final void remove() {
    removeThisFromHolder();
    super.remove();
}

  阅读TTL的get()set()remove() 方法,发现相关增删逻辑都在 addThisToHolder()removeThisFromHolder()这两个方法,我们接着往下看。

private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}
private void removeThisFromHolder() {
    holder.get().remove(this);
}
// Note about the holder:
// 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
// 2. The type of value in the holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
//    2.1 but the WeakHashMap is used as a *Set*:
//        the value of WeakHashMap is *always* null, and never used.
//    2.2 WeakHashMap support *null* value.
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
    new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
        @Override
        protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
            return new WeakHashMap<>();
        }

        @Override
        protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
            return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
        }
    };

  显然,上下文数据都是存储在holder实例中,holder是一个泛型为WeakHashMap<TransmittableThreadLocal, ?>的InheritableThreadLocal,并且重写了InheritableThreadLocal的initialValue()和childValue()方法。

  WeakHashMap是一个key为弱引用的map,ttl把WeakHashMap当做一个set使用,key为当前的ttl,value则固定为null。使用WeakHashMap是为了避免内存泄漏的问题。

上下文在不同线程间传递
这部分代码比较复杂,我们需要先介绍 `TransmittableThreadLocal`类中的三个核心类 `Snapshot``Transmitter``Transmittee` 的概念。它们其实都是用来管理holder的,具体说明如下:
/**
 * 数据快照,作为一种封装结构
 */
private static class Snapshot {
    //   
    final HashMap<Transmittee<Object, Object>, Object> transmittee2Value;

    public Snapshot(HashMap<Transmittee<Object, Object>, Object> transmittee2Value) {
        this.transmittee2Value = transmittee2Value;
    }
}
/**
 * 发送方,负责holder的备份、加载、清空、恢复
 */
public static class Transmitter {

    // 使用弱一致性的并发安全容器CopyOnWriteArraySet来管理接收方
    private static final Set<Transmittee<Object, Object>> transmitteeSet = new CopyOnWriteArraySet<>();

    static {
        registerTransmittee(ttlTransmittee);
        registerTransmittee(threadLocalTransmittee);
    }
    /**
     * 注册接收方
     */
    public static <C, B> boolean registerTransmittee(@NonNull Transmittee<C, B> transmittee) {
        return transmitteeSet.add((Transmittee<Object, Object>) transmittee);
    }
    /**
     * 备份,其实就是调用transmitteeSet容器中所有接收方的capture()进行备份,再封装成Snapshot类型返回
     */
    public static Object capture() {
        final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(transmitteeSet.size());
        for (Transmittee<Object, Object> transmittee : transmitteeSet) {
            try {
                transmittee2Value.put(transmittee, transmittee.capture());
            } catch (Throwable t) {
                if (logger.isLoggable(Level.WARNING)) {
                    logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
                            "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                }
            }
        }
        return new Snapshot(transmittee2Value);
    }
    /**
     * 加载,其实就是调用transmitteeSet容器中所有接收方的replay()进行加载,再封装成Snapshot类型返回
     */
    public static Object replay(@NonNull Object captured) {
        final Snapshot capturedSnapshot = (Snapshot) captured;

        final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(capturedSnapshot.transmittee2Value.size());
        for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) {
            Transmittee<Object, Object> transmittee = entry.getKey();
            try {
                Object transmitteeCaptured = entry.getValue();
                transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured));
            } catch (Throwable t) {
                if (logger.isLoggable(Level.WARNING)) {
                    logger.log(Level.WARNING, "exception when Transmitter.replay for transmittee " + transmittee +
                            "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                }
            }
        }
        return new Snapshot(transmittee2Value);
    }
    /**
     * 清空,其实就是调用transmitteeSet容器中所有接收方的clear()进行清空,再封装成Snapshot类型返回
     */
    public static Object clear() {
        final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap<>(transmitteeSet.size());
        for (Transmittee<Object, Object> transmittee : transmitteeSet) {
            try {
                transmittee2Value.put(transmittee, transmittee.clear());
            } catch (Throwable t) {
                if (logger.isLoggable(Level.WARNING)) {
                    logger.log(Level.WARNING, "exception when Transmitter.clear for transmittee " + transmittee +
                            "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                }
            }
        }
        return new Snapshot(transmittee2Value);
    }
    /**
     * 恢复,其实就是利用先前的备份backup,再调用transmitteeSet容器中所有接收方的restore()进行恢复
     */
    public static void restore(@NonNull Object backup) {
        for (Map.Entry<Transmittee<Object, Object>, Object> entry : ((Snapshot) backup).transmittee2Value.entrySet()) {
            Transmittee<Object, Object> transmittee = entry.getKey();
            try {
                Object transmitteeBackup = entry.getValue();
                transmittee.restore(transmitteeBackup);
            } catch (Throwable t) {
                if (logger.isLoggable(Level.WARNINGa {
                    logger.log(Level.WARNING, "exception when Transmitter.restore for transmittee " + transmittee +
                            "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
                }
            }
        }
    }
}

  可以看到Transmitter类 实际是负责管理 transmitteeSet 容器,它的所有方法都间接调用了容器中的方法,主线程的局部变量就存储在 transmitteeSet 中。

  那么接下来,就需要看看容器元素类型 Transmittee类 和 容器中唯一的两个元素 ttlTransmittee类threadLocalTransmittee类 的结构了。这两个元素分别存储 TransmittableThreadLocal局部变量threadLocalTransmittee局部变量

public interface Transmittee<C, B> {
    C capture();
    B replay(@NonNull C captured);
    B clear();
    void restore(@NonNull B backup);
}
private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
    new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
        // 备份,使用迭代器遍历holder,并调用copyValue()方法,这个方法在ttl中的实现只是复制引用
        @Override
        public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
            final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>(holder.get().size());
            for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        }
        // 清空,复用replay()方法,使用空集合进行加载
        @Override
        public HashMap<TransmittableThreadLocal<Object>, Object> clear() {
            return replay(new HashMap<>(0));
        }
        // 恢复,使用迭代器遍历、清除holder。而后调用setTtlValuesTo()方法为线程局部变量写入值
        @Override
        public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
            // 触发自定义钩子函数,可拓展
            doExecuteCallback(false);
            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();
                // clear the TTL values that is not in backup
                // avoid the extra TTL values after restore
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }
            // restore TTL values
            setTtlValuesTo(backup);
        }
        // 加载,先使用迭代器遍历holder,以填充backup用于返回备份。而后调用setTtlValuesTo()方法为线程局部变量写入值
        @Override
        public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
            final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>(holder.get().size());
            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();
                // backup
                backup.put(threadLocal, threadLocal.get());
                // clear the TTL values that is not in captured
                // avoid the extra TTL values after replay when run task
                if (!captured.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }
            // set TTL values to captured
            setTtlValuesTo(captured);
            // call beforeExecute callback
            doExecuteCallback(true);
            return backup;
        }
    };

private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

  容器中的另一个实例 threadLocalTransmittee 结构也是类似的,迭代遍历后进行读写操作。

  熟悉核心类的这些职责后,我们开始逐步分析之前TtlExecutors装饰线程池的流程。

Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(1));
public final class TtlExecutors {
    // 省略部分代码......

    public static Executor getTtlExecutor(Executor executor) {
        //  如果 ttl的java agent被加载 || 装饰对象为null || 对象已被装饰,那么直接返回
        if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
            return executor;
        }
        // 使用 ExecutorTtlWrapper 装饰,增强其功能
        return new ExecutorTtlWrapper(executor, true);
    }
}

class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    // 省略部分代码......

    @Override
    public void execute(@NonNull Runnable command) {
        // 装饰线程池,实际上还是通过 TtlRunnable 装饰 Runnable 或 Callable
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }
}

    观察发现,装饰线程池,实际上还是使用 TtlRunnable 装饰 Runnable 或 Callable。我们继续往里深入看看 TtlRunnable 的逻辑。

public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    // 省略部分代码......
    
    @Override
    public void run() {
        // <1> 获取主线程的数据快照
        final Object captured = capturedRef.get();
        // run后释放线程ttl局部变量
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
        // <2> 将快照加载到执行线程
        final Object backup = replay(captured);
        try {
            // 执行runnable任务
            runnable.run();
        } finally {
            // <3> 根据backup备份,恢复数据
            restore(backup);
        }
    }
}

  显然,TtlRunnableRunnable 上进行了增强,它有三个主要功能:获取主线程的数据快照将快照加载到执行线程根据backup备份,恢复数据

  其中,获取主线程数据快照的相关代码如下。

// 1.1 数据快照
private final AtomicReference<Object> capturedRef;

public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        if (null == runnable) return null;
        // 参数idempotent用于避免重复包装  true:无需装饰  false:需要装饰
        if (runnable instanceof TtlEnhanced) {
            if (idempotent) return (TtlRunnable) runnable;
            else throw new IllegalStateException("Already TtlRunnable!");
        }
        // 1.2 调用 TtlRunnable 的私有构造方法进行装饰
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // 1.3 创建数据快照,其实就是Transmitter#capture方法进行备份
        this.capturedRef = new AtomicReference<>(capture());
        // 省略部分代码......
}

  其余两个功能则是直接调用的 Transmitter 的静态方法。replay() 方法负责在 run() 方法调用前,将局部变量从主线程加载到执行线程。restore() 方法则是恢复数据,常用于线程池复用等情况,避免对后续任务产生影响。