【并发编程】线程安全策略

1,678 阅读11分钟

不可变对象

我们知道线程安全的问题就是出在多个线程同时修改共享变量,不可变对象的策略完全规避了对对象的修改,所以在多线程中使用一定是线程安全的。

不可变对象需要满足的条件:

  • 对象创建以后其状态就不能修改
  • 对象所有域都是final类型
  • 对象是正确创建的(在对象创建期间,this引用没有逸出)

下面来复习一下final关键字的作用

修饰类:

  • 不能被继承,final类中所有成员方法都会被隐式指定为final方法。

修饰方法:

  • 锁定方法不被继承类修改
  • 一个类的private方法会被隐式地指定为final方法。

修饰对象:

  • 基本数据类型变量:其数值初始化后无法修改
  • 引用类型变量:初始化后不能指向另外的对象(里面的值可以修改)

除了final修饰的方法来使对象不可变,还可以用Collections类中的unmodifiable为前缀的方法,包括Collection、List、Set、Map等,只需把对应集合的对象传入这个方法这个集合就不允许修改了。

同样地,在Guava中也有类似的方法immutableXXX可以达到相同的效果。

下面来验证一下

@Slf4j
public class ImmutableExample1 {

    private static Map<Integer, Integer> map = Maps.newHashMap();

    static {
        map.put(1, 2);
        map.put(3, 4);
        map.put(5, 6);
        map = Collections.unmodifiableMap(map);
    }

    public static void main(String[] args) {
        map.put(1, 3);
        log.info("{}", map.get(1));
    }
}

运行结果

可以看到程序报了一个不支持操作的异常,说明当map经过Collections.unmodifiableMap方法后就不支持更新操作了。

下面我们进入Collections.unmodifiableMap看一下它的实现

/**
 * Returns an unmodifiable view of the specified map.  This method
 * allows modules to provide users with "read-only" access to internal
 * maps.  Query operations on the returned map "read through"
 * to the specified map, and attempts to modify the returned
 * map, whether direct or via its collection views, result in an
 * <tt>UnsupportedOperationException</tt>.<p>
 *
 * The returned map will be serializable if the specified map
 * is serializable.
 *
 * @param <K> the class of the map keys
 * @param <V> the class of the map values
 * @param  m the map for which an unmodifiable view is to be returned.
 * @return an unmodifiable view of the specified map.
 */
public static <K,V> Map<K,V> unmodifiableMap(Map<? extends K, ? extends V> m) {
    return new UnmodifiableMap<>(m);
}

可以看到这个方法返回了一个新的不能被修改的map,我们来看一下这个map的实现。

/**
 * @serial include
 */
private static class UnmodifiableMap<K,V> implements Map<K,V>, Serializable {
    private static final long serialVersionUID = -1034234728574286014L;

    private final Map<? extends K, ? extends V> m;

    UnmodifiableMap(Map<? extends K, ? extends V> m) {
        if (m==null)
            throw new NullPointerException();
        this.m = m;
    }

    public int size()                        {return m.size();}
    public boolean isEmpty()                 {return m.isEmpty();}
    public boolean containsKey(Object key)   {return m.containsKey(key);}
    public boolean containsValue(Object val) {return m.containsValue(val);}
    public V get(Object key)                 {return m.get(key);}

    public V put(K key, V value) {
        throw new UnsupportedOperationException();
    }
    public V remove(Object key) {
        throw new UnsupportedOperationException();
    }
    public void putAll(Map<? extends K, ? extends V> m) {
        throw new UnsupportedOperationException();
    }
    public void clear() {
        throw new UnsupportedOperationException();
    }

    private transient Set<K> keySet;
    private transient Set<Map.Entry<K,V>> entrySet;
    private transient Collection<V> values;

    public Set<K> keySet() {
        if (keySet==null)
            keySet = unmodifiableSet(m.keySet());
        return keySet;
    }

    public Set<Map.Entry<K,V>> entrySet() {
        if (entrySet==null)
            entrySet = new UnmodifiableEntrySet<>(m.entrySet());
        return entrySet;
    }

    public Collection<V> values() {
        if (values==null)
            values = unmodifiableCollection(m.values());
        return values;
    }

