Java 多线程进阶-01:ThreadLocal

3 阅读5分钟

前言

ThreadLocal 是 Java 中的一个线程本地变量工具类,它提供了一种线程级别的数据隔离机制,使得每个线程都可以拥有自己独立的变量副本,避免了线程间共享变量带来的线程安全问题。具体来说,ThreadLocal 允许你创建的变量在每个线程中都有其自己的副本,每个线程都可以独立地修改自己的副本,而不会影响到其他线程。

ThreadLocal 主要用于解决线程范围内的数据共享问题,例如在多线程环境下实现用户身份认证、数据库连接管理、线程池任务分配等场景下,可以使用 ThreadLocal 来保存每个线程的相关数据,避免了使用全局变量或者在方法间传递参数的方式,使得代码更加简洁清晰。基本用法参考如下:

// 创建 ThreadLocal 实例
ThreadLocal<String> threadLocal = new ThreadLocal<>();

// 在不同线程中设置和获取
threadLocal.set("主线程的值");
new Thread(() -> {
    threadLocal.set("子线程的值");
    System.out.println(threadLocal.get());  // 输出:子线程的值
}).start();

System.out.println(threadLocal.get());  // 输出:主线程的值

本文主要介绍ThreadLocal的底层原理以及基本使用方法。

原理

Thread 成员变量 ThreadLocalMap

每个Thread线程内部存在成员变量ThreadLocalMap,不同线程在使用线程局部变量 ThreadLocal 的时候,其实是将同一个 ThreadLocal 实例变量作为各自ThreadLocalMap 的key,然后设置不同的value值。ThreadLocalMap 结构可以参考如下:

// Thread 类中有个 ThreadLocalMap 字段
public class Thread implements Runnable {
    // 每个线程都有自己的 ThreadLocalMap
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

// ThreadLocalMap 是一个定制化的 HashMap
static class ThreadLocalMap {
    // Entry 继承自 WeakReference(弱引用)
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;  // 存储的实际值
        Entry(ThreadLocal<?> k, Object v) {
            super(k);  // 弱引用指向 ThreadLocal
            value = v;
        }
    }
    
    private Entry[] table;
    private int size = 0;
}

在使用ThreadLocalMap 的时候,大致的内存结构图可以参考如下:

Thread-1              Thread-2              Thread-3
  │                     │                     │
  ▼                     ▼                     ▼
ThreadLocalMap       ThreadLocalMap       ThreadLocalMap
  ├─ Entry[0]          ├─ Entry[0]          ├─ Entry[0]
  ├─ Entry[1]          ├─ Entry[1]          ├─ Entry[1]
  └─ ...               └─ ...               └─ ...

ThreadLocalMap的关键方法

set() 方法原理

public void set(T value) {
    Thread t = Thread.currentThread();
    // 获取当前线程的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    
    if (map != null) {
        // 以当前 ThreadLocal 实例为 key 存储值
        map.set(this, value);
    } else {
        // 首次调用,创建 ThreadLocalMap
        createMap(t, value);
    }
}

// ThreadLocalMap.set() 核心逻辑
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    
    // 计算索引(优化过的 hash 算法)
    int i = key.threadLocalHashCode & (len-1);
    
    // 线性探测解决哈希冲突
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        
        if (k == key) {  // 找到相同的 key,更新值
            e.value = value;
            return;
        }
        
        if (k == null) {  // 遇到过期 Entry(key 被回收)
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    
    // 创建新 Entry
    tab[i] = new Entry(key, value);
    int sz = ++size;
    
    // 清理过期 Entry 并检查是否需要扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

get() 方法原理

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    
    if (map != null) {
        // 以当前 ThreadLocal 为 key 获取 Entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    
    // 如果 map 不存在或值不存在,返回初始值
    return setInitialValue();
}

// ThreadLocalMap.getEntry()
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    
    if (e != null && e.get() == key) {
        return e;  // 直接命中
    } else {
        // 线性探测查找
        return getEntryAfterMiss(key, i, e);
    }
}

