JAVA 学习笔记之ThreadLocal解析

240 阅读3分钟

ThreadLocal作用: 与synchroinzed 关键字一样。都是保证多线程下的并发访问 ThradLocal 与synchroinzed 的区别

Synchroinzed 主要是利用锁,让在某一时刻只能一个线程去访问变量

ThreadLocal 为每一个线程提供了变量的副本,每一个线程都是访问自己变量的副本,实现了线程的隔离 使用:

A、不使用ThreadLocal

public class NoThreadLocal {
    static Integer count = new Integer(1);

    /**
     * 运行3个线程
     */
    public void StartThreadArray(){
        Thread[] runs = new Thread[3];
        for(int i=0;i<runs.length;i++){
            runs[i]=new Thread(new TestTask(i));
        }
        for(int i=0;i<runs.length;i++){
            runs[i].start();
        }
    }

    /**
     *类说明:
     */
    public static class TestTask implements Runnable{
        int id;
        public TestTask(int id){
            this.id = id;
        }
        public void run() {
            System.out.println(Thread.currentThread().getName()+":start");
            count = count+id;
            System.out.println(Thread.currentThread().getName()+":"
                    +count);
        }
    }

    public static void main(String[] args){
        NoThreadLocal test = new NoThreadLocal();
        test.StartThreadArray();
    }
}

运行结果:

截图.png 此处可以看出来值是不规律运行

B、使用ThreadLocal

/**
 *类说明:演示ThreadLocal的使用
 */
public class UseThreadLocal {
   
   private static ThreadLocal<Integer> intLocal
            = new ThreadLocal<Integer>(){
        @Override
        protected Integer initialValue() {
            return 1;
        }
    };

    private static ThreadLocal<String> stringThreadLocal;

    /**
     * 运行3个线程
     */
    public void StartThreadArray(){
        Thread[] runs = new Thread[3];
        for(int i=0;i<runs.length;i++){
            runs[i]=new Thread(new TestThread(i));
        }
        for(int i=0;i<runs.length;i++){
            runs[i].start();
        }
    }
    
    /**
     *类说明:测试线程,线程的工作是将ThreadLocal变量的值变化,并写回,看看线程之间是否会互相影响
     */
    public static class TestThread implements Runnable{
        int id;
        public TestThread(int id){
            this.id = id;
        }
        public void run() {
            System.out.println(Thread.currentThread().getName()+":start");
            Integer s = intLocal.get();
            s = s+id;
            intLocal.set(s);
            System.out.println(Thread.currentThread().getName()
                    +":"+ intLocal.get());
            //intLocal.remove();
        }
    }

    public static void main(String[] args){
       UseThreadLocal test = new UseThreadLocal();
        test.StartThreadArray();
    }
}

截图.png

运行结果可以发现此时为每个线程取出其对应值了。

ThreadLocal源码解析 get()方法:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

源码可以看到里面有几个需要注意的变量

1、t.threadLocals

2、setInitialValue()

继续往下解析t.threadLocals

截图.png

由图可以看出是t.threadLocals thread里的map对象

重点解析: 什么是ThreadLocalMap?

static class ThreadLocalMap {

    /**
     * The entries in this hash map extend WeakReference, using
     * its main ref field as the key (which is always a
     * ThreadLocal object).  Note that null keys (i.e. entry.get()
     * == null) mean that the key is no longer referenced, so the
     * entry can be expunged from table.  Such entries are referred to
     * as "stale entries" in the code that follows.
     */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * The initial capacity -- MUST be a power of two.
     */
    private static final int INITIAL_CAPACITY = 16;

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;
    
    ...
    
    }

看源码变量,里面包含一个Entry 内部类,一个Entry[] table对象

每个线程里都有自己的ThreadLocalMap对象,Map里的 K便是单独的threadLocal对象

set()方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

很简单了,每次set,判断当前线程是否存在了threadLocalMap,

存在,把当前得local变量作为K传进去,不存在则创建。

每个线程都要自己得threadLocal副本,都是作为K,传进去,故相互之间不会影响,过程如下图所示

截图.png

故用TreadLocal可以实现线程隔离

那么ThreadLocal的问题也哪些呢?

问题1 :如何保证每次传进去得ThreadLocal 的Key都是唯一的? set里会自动排序

get()方法里 setInitialValue()也会调这个,故可以保证每次put的Key的HashCode都不一样

具体流程可以自己看源码,不做深究

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

问题2 :ThreadLocal 有什么不安全之处? ThreadLocal有可能会引发的内存泄漏 首先了解几个概念。

强引用:Object o = new Object();

栈上面有一个引 用指向堆,强引用不会垃圾回收

软引用:可以回收,要发生内存溢出,就进行回收

弱引用: 只要发生GC,就回收

虚引用:强度最低

看源码发现

截图.png

此时Entry 为弱引用使用

当一个线程有N个ThreadLocal时,

Map里的Key为ThreadLocal本身,发生GC回收时,这个ThreadLocal会被回收

但是Map是线程的强引用,只有线程本身被回收,整个Entry才会被回收,此时Value占用内存空间,无法回收。产生了内存泄漏

解决:调remove方法

threadLocal = null

set()与 get()会有 清除的方法,但不是每次都一定执行,故最好的避免方法调remove

问题3:为什么ThreadLocal 线程不安全?

操作的对象是静态的,会每个线程都去修改一次,导致值变化。

多个ThreadLocal存放的是同一个对象的引用。

/**
 * 类说明:ThreadLocal的线程不安全演示
 */
public class ThreadLocalUnsafe implements Runnable {

    public Number number = new Number(0);//此处是否加static为下面输出结果

    public void run() {
        //每个线程计数加一
        number.setNum(number.getNum()+1);
      //将其存储到ThreadLocal中
        value.set(number);
        SleepTools.ms(2);
        //输出num值
        System.out.println(Thread.currentThread().getName()+"="+value.get().getNum());
    }

    public static ThreadLocal<Number> value = new ThreadLocal<Number>() {
    };

    public static void main(String[] args) {
        for (int i = 0; i < 5; i++) {
            new Thread(new ThreadLocalUnsafe()).start();
        }
    }

    private static class Number {
        public Number(int num) {
            this.num = num;
        }

        private int num;

        public int getNum() {
            return num;
        }

        public void setNum(int num) {
            this.num = num;
        }

        @Override
        public String toString() {
            return "Number [num=" + num + "]";
        }
    }

}

加static 输出结果:

此时线程不安全

截图.png

去掉static ,线程安全输出结果

截图.png