    public boolean equals(Object o) {return o == this || m.equals(o);}
    public int hashCode()           {return m.hashCode();}
    public String toString()        {return m.toString();}

    // Override default methods in Map
    @Override
    @SuppressWarnings("unchecked")
    public V getOrDefault(Object k, V defaultValue) {
        // Safe cast as we don't change the value
        return ((Map<K, V>)m).getOrDefault(k, defaultValue);
    }

    @Override
    public void forEach(BiConsumer<? super K, ? super V> action) {
        m.forEach(action);
    }

    @Override
    public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V putIfAbsent(K key, V value) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean remove(Object key, Object value) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean replace(K key, V oldValue, V newValue) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V replace(K key, V value) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V computeIfPresent(K key,
            BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V compute(K key,
            BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
        throw new UnsupportedOperationException();
    }

    @Override
    public V merge(K key, V value,
            BiFunction<? super V, ? super V, ? extends V> remappingFunction) {
        throw new UnsupportedOperationException();
    }

从上面的实现中可以看到UnmodifiableMap 对于很多操作都是直接抛出不支持操作的异常。

Guava 中的immutable 方法也是类似原理。

线程封闭

线程封闭就是把对象封装到一个线程里,只有一个线程可以看到这个对象,这样就算这个对象不是线程安全也不会有线程安全问题。

实现线程封闭主要有三种方式

  • Ad-hoc线程封闭:程序控制实现,实现较复杂已弃用。
  • 堆栈封闭:能使用局部变量的地方就不使用全局变量,多线程下访问同一个方法时,方法中的局部变量都会拷贝一份到线程的栈中,也就是说每一个线程中都有只属于本线程的私有变量,因此局部变量不会被多个线程共享。
  • ThreadLocal线程封闭:使用map实现了线程封闭,map的key是线程id,map的值是封闭的对象。

下面主要来看ThreadLocal线程封闭方法。

ThreadLocal是为每一个线程都提供了一个线程内的局部变量,每个线程只能访问到属于它的副本。

我们来看一下ThreadLocal的源码中的get和set方法

     /**
     * Returns the value in the current thread's copy of this
     * thread-local variable.  If the variable has no value for the
     * current thread, it is first initialized to the value returned
     * by an invocation of the {@link #initialValue} method.
     *
     * @return the current thread's value of this thread-local
     */
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    
    /**
     * Sets the current thread's copy of this thread-local variable
     * to the specified value.  Most subclasses will have no need to
     * override this method, relying solely on the {@link #initialValue}
     * method to set the values of thread-locals.
     *
     * @param value the value to be stored in the current thread's copy of
     *        this thread-local.
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

从源码中可以看出:每一个线程拥有一个ThreadLocalMap,这个map存储了该线程拥有的所有局部变量。

set时先通过Thread.currentThread()获取当前线程,进而获取到当前线程的ThreadLocalMap,然后以ThreadLocal自己为key,要存储的对象为值,存到当前线程的ThreadLocalMap中。

get时也是先获得当前线程的ThreadLocalMap,以ThreadLocal自己为key,取出和该线程的局部变量。

一个线程内可以设置多个ThreadLocal,这样该线程就拥有了多个局部变量。比如当前线程为t1,在t1内创建了两个ThreadLocal分别是tl1tl2,那么t1ThreadLocalMap就有两个键值对。

t1.threadLocals.set(tl1, obj1) // 等价于在t1线程中调用tl1.set(obj1)
t1.threadLocals.set(tl2, obj2) // 等价于在t1线程中调用tl2.set(obj1)

t1.threadLocals.getEntry(tl1) // 等价于在t1线程中调用tl1.get()获得obj1
t1.threadLocals.getEntry(tl2) // 等价于在t1线程中调用tl2.get()获得obj2

同步容器

由于很多常见的容器都是线程不安全的,这就要求开发人员在任何访问到这些容器的地方进行同步处理,导致使用非常不便,因此Java提供了同步容器。

常见的同步容器有以下几种:

  • ArrayList -> Vector, Stack

  • HashMap -> HashTable(key,value不能为null)

  • Collections.synchronizedXXX(List, Set, Map)

注意:同步容器不是绝对的线程安全。

Vector

@Slf4j
public class VectorExample1 {

    /**
     * 请求总数
     */
    public static int clientTotal = 5000;
    /**
     * 同时并发执行线程数
     */
    public static int threadTotal = 200;

    private static Vector<Integer> list = new Vector<>();

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newCachedThreadPool();
        final Semaphore semaphore = new Semaphore(threadTotal);
        final CountDownLatch countDownLatch = new CountDownLatch(clientTotal);
        for (int i = 0; i < clientTotal; i++){
            final int count = i;
            executorService.execute(() -> {
                try {
                    semaphore.acquire();
                    update(count);
                    semaphore.release();
                } catch (Exception e){
                    log.error("exception", e);
                }
                countDownLatch.countDown();
            });
        }
        countDownLatch.await();
        executorService.shutdown();
        log.info("size:{}", list.size());
    }

