ThreadLocal那些不为人知的细节

937 阅读14分钟

大家好,我是阿轩。

今天我们来剖析一下ThreadLocal的源码。

说到ThreadLocal,我们在日常的开发工作中用的还是挺多的。

比如,用户登录的时候我们可以通过ThreadLocal把用户的信息保存起来,而不用在每次使用的时候再去查一遍。

Spring中的声明式事务也是通过ThreadLocal来保存数据库的链接,从而使多条SQL语句使用的是同一个数据库链接,保证事务。

好了,话不多说,我们开始。

引言

首先看下ThreadLocal的整体结构

在Thread类中保存了一个ThreadLocalMap的变量

/* ThreadLocal values pertaining to this thread. This map is maintained
 * by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocalMapThreadLocal的内部类,底层数据结构是一个数组

/**
 * The table, resized as necessary.
 * table.length MUST always be a power of two.
 */
private Entry[] table;

元素是Entry类,这个类又是ThreadLocalMap的内部类,继承了WeakReference

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

可以看到,Entry中k是弱引用,也就是ThreadLocal,而value仍然是强引用,我们通常所说的内存泄漏原因也就在这个地方,后面再说。

好了,ThreadLocal的整体结构我们介绍完了,下面我们开始看他的核心方法。

set

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

首先是set方法,我们在使用ThreadLocal的时候,肯定是先存然后再取,所以我们先看看他是怎么存的。

首先调用getMap方法获取当前线程所保存的ThreadLocalMap

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

这个刚刚说过,Thread类里保存了一个ThreadLocalMap的变量

如果为空先进行初始化

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

调用ThreadLocalMap的构造方法

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

这里我们就可以看出,ThreadLocalMap的底层数据结构是数组

首先构造了一个默认长度16Entry数组,然后计算数组下标

private final int threadLocalHashCode = nextHashCode();

private static AtomicInteger nextHashCode = new AtomicInteger();

private static final int HASH_INCREMENT = 0x61c88647;

private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

ThreadLocal的hash值是一个叫threadLocalHashCode的变量,调用的是nextHashCode方法,这个方法又是调用一个AtomicInteger静态实例的getAndAdd方法。

注意,这个nextHashCode变量是静态的,也就是说,每次新建一个ThreadLocal实例,他的hashcode都是在之前的基础上再加HASH_INCREMENT的。

下面来看看HASH_INCREMENT这个变量,值是0x61c88647,转换成10进制就是1640531527

看到这里,想必小伙伴们很自然的很有一个疑问,为什么每次hashcode都是在之前基础上再加一个这个值呢?

我们先来看一个小实验

public static void main(String[] args) {
    int a = 0x61c88647;
    int len = 16;
    for (int i = 1; i < len + 1; i++) {
        System.out.println(i + " " + ((a*i) & (len-1)));
    }
}

这段程序是模拟连续创建16个ThreadLocal实例,他的下标分布情况,我们看看结果如何

居然没有一个下标重复的,再试下长度为32看看

一样,没有一个下标重复,是不是很神奇

这里面其实是蕴含了一些数学原理的,我们先看下这个数字是怎么来的

把上面的公式变形一下,(long)((1<<31) * (Math.sqrt(5)-1)/2 * 2);

(Math.sqrt(5)-1)/2这个值是什么?

数字比较好的小伙伴可能立马就想到了,这不就是我们在初中学习的黄金分割吗,0.618

所以,为什么每次hashcode递增1640531527,求出来的下标会均匀分布,原因就在这里,感兴趣的小伙伴可以去研究一下。

我们继续往下看,在初始化完成之后会调用setThreshold方法设置扩容阈值

private void setThreshold(int len) {
    threshold = len * 2 / 3;
}

这里的阈值和HashMap不太一样,HashMap是设置的3/4倍,他这里是2/3

在第一次初始化之后,第二次调用的时候就会调用ThreadLocalMapset方法

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

HashMap一样,循环遍历数组,找出符合条件的key,nextIndex是获取数组的下一位

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

因为数组是有界的,所以当遍历超过数组范围时会重新回到0下标位。

循环中有2个判断,第一个判断key是否相等,如果相等直接覆盖value值。