remove() 方法

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null) {
        // 移除当前 ThreadLocal 对应的 Entry
        m.remove(this);
    }
}

// ThreadLocalMap.remove()
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // 清除弱引用
            e.clear();
            // 清理过期 Entry
            expungeStaleEntry(i);
            return;
        }
    }
}

使用示例

ThreadLocal 的常用方法包括:

  • set(T value):设置当前线程的局部变量值。
  • get():获取当前线程的局部变量值。
  • remove():移除当前线程的局部变量值。

使用 ThreadLocal 时需要注意以下几点:

  1. 内存泄漏:由于 ThreadLocal 中使用了 ThreadLocalMap 来保存每个线程的局部变量值,如果没有及时清理,会导致内存泄漏问题。因此,使用完毕后应该调用 remove() 方法进行清理。
  2. 线程安全:虽然 ThreadLocal 可以解决多线程环境下的数据共享问题,但并不是线程安全的,需要开发者自行保证线程安全性。
  3. 初始化:如果需要对 ThreadLocal 变量进行初始化,可以通过重写 initialValue() 方法或者使用 withInitial() 方法来实现。

以下是一个使用 ThreadLocal 的简单示例,展示了如何在多线程环境下保存和获取线程局部变量:

public class ThreadLocalExample {
​
    // 创建一个 ThreadLocal 变量,用于保存线程局部变量
    private static ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 0);
​
    public static void main(String[] args) {

        // 创建两个线程并启动
        Thread t1 = new Thread(() -> {
            // 设置线程局部变量的值
            threadLocal.set(1);
            // 打印线程局部变量的值
            System.out.println("Thread 1 - Local variable: " + threadLocal.get());

        });

        Thread t2 = new Thread(() -> {
            // 设置线程局部变量的值
            threadLocal.set(2);
            // 打印线程局部变量的值
            System.out.println("Thread 2 - Local variable: " + threadLocal.get());

        });​

        t1.start();
        t2.start();
    }
}

在这个示例中,我们创建了一个 ThreadLocal 变量 threadLocal,并使用 withInitial() 方法设置了初始值为 0。然后创建了两个线程 t1t2,每个线程都设置了 threadLocal 的值,并打印了它的值。由于 threadLocal 是线程局部变量,因此每个线程都可以独立地修改和访问它,不会相互影响。

注意事项

内存泄漏

// 潜在的内存泄漏场景
public class MemoryLeakDemo {
    static class BigObject {
        byte[] data = new byte[1024 * 1024];  // 1MB
        
        @Override
        protected void finalize() throws Throwable {
            System.out.println("BigObject 被回收");
        }
    }
    
    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<BigObject> threadLocal = new ThreadLocal<>();
        
        new Thread(() -> {
            threadLocal.set(new BigObject());
            // 线程结束,但 ThreadLocalMap 中的 Entry 可能未清理
        }).start();
        
        System.gc();
        Thread.sleep(1000);
    }
}

线程池内线程重复使用局部变量

// 线程池中 ThreadLocal 的内存泄漏
public class ThreadPoolProblem {
    private static ThreadLocal<String> userContext = new ThreadLocal<>();
    private static ExecutorService executor = Executors.newFixedThreadPool(2);
    
    public static void main(String[] args) throws InterruptedException {
        // 提交10个任务
        for (int i = 0; i < 10; i++) {
            final int taskId = i;
            executor.execute(() -> {
                // 设置 ThreadLocal 值
                userContext.set("User_" + taskId);
                
                try {
                    // 模拟业务处理
                    System.out.println("处理任务 " + taskId + 
                                     ", 用户: " + userContext.get());
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    // 必须清理!否则下一个任务会看到旧值
                    userContext.remove();
                }
            });
        }
        
        executor.shutdown();
        executor.awaitTermination(1, TimeUnit.SECONDS);
    }
}