ThreadLocal用法及原理

avatar
SugarTurboS Club @SugarTurboS
  • 苏格团队
  • 作者:ZhiCheng

与Synchonized的对照: ThreadLocal和Synchonized都用于解决多线程并发訪问。可是ThreadLocal与synchronized有本质的差别。synchronized是利用锁的机制,使变量或代码块在某一时该仅仅能被一个线程訪问。而ThreadLocal为每个线程都提供了变量的副本,使得每个线程在某一时间訪问到的并非同一个对象,这样就隔离了多个线程对数据的数据共享。而Synchronized却正好相反,它用于在多个线程间通信时可以获得数据共享。

Synchronized用于线程间的数据共享,而ThreadLocal则用于线程间的数据隔离。


一、用法

把要线程隔离的数据放进ThreadLocal

static ThreadLocal<T> threadLocal = new ThreadLocal<T>() {
	protected T initialValue() {
		这里一般new一个对象返回
    }
}

线程获取相关数据的时候只要

threadLocal.get();

想修改、赋值只要

threadLocal.set(val)

二 、使用场景

如上面说到的,ThreadLocal是用于线程间的数据隔离,ThreadLocal为每个线程都提供了变量的副本。

  • 举例1:联想一下服务器(例如tomcat)处理请求的时候,会从线程池中取一条出来进行处理请求,如果想把每个请求的用户信息保存到一个静态变量里以便在处理请求过程中随时获取到用户信息。这时候可以建一个拦截器,请求到来时,把用户信息存到一个静态ThreadLocal变量中,那么在请求处理过程中可以随时从静态ThreadLocal变量获取用户信息。

  • 举例2:Spring的事务实现也借助了ThreadLocal类。Spring会从数据库连接池中获得一个connection,然会把connection放进ThreadLocal中,也就和线程绑定了,事务需要提交或者回滚,只要从ThreadLocal中拿到connection进行操作。

三、原理分析

1、get()方法

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {//当map已存在
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();//初始化值
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

上面先取到当前线程,然后调用getMap方法获取对应的ThreadLocalMap,ThreadLocalMap是ThreadLocal的静态内部类,然后Thread类中有一个这样类型成员,所以getMap是直接返回Thread的成员

ThreadLocal.ThreadLocalMap threadLocals = null;

来看下ThreadLocal的内部类ThreadLocalMap源码,留个大致印象

static class ThreadLocalMap {
	private static final int INITIAL_CAPACITY = 16;//初始数组大小
    private Entry[] table;//每个可以拥有多个ThreadLocal
    private int size = 0;
    private int threshold;//扩容阀值
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
 
    private Entry getEntry(ThreadLocal<?> key) {
        int i = key.threadLocalHashCode & (table.length - 1);
        Entry e = table[i];
        if (e != null && e.get() == key)
            return e;
        else
            return getEntryAfterMiss(key, i, e);
    }
    
    private void set(ThreadLocal<?> key, Object value) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
            if (k == key) {
                e.value = value;
                return;
            }
            if (k == null) {
            		//循环利用key过期的Entry
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
	...
}

可以看到有个Entry内部静态类,它继承了WeakReference,总之它记录了两个信息,一个是ThreadLocal<?>类型,一个是Object类型的值。getEntry方法则是获取某个ThreadLocal对应的值,set方法就是更新或赋值相应的ThreadLocal对应的值。里面涉及到扩容策略、Entry哈希冲突、循环利用等等不再深入,留个大致印象就好。

回顾下get()方法中的代码

if (map != null) {
    ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
        @SuppressWarnings("unchecked")
        T result = (T)e.value;
        return result;
    }
}
return setInitialValue();

map为null或e为null就会走到setInitialValue,如果我们是第一次get()方法,那map会是空的,所以接下来先看setInitialValue()方法

private T setInitialValue() {
		//调用我们实现的方法得到需要线程隔离的值
    T value = initialValue();
    Thread t = Thread.currentThread();
    //拿到相应线程的ThreadLocalMap成员变量
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

上面initialValue就是实例化ThreadLocal要实现的方法,这里又取了线程的ThreadLocalMap,不为空就把值set进去(键为TreadLocal本身,值就是initialValue返回的值);为空就创建一个map同时添加一个值进去,最后返回value。

map.set(this, value)这句代码在上面的ThreadLocalMap源码中可以看到大致流程,下面看看createMap()做了什么事

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}


ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    	//创建一个Entry,加入数组
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

可以看到在new ThreadLocalMap之后,就会创建一个Entry加入到数组中,最后把ThreadLocalMap的引用赋值给Thread的threadLocals成员变量

在回顾下get()方法中的代码

if (map != null) {
    ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
        @SuppressWarnings("unchecked")
        T result = (T)e.value;
        return result;
    }
}
return setInitialValue();

现在map不会为空了,再次调用get方法就会调用map的getEntry方法(上面的ThreadLocalMap源码中可以看到大致流程),拿到相应的Entry,然后就可以拿到相应的值返回出去

2、set方法

分析完get()方法,那么set()方法就自然而然的明白了,就不再赘述

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

总结

  • 原理

    ThreadLocal的实现原理是,在每个线程中维护一个Map,键是ThreadLocal类型,值是Object类型。当想获取ThreadLocal的值时,就从当前线程中拿出Map,然后在把ThreadLocal本身作为键从Map中拿出值返回。

  • 优缺点

    **优点:**提供线程内的局部变量。每个线程都自己管理自己的局部变量,互不影响

    **缺点:**内存泄漏问题。可以看到ThreadLocalMap中的Entry是继承WeakReference的,其中ThreadLocal是以弱引用形式存在Entry中,如果ThreadLocal在外部没有被强引用,那么垃圾回收的时候就会被回收掉,又因为Entry中的value是强引用,就会出现内存泄漏。虽然ThreadLocal源码中的会对这种情况进行了处理,但还是建议不需要用TreadLocal的时候,手动调remove方法。