第二个判断k是否为空,如果为空,替换当前数组位的值。

这里注意了,当前索引位Entrykey为null,但是value是不为null的,这里说下前面提到的内存泄漏问题。

在Java中引用分为4种,强软弱虚四大法王,强引用就是我们日常工作时用到的引用,比如User a = new User(),a就是强引用,软引用使用SoftReference包装,弱引用使用WeakReference包装,虚引用使用PhantomReference包装。

虚引用一般是用来链接堆外对象,通过虚引用实现对堆外内存的回收弱引用每次发生GC的时候会被回收,而软引用只有在内存不足的时候才会被回收。

ThreadLocalMap中Entry的key就是通过弱引用修饰的,所以每次发生GC时会被回收掉,导致key变成null,而value强引用,不会被回收,但是此时的value已经没有了任何意义,只是白白占着内存,所以也就导致了这部分内存不能被正常使用,造成内存泄漏

好了,我们继续往下看。

其实,从这里就可以看出,ThreadLocal处理hash碰撞是使用的线性探测法,就是如果计算出的索引位被别人占用了,那么就看下一位有没有被占用,一直找到没被占用的或者key为null的。

看下他的替换方法replaceStaleEntry

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

这个方法有点长,我们一点点细看。

先看第一个for循环,我们现在知道当前数组下标位的key是为null的,待会是要被回收的,那么,我顺带看看前面还有没有Entry的key是null的,如果有的话那我就一并回收了岂不是更好。所以,这个for的作用就是向前遍历,如果找到key==null的,记录下位置,赋值给slotToExpunge。当遇到Entry为null时停下来,否则一直向前遍历,遍历到第一个元素时,会跳到数组的末尾继续往前遍历。

这里可能有小伙伴会想了,如果我一直没遇到Entry为null的,会不会又遍历回自己了?

显然,是不会的。

忘记了吗,当数组的元素个数达到一定的值时是会扩容的,所以,数组中始终会有一些下标位是为null的。

再看第二个for循环,这次是向后开始遍历,如果找到满足条件的key,那么就覆盖value,将当前索引位元素和staleSlot索引位元素替换下,画个图理解一下

因为staleSlot索引位key为null,待会要被清理掉,所以把他和覆盖完value值的i位替换下。然后判断之前向前遍历的时候有没有找到key为null的,如果没找到,就将开始清理的位置设置为i,否则从之前找到的索引位开始清理。

ThreadLocal清理的方法有2个,先看里面那个,expungeStaleEntry方法

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

因为清理是从staleSlot开始,所以上来就把staleSlot位的元素清空了。

然后向后遍历,遇到key为null的直接清空掉。

如果不为null,就计算下标位,如果发现计算出来的下标位不是自己现在的位置,那么就说明当初set的时候,计算出来的索引位被占用了,被迫向后遍历了。

那么,把当前i位设置为null

为什么设置为null呢?

因为此时已经清理掉了一些key为null的元素,当初占用他位置的元素此时很有可能被清理掉了,所以他要去夺回属于自己的东西,邪笑。

紧接着,他就开始循环遍历,从计算出的h位开始寻找,一直找到空余的位置为止。

最后返回i,注意了,这个i位元素是为null的。

回到外层,再次进行一次清理,调用cleanSomeSlots方法

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

这里注意一下,因为当前i位元素是为null的,所以开始遍历的时候是从下一位开始遍历的,如果发现key为null的,再次调用之前的expungeStaleEntry方法开始清理。

如果没找到key为null的,那么会循环log2^n次,找到了重新赋值n = len,再次循环log2^n次。

再次回到外层方法,如果key不符合条件,那么判断key是否为null,为null再判断之前向前遍历的时候有没有发现key为null的Entry,没发现就设置开始清理的位置。

一直遍历到元素为null,如果都没有找到符合条件的就跳出循环。新增一个Entry插入到staleSlot位置。因为之前循环的时候没有找到符合条件的key,没有进行清理工作,所以此时会进行清理工作。和之前循环调用的方法cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)一样。

回到一开始的set方法,如果循环中没有找到符合条件的key,也没有找到key为null的,那么就会构造一个Entry元素赋值到i位置上。