    private static void update(int i){

        list.add(i);
    }
}

运行结果

在这里Vector是线程安全的。

下面来看一个线程不安全的例子

public class VectorExample2 {

    private static Vector<Integer> vector = new Vector<>();

    public static void main(String[] args) {

        for (int i = 0; i < 10; i++){
            vector.add(i);
        }

        Thread thread1 = new Thread() {
            public void run() {

                for (int i = 0; i < 10; i++){
                    vector.remove(i);
                }
            }
        };

        Thread thread2 = new Thread() {
            public void run() {

                for (int i = 0; i < 10; i++){
                    vector.get(i);
                }
            }
        };

        thread1.start();
        thread2.start();
    }
}

运行结果

可以看到抛出了数组越界的异常。这是因为thread2 中可能会get到已经被thread1移除的下标。

HashTable

@Slf4j
public class HashTableExample {

    /**
     * 请求总数
     */
    public static int clientTotal = 5000;
    /**
     * 同时并发执行线程数
     */
    public static int threadTotal = 200;

    private static Map<Integer, Integer> map = new Hashtable<>();

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newCachedThreadPool();
        final Semaphore semaphore = new Semaphore(threadTotal);
        final CountDownLatch countDownLatch = new CountDownLatch(clientTotal);
        for (int i = 0; i < clientTotal; i++){
            final int count = i;
            executorService.execute(() -> {
                try {
                    semaphore.acquire();
                    update(count);
                    semaphore.release();
                } catch (Exception e){
                    log.error("exception", e);
                }
                countDownLatch.countDown();
            });
        }
        countDownLatch.await();
        executorService.shutdown();
        log.info("size:{}", map.size());
    }

    private static void update(int i){

        map.put(i, i);
    }
}

运行结果

Collections.sync

将之前例子中的容器类改成

private static List<Integer> list = Collections.synchronizedList(Lists.newArrayList());

运行结果始终是5000,线程安全。

将容器换成SetMap也是一样。

同步容器的错误用法

public class VectorExample3 {

    private static void test1(Vector<Integer> v1) {
        for (Integer i : v1) {
            if (i.equals(3)){
                v1.remove(i);
            }
        }
    }

    private static void test2(Vector<Integer> v1) {
        Iterator<Integer> iterator = v1.iterator();
        while (iterator.hasNext()) {
            Integer i = iterator.next();
            if (i.equals(3)){
                v1.remove(i);
            }
        }
    }

    private static void test3(Vector<Integer> v1) {
        for (int i = 0; i < v1.size(); i++){
            if (v1.get(i).equals(3)) {
                v1.remove(i);
            }
        }
    }

    public static void main(String[] args) {
        Vector<Integer> vector = new Vector<>();
        vector.add(1);
        vector.add(2);
        vector.add(3);
        test1(vector);
    }

}

这里定义了3种对Vector遍历后删除指定值的方法,依次对每个方法进行测试。

测试结果:

test1test2都抛出java.util.ConcurrentModificationException异常

test3运行正常

下面来看一下异常产生的原因

从第一个报错处点进去可以看到

final void checkForComodification() {
    if (modCount != expectedModCount)
        throw new ConcurrentModificationException();
}

我们在对一个集合进行遍历操作的同时对它进行了增删的操作,导致了modCount != expectedModCount 从而抛出异常。