一般新增一个元素后都会判断是否需要扩容,所以此时同样会判断扩容,但是扩容之前会进行一次随机清理,如果正巧清理了key为null的元素,那么因为清理了元素,所以数组个数减少了,也就不用再判断扩容了,如果没有清理到,此时判断是否超过阈值,超过了进行扩容。调用rehash方法

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

在进行真正的扩容之前会把数组全部遍历一遍,清理key为null的元素,expungeStaleEntries这个方法

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

可以看到,这里把数组从头到尾遍历了一遍,发现key为null的就调用expungeStaleEntry进行清理。

清理之后判断是否超过阈值,这里把阈值减小了,减到原来的3/4

这里作者可能考虑到,清理之后,如果元素数量还超过阈值的3/4,那么过不了多久肯定又会超过2/3,与其那个时候再扩容不如现在提前扩容算了。

调用resize进行扩容

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

这个方法比较简单,就是数组容量扩大一倍,然后把老数组的元素转移到新数组上,

到这里set方法我们就剖析完了,下面我们看get方法

get

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

get方法相对而言简单一些,首先获取当前线程的ThreadLocalMap变量,如果为null,调用setInitialValue初始化

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

initialValue方法返回的是个null,然后调用前面说的createMap方法进行初始化

如果ThreadLocalMap变量不为null,调用getEntry获取Entry元素。

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

这里如果计算出来的i索引位满足就返回,否则调用getEntryAfterMiss方法

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

如果e为null说明已经发生GC被回收掉了,返回null。否则,从i开始往后遍历,满足条件就返回,为null就清理,一直到e==null或者找到符合条件的为止。

最后回到get方法判断有没有找到符合条件的Entry,找到就返回,没找到继续调用setInitialValue方法,将当前ThreadLocal实例作为key,null作为value,构造一个Entry插入到数组中。

最后看下remove方法

reove

public void remove() {
   ThreadLocalMap m = getMap(Thread.currentThread());
   if (m != null)
       m.remove(this);
}

调用ThreadLocalMapremove方法

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

计算出当前ThreadLocal实例所在的i索引位,判断此位置的key是否是自己,是,就删除,然后调用expungeStaleEntry方法看看能不能清理掉一些元素,然后返回。

实战应用

在我们日常的工作中,线上出现问题的话需要去排查,而现在微服务盛行,许多项目都由传统的单体式拆分成了微服务,经常客户端一个请求过来会经过好几个系统,这个时候为了追踪整个链路的调用情况,我们通常会创建一个traceId,贯穿整个调用链路,这样,我们在查日志的时候就可以通过这个traceId将整个调用过程串联起来。

但是为了提高系统的快速响应能力,我们经常会创建线程池来进行异步执行,这个时候traceId就会断掉,如果恰巧是线程池执行出现了错误,那么就无法跟踪到了。

这个时候ThreadLocal就派上用场了。

在日志框架slf4j里有一个叫MDC的类,通过他就可以实现我们需要的功能。

我们先看一下正常情况下调用的过程。

/**
 * @author 程序员阿轩
 */
@Component
public class WebFilter extends GenericFilterBean {
    Logger logger = LoggerFactory.getLogger(WebFilter.class);

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) {
        System.out.println("WebFilter doFilter-----------");
        try {
            HttpServletRequest request = (HttpServletRequest) servletRequest;
            String traceId = request.getHeader(TraceConstants.X_COMMON_TRACE_ID);
            if (StrUtil.isBlank(traceId)) {
                traceId = TraceUtils.newTraceId();
            }
            TraceContext.setTraceId(traceId);
            System.out.println("WebFilter traceId ->" + traceId);
            filterChain.doFilter(servletRequest, servletResponse);
        } catch (Throwable e) {
            
        } finally {
            TraceContext.clear();
        }
    }
}

首先请求来到过滤器,我们在这里给他设置一个traceId

/**
 * @author 程序员阿轩
 */
public class TraceContext {
    private TraceContext() {
    }

    public static String getTraceId() {
        return MDC.get(TraceConstants.X_COMMON_TRACE_ID);
    }

    public static void setTraceId(String traceId) {
        MDC.put(TraceConstants.X_COMMON_TRACE_ID, traceId);
    }

    public static Map<StringStringgetContextMap() {
        return MDC.getCopyOfContextMap();
    }

    public static void setContextMap(Map<StringString> contextMap) {
        if (contextMap == null) {
            contextMap = new HashMap<>();
        }
        MDC.setContextMap(contextMap);
    }

    public static void clear() {
        MDC.remove(TraceConstants.X_COMMON_TRACE_ID);
    }

    public static void clearAll() {
        MDC.clear();
    }
}

接着请求来到controller

/**
 * @author 程序员阿轩
 */
@RestController
public class TraceController {
    @Autowired
    private TestService testService;
    
    @GetMapping("/trace")
    public String trace() {
        System.out.println("main->" + Thread.currentThread().getName());
        testService.test();
        return "程序员阿轩";
    }
}

service类

/**
 * @author 程序员阿轩
 */
@Service
public class TestService {
    @Async("ecsAsyncExecutor")
    public void test() {
        System.out.println("线程池中线程->" + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));
        try {
            Thread.sleep(500000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

线程池配置类

/**
 * @author 程序员阿轩
 */
@Configuration
public class AsyncExecutorConfig implements AsyncConfigurer {
    private static final Logger LOGGER = LoggerFactory.getLogger(AsyncExecutorConfig.class);

    private final TaskExecutionProperties properties;

    public AsyncExecutorConfig(TaskExecutionProperties properties) {
        this.properties = properties;
    }

    @Override
    @Bean("ecsAsyncExecutor")
    public ThreadPoolTaskExecutor getAsyncExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor() {
            @Override
            public <T> Future<T> submit(Callable<T> task) {
                return super.submit(task);
            }

            @Override
            public void execute(Runnable task) {
                super.execute(task);
            }
        };
        executor.setCorePoolSize(properties.getPool().getCoreSize());
        executor.setMaxPoolSize(properties.getPool().getMaxSize());
        executor.setQueueCapacity(properties.getPool().getQueueCapacity());
        executor.setAllowCoreThreadTimeOut(properties.getPool().isAllowCoreThreadTimeout());
        executor.setKeepAliveSeconds((int) properties.getPool().getKeepAlive().getSeconds());
        executor.setThreadNamePrefix(properties.getThreadNamePrefix());
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        executor.initialize();
        return executor;
    }
}

这里主要为了演示,一些异常异常捕捉什么的就省掉了。

yaml配置

spring:
  task:
    execution:
      pool:
        allow-core-thread-timeout: true
        core-size: 1
        max-size: 5
        queue-capacity: 3
        keep-alive: 60s
      thread-name-prefix: a-xuan

运行程序,打印结果

WebFilter doFilter-----------
WebFilter traceId ->08dd98a2-7343-4695-b509-2103bda6f7ef
main->http-nio-9050-exec-1
线程池中线程->a-xuan1---null

可以看到,线程池中的线程获取traceId为null,获取不到。

我们稍微改造下线程池的配置类

/**
 * @author 程序员阿轩
 */
@Configuration
public class AsyncExecutorConfig implements AsyncConfigurer {
    private static final Logger LOGGER = LoggerFactory.getLogger(AsyncExecutorConfig.class);

    private final TaskExecutionProperties properties;

    public AsyncExecutorConfig(TaskExecutionProperties properties) {
        this.properties = properties;
    }

    @Override
    @Bean("ecsAsyncExecutor")
    public ThreadPoolTaskExecutor getAsyncExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor() {
            @Override
            public <T> Future<T> submit(Callable<T> task) {
                return super.submit(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
            }

            @Override
            public void execute(Runnable task) {
                super.execute(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
            }
        };
        executor.setCorePoolSize(properties.getPool().getCoreSize());
        executor.setMaxPoolSize(properties.getPool().getMaxSize());
        executor.setQueueCapacity(properties.getPool().getQueueCapacity());
        executor.setAllowCoreThreadTimeOut(properties.getPool().isAllowCoreThreadTimeout());
        executor.setKeepAliveSeconds((int) properties.getPool().getKeepAlive().getSeconds());
        executor.setThreadNamePrefix(properties.getThreadNamePrefix());
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        executor.initialize();
        return executor;
    }
}

我们把需要执行的任务包装一层

/**
 * @author 程序员阿轩
 */
public class ThreadMdcUtil {
    public static <T> Callable<T> wrap(final Callable<T> callable, final Map<StringString> context) {
        return new Callable<T>() {
            @Override
            public T call() throws Exception {
                if (context == null) {
                    MDC.clear();
                } else {
                    MDC.setContextMap(context);
                }
                System.out.println("wrap: " + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));
                
                try {
                    return callable.call();
                } finally {
                    MDC.clear();
                }
            }
        };

    public static Runnable wrap(final Runnable runnable, final Map<StringString> context) {
        return () -> {
            if (context == null) {
                MDC.clear();
            } else {
                MDC.setContextMap(context);
            }
//            System.out.println("wrap: " + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));

            try {
                runnable.run();
            } finally {
                MDC.clear();
            }
        };
    }
}

再次执行看下打印结果

WebFilter doFilter-----------
WebFilter traceId ->655795e7-a58e-408b-9758-e65c04aa4e4a
main->http-nio-9050-exec-1
submit->http-nio-9050-exec-1
wrap: a-xuan1---655795e7-a58e-408b-9758-e65c04aa4e4a
线程池中线程->a-xuan1---655795e7-a58e-408b-9758-e65c04aa4e4a

可以看到此时线程池中的线程拿到了traceId,从而完成了链路追踪的功能。

下面我们简单看下MDC是怎么实现这个功能的。

我们看下刚刚使用到的put和get方法

public static void put(String key, String val) throws IllegalArgumentException {
    if (key == null) {
        throw new IllegalArgumentException("key parameter cannot be null");
    } else if (mdcAdapter == null) {
        throw new IllegalStateException("MDCAdapter cannot be null. See also http://www.slf4j.org/codes.html#null_MDCA");
    } else {
        mdcAdapter.put(key, val);
    }
}
public static String get(String key) throws IllegalArgumentException {
    if (key == null) {
        throw new IllegalArgumentException("key parameter cannot be null");
    } else if (mdcAdapter == null) {
        throw new IllegalStateException("MDCAdapter cannot be null. See also http://www.slf4j.org/codes.html#null_MDCA");
    } else {
        return mdcAdapter.get(key);
    }
}

可以看到,MDC只是个门面,真正发挥作用的是MDCAdapter这个东西。

public interface MDCAdapter {
    void put(String var1, String var2);

    String get(String var1);

    void remove(String var1);

    void clear();

    Map<StringStringgetCopyOfContextMap();

    void setContextMap(Map<StringString> var1);
}

MDCAdapter实际上是一个接口,功能由他的子类来实现

现在我们日志框架通常使用的都是LogBack,我们看下LogBack的实现

public void put(String key, String val) throws IllegalArgumentException {
    if (key == null) {
        throw new IllegalArgumentException("key cannot be null");
    } else {
        Map<StringString> oldMap = (Map)this.copyOnThreadLocal.get();
        Integer lastOp = this.getAndSetLastOperation(1);
        if (!this.wasLastOpReadOrNull(lastOp) && oldMap != null) {
            oldMap.put(key, val);
        } else {
            Map<StringString> newMap = this.duplicateAndInsertNewMap(oldMap);
            newMap.put(key, val);
        }

    }
}

可以看到,核心是一个Map的变量copyOnThreadLocal,从名字其实已经能够看出来了

final ThreadLocal<Map<StringString>> copyOnThreadLocal = new ThreadLocal();
private static final int WRITE_OPERATION = 1;
private static final int MAP_COPY_OPERATION = 2;
final ThreadLocal<Integer> lastOperation = new ThreadLocal();

没错,他就是一个ThreadLocal,所有的一切都是围绕着这个ThreadLocal来进行的。

总结

本篇文章从ThreadLocal的源码剖析说到他在实际工作中的使用,其实小伙伴们可以发现,很多技术的底层都是我们熟悉的东西,只不过经过了层层包装,穿上了各种各样华丽的马甲之后,我们不认识他了,但是当你一步步去深究,像洋葱一样一层一层剥开他的时候,最后,你会情不自禁的感叹一句,哦---,原来如此,soga。