因此当我们用for-each迭代器遍历集合时不要对集合进行更新操作。如果需要对集合进行增删操作,推荐的做法是在遍历过程中标记好要增删的位置,遍历结束后再进行相关的操作。

并发容器

CopyOnWriteArrayList

核心思想:

  • 读写分离
  • 最终一致性
  • 通过另外开辟空间解决并发冲突

相比于ArrayListCopyOnWriteArrayList是线程安全的。

当有新元素添加到CopyOnWriteArrayList时,它先从原有的数组中拷贝一份出来,然后在新数组上做写操作,写完后再将原有的数组指向到新的数组。CopyOnWriteArrayList的整个add操作都是在锁的保护下进行的。

缺点:

  • 拷贝数组会消耗内存,元素多时可能会导致GC问题。
  • 不能用于实时读的场景。拷贝数组,新增元素都需要时间,所以调用get操作后读取到的数据可能还是旧的,CopyOnWriteArrayList只能保证最终的一致性,不能满足实时性的要求。

CopyOnWriteArrayList的读操作都是在原数组上读的,不需要加锁。

下面来coding测试一下

public class CopyOnWriteArrayListExample {

    /**
     * 请求总数
     */
    public static int clientTotal = 5000;
    /**
     * 同时并发执行线程数
     */
    public static int threadTotal = 200;

    private static List<Integer> list = new CopyOnWriteArrayList<>();

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newCachedThreadPool();
        final Semaphore semaphore = new Semaphore(threadTotal);
        final CountDownLatch countDownLatch = new CountDownLatch(clientTotal);
        for (int i = 0; i < clientTotal; i++){
            final int count = i;
            executorService.execute(() -> {
                try {
                    semaphore.acquire();
                    update(count);
                    semaphore.release();
                } catch (Exception e){
                    log.error("exception", e);
                }
                countDownLatch.countDown();
            });
        }
        countDownLatch.await();
        executorService.shutdown();
        log.info("size:{}", list.size());
    }

    private static void update(int i){

        list.add(i);
    }
}

运行结果始终是5000,线程安全。

下面我们进入CopyOnWriteArrayListadd方法看一下

/**
 * Appends the specified element to the end of this list.
 *
 * @param e element to be appended to this list
 * @return {@code true} (as specified by {@link Collection#add})
 */
public boolean add(E e) {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        Object[] elements = getArray();
        int len = elements.length;
        Object[] newElements = Arrays.copyOf(elements, len + 1);
        newElements[len] = e;
        setArray(newElements);
        return true;
    } finally {
        lock.unlock();
    }
}

可以看到整个方法是加了锁的,添加新元素时是把整个数组复制到一个新的数组中。

CopyOnWriteArraySet

HashSet对应的线程安全类。

底层实现是基于CopyOnWriteArrayList ,因此它符合CopyOnWriteArrayList的特点和适用场景。

迭代器不支持可变的remove操作。

ConcurrentSkipListSet

TreeSet对应的线程安全类。

基于Map集合,在多线程环境下addremove等操作都是线程安全的,但是批量操作如addAllremoveAll等并不能保证以原子方式执行。原因是它们的底层调用的还是addremove等方法,需要手动做同步操作。

不能存储null值。

ConcurrentHashMap

HashMap的线程安全类。

不能存储null值。

对读操作做了大量优化,后面会详细介绍。

ConcurrentSkipListMap

TreeMap的线程安全类。

内部使用SkipList来实现。

key有序,相比于ConcurrentHashMap支持更高并发,存取数与线程没有关系,也就是说在相同条件下并发线程越多ConcurrentSkipListMap优势越大。

安全共享对象策略总结

  • 1.线程限制:一个被线程限制的对象,由线程独占,并且只能被占有它的线程修改。
  • 2.共享只读:一个共享只读的对象,在没有额外同步的情况下,可以被多个线程访问,但任何线程都不能修改它。
  • 3.线程安全对象:一个线程安全对象或者容器,在内部通过同步机制来保证线程安全,所以其他线程无需额外的同步就可以通过公共接口随意访问它。
  • 4.被守护对象:被守护对象只能通过获取特定的锁来访问。

Written by Autu

2019.7.19