用过的举手🙌:Java多线程条件通行工具CyclicBarrier和CountDownLatch | 训练营第三期

937 阅读8分钟

本文已参与掘金创作者训练营第三期,详情查看:掘力计划|创作者训练营第三期正在进行,「写」出个人影响力

📖前言

方向不对努力白费

平常你也付出了很多的时间,但就是没有得到多少收益。就像有时候很多小伙伴问我,我是该怎么学一个我没接触过的内容。我的个人经验非常建议,先不要学太多理论性的内容,而是尝试实际操作下,把要学的内容做一些Demo案例出来。这有点像你买了个自行车是先拆了学学怎么个原理,还是先骑几圈呢?哪怕摔了跟头,但那都是必须经历后留下的经验。

今天博主将为大家分享一下使用 Java多线程条件通行工具 CyclicBarrierCountDownLatch,不喜勿喷,如有异议欢迎讨论!

✨概述

CyclicBarrier 是一个同步工具类,可以翻译成循环屏障,也叫障碍器或同步屏障。

CyclicBarrier 内部有一个计数器 count,调用障碍器的 await 方法会使计数器 count 的值减一,当计数器 count 的值为 0 时,表明调用了 await 方法线程已经达到了设置的数量。

当障碍器的屏障被打破后,会重置计数器,因此叫做循环屏障。

💕比较 CountDownLatchCyclicBarrier

  • CountDownLatch 的作用其实就是一个计数器,当阻塞的线程数达到 CountDownLatch 设置的临界值后,CountDownLatch 将会唤醒阻塞的线程,并且后面将失效不再阻塞线程,因此 CountDownLatch 也可以理解为一次性的障碍器

  • 相比较 CountDownLatch , CyclicBarrier 可以设置条件线程 barrierCommand 时,并且 CyclicBarrier 的是循环屏障,CyclicBarrier 只要内部不发生异常,是可以通过重置计数器来重复使用的。

  • CountDownLatchCyclicBarrier 的区别,其实比较简单,CountDownLatch 对执行任务的线程。比如有A,B,C三个线程,那么如果A是设置了 countDown.await() 这个方法,那么B和C就只能等到A了,什么时候A准备好了,那么B和C才可以执行.

  • 典型的场景就是,比如A需要从远程服务器拿到某个资源,然后B和C的执行,需要依赖A这个资源,但是又可能,A执行的时候,拿到的资源还是空的或者没有就绪,那么只有通过在主线程里,另外起一个线程,比如叫D吧,他循环判断A请求的资源是否就绪,而在主线程里在B和C执行以前,统一设置一个 countDown.await()方法。只有当D确认资源状态以后,并且调用了 countDown.countDown(2) 这里是2是假设的,到0,就表示OK了,那么被等待的方法才能继续执行下去

  • CyclicBarrier 呢,A B C三个都要准备好了,不然没有办法继续下去.

🎇原理

  • 障碍器内部有一个 ReentrantLock 变量 lock(显式锁),还有通过该 显式锁lock 获得的 Condition变量trip。在线程里调用 障碍器await方法,而在 await方法 内部调用了 dowait方法dowait方法 使用了显式锁变量 lock ),在 dowait方法 内部会根据计数器 count 判断,如果 count 不等于0,将会调用 Condition 变量 tripawait 方法,也就是说实际上障碍器的 await 方法是通过 Condition 变量 tripawait() 方法阻塞了所有的进行到这里的线程, 每个线程执行 await 方法都会令计数器 count 减一,当 count 值为 0 时,然后会调用 Condition 变量 tripsignalAll 方法,唤醒所有阻塞的线程。

🚀条件通行工具的作用

  • 设置一个屏障(也可以叫同步点),当某一个线程到达屏障后会阻塞该线程,只有当到达屏障的线程数达到临界值parties后,那些在屏障处被阻塞的线程才被唤醒继续执行

  • 可以在屏障处设置一个待执行的线程A,当所有线程到达屏障时,会执行线程A,然后打开屏障让哪些被阻塞的线程继续执行。这里容易有一个误解就是,并不是要等到线程A执行结束后,被阻塞的线程才继续执行,如果线程A中调用了wait()、yield方法,此时被阻塞的线程可以不必等到线程A执行完毕就能继续执行,而如果线程A调用了sleep方法,被阻塞的线程仍然需要等到线程A执行完毕后才能继续执行。

🙌CyclicBarrier 的作用

线程进入等待后,需要达到一定数量的等待线程后,再一次性开启放行。

CyclicBarrier 创建了一个栅栏,其维护了一组线程,所有在线程的 run方法 内部执行了某个CyclicBarrier对象的await()方法的任务都会在执行一次之后处于等待状态,当 CyclicBarrier 所维护的所有线程的任务都执行过一次之后(所有线程都处于等待状态),CyclicBarrier 会使其重复的再次都执行一次,如此反复。在某次执行过程中,如果有某个任务还处于执行状态,那么已经处于等待状态的线程将等待所有任务都执行完成。

  • CyclicBarrier(int, Runnable) 构造方法,参数1为通行所需的线程数量,参数2为条件满足时的监听器。
  • int await()/int await(long, TimeUnit) 线程进入等待,并返回一个进入等待时的倒计索引。
  • int getParties() 通行所需的线程数量
  • int getNumberWaiting() 当前线程数量
  • boolean isBroken() 本批次是否已经终止
  • eset() 释放本批次通行,并重新接收下一批线程进入。

例子1:主线程创建了若干子线程,主线程需要等待这若干子线程结束后才结束。

例子2:线程有若干任务,分多个线程来完成,需要等待这若干任务被完成后,才继续运行处理。

源码如下:

        /**
     * @since 1.5
     * @see CountDownLatch
     *
     * @author Doug Lea
     */
    public class CyclicBarrier {
    
        public CyclicBarrier(int parties) {
            this(parties, null);
        }
    
        public CyclicBarrier(int parties, Runnable barrierAction) {
            if (parties <= 0)
                throw new IllegalArgumentException();
            this.parties = parties;
            this.count = parties;
            this.barrierCommand = barrierAction;
        }
    
        private static class Generation {
            boolean broken = false;
        }
    
        private final ReentrantLock lock = new ReentrantLock();
        private final Condition trip = lock.newCondition();
        private final int parties;
        private final Runnable barrierCommand;
        private Generation generation = new Generation();
    
        private int count;
    
        private void nextGeneration() {
            // 通知本批次所有线程可通行
            trip.signalAll();
            // 重置计数器
            count = parties;
            // 重建批次对象,即不同批次使用不同对象
            generation = new Generation();
        }
    
        private void breakBarrier() {
            // 标示本批次已终止
            generation.broken = true;
            // 重置计数器
            count = parties;
            // 通知本批次所有线程可通行
            trip.signalAll();
        }
    
        public int getParties() {
            // 返回通行所需的线程数量
            return parties;
        }
    
        public int await() throws InterruptedException, BrokenBarrierException {
            try {
                return dowait(false, 0L);
            } catch (TimeoutException toe) {
                throw new Error(toe);
            }
        }
    
        public int await(long timeout, TimeUnit unit) throws InterruptedException,
                BrokenBarrierException, TimeoutException {
            return dowait(true, unit.toNanos(timeout));
        }
    
        private int dowait(boolean timed, long nanos) throws InterruptedException,
                BrokenBarrierException, TimeoutException {
            final ReentrantLock lock = this.lock;
            // 进入同步
            lock.lock();
            try {
                final Generation g = generation;
    
                // 如果本批次已终止,则抛出异常
                if (g.broken)
                    throw new BrokenBarrierException();
    
                // 如果线程已终止,则终止本批次
                if (Thread.interrupted()) {
                    breakBarrier();
                    throw new InterruptedException();
                }
    
                // 更新计数器
                int index = --count;
                // 判断是否达到可释放的线程数量
                if (index == 0) {
                    // 观察监听器是否正常运行结束
                    boolean ranAction = false;
                    try {
                        // 执行监听器
                        final Runnable command = barrierCommand;
                        if (command != null)
                            command.run();
                        // 标记正常运行
                        ranAction = true;
                        // 通知所有线程并重置
                        nextGeneration();
                        // 返回索引
                        return 0;
                    } finally {
                        // 如果监听器是运行时异常结束,则终止本批次
                        if (!ranAction)
                            breakBarrier();
                    }
                }
    
                for (;;) {
                    // 进入等待或计时等待
                    try {
                        if (!timed)
                            trip.await();
                        else if (nanos > 0L)
                            nanos = trip.awaitNanos(nanos);
                    } catch (InterruptedException ie) {
                        if (g == generation && !g.broken) {
                            breakBarrier();
                            throw ie;
                        } else {
                            Thread.currentThread().interrupt();
                        }
                    }
    
                    if (g.broken)
                        throw new BrokenBarrierException();
    
                    // 如果已经换批,则返回索引退出
                    if (g != generation)
                        // 返回索引
                        return index;
    
                    // 如果超时,则 止本批次
                    if (timed && nanos <= 0L) {
                        breakBarrier();
                        throw new TimeoutException();
                    }
                }
            } finally {
                // 退出同步
                lock.unlock();
            }
        }
    
        public boolean isBroken() {
            final ReentrantLock lock = this.lock;
            lock.lock();
            try {
                // 返回本批次是否已经终止
                return generation.broken;
            } finally {
                lock.unlock();
            }
        }
    
        public void reset() {
            final ReentrantLock lock = this.lock;
            lock.lock();
            try {
                // 终止本批次
                breakBarrier();
                // 开始下一批
                nextGeneration();
            } finally {
                lock.unlock();
            }
        }
    
        public int getNumberWaiting() {
            final ReentrantLock lock = this.lock;
            lock.lock();
            try {
                // 返回本批次等待中的线程数量
                return parties - count;
            } finally {
                lock.unlock();
            }
        }
    }

给个栗子,代码如下:

    package com.test;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    import java.util.concurrent.BrokenBarrierException;
    import java.util.concurrent.CyclicBarrier;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    import java.util.concurrent.TimeUnit;
    
    /**
     * 
     * @Description: 赛马示例:将每次马的前进(0-2步)看作一次重复,当某匹马已经前进过一次之后,其必须等待,其他所有的马都前进过一次才能再次前进,当有一匹马到达终点(FINISH_LINE)时,游戏结束
     * @ClassName: HorseRace.java
     * @author ChenYongJia
     * @Date 2019年4月17日 晚上23:25
     * @Email chen87647213@163.com
     */
    public class HorseRace {
    
    	static final int FINISH_LINE = 75;
    	private List<Horse> horses = new ArrayList<>();
    	private ExecutorService exec = Executors.newCachedThreadPool();
    	private CyclicBarrier barrier;
    
    	public HorseRace(int nHorse, final int pause) {
    		barrier = new CyclicBarrier(nHorse, () -> {
    			StringBuilder s = new StringBuilder();
    			for (int i = 0; i < FINISH_LINE; i++) {
    				s.append("=");
    			}
    			System.out.println(s);
    			horses.forEach(horse -> System.out.println(horse.tracks()));
    			for (Horse horse : horses) {
    				if (horse.getStrides() >= FINISH_LINE) {
    					System.out.println(horse + " won!");
    					exec.shutdownNow();
    					return;
    				}
    				try {
    					TimeUnit.MICROSECONDS.sleep(pause);
    				} catch (InterruptedException e) {
    					System.out.println("barrier-action sleep interrupted");
    				}
    			}
    		});
    		for (int i = 0; i < nHorse; i++) {
    			Horse horse = new Horse(barrier);
    			horses.add(horse);
    			exec.execute(horse);
    		}
    	}
    
    	public static void main(String[] args) {
    		int nHorses = 7;
    		int pause = 20;
    		new HorseRace(nHorses, pause);
    	}
    }
    
    class Horse implements Runnable {
    	private static int counter = 0;
    	private final int id = counter++;
    	private int strides = 0;
    	private static Random random = new Random(47);
    	private static CyclicBarrier barrier;
    
    	public Horse(CyclicBarrier barrier) {
    		this.barrier = barrier;
    	}
    
    	public synchronized int getStrides() {
    		return strides;
    	}
    
    	@Override
    	public void run() {
    		try {
    			while (!Thread.interrupted()) {
    				synchronized (this) {
    					strides += random.nextInt(3);
    				}
    				barrier.await(); // 使当前线程处于等待状态,当barrier中所有线程的任务都完成(处于等待状态时)又开始执行
    			}
    		} catch (InterruptedException e) {
    		} catch (BrokenBarrierException e) {
    			throw new RuntimeException(e);
    		}
    	}
    
    	@Override
    	public String toString() {
    		return "Horse " + id + " ";
    	}
    
    	public String tracks() {
    		StringBuilder s = new StringBuilder();
    		for (int i = 0; i < getStrides(); i++) {
    			s.append("*");
    		}
    		s.append(id);
    		return s.toString();
    	}
    
    } 

🎊CountDownLatch 的作用

线程进入等待后,需要计数器达到0才能通行。

  • CountDownLatch(int) 构造方法,指定初始计数。
  • await() 等待计数减至0。
  • await(long, TimeUnit) 在指定时间内,等待计数减至0。
  • countDown() 计数减1。
  • getCount() 获取剩余计数。

例子1:主线程创建了若干子线程,主线程需要等待这若干子线程结束后才结束。

例子2:线程有若干任务,分多个线程来完成,需要等待这若干任务被完成后,才继续运行处理。

源码如下:

    /**
     * @since 1.5
     * @author Doug Lea
     */
    public class CountDownLatch {
    
        private final Sync sync;
    
        public CountDownLatch(int count) {
            if (count < 0) throw new IllegalArgumentException("count < 0");
            this.sync = new Sync(count);
        }
        
        private static final class Sync extends AbstractQueuedSynchronizer {
            private static final long serialVersionUID = 4982264981922014374L;
    
            Sync(int count) {
                setState(count);
            }
    
            int getCount() {
                return getState();
            }
    
            protected int tryAcquireShared(int acquires) {
                // 当数量达到0时,才能通行,否则阻塞
                return (getState() == 0) ? 1 : -1;
            }
    
            protected boolean tryReleaseShared(int releases) {
                for (;;) {
                    int c = getState();
                    // 如果数量达到0,则释放失败
                    if (c == 0)
                        return false;
                    int nextc = c-1;
                    // 尝试把数量递减
                    if (compareAndSetState(c, nextc))
                        return nextc == 0;
                }
            }
        }
    
        public void await() throws InterruptedException {
            // 获取共享锁
            sync.acquireSharedInterruptibly(1);
        }
    
        public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
            // 尝试获取共享锁
            return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
        }
    
        public void countDown() {
            // 释放共享锁
            sync.releaseShared(1);
        }
    
        public long getCount() {
            return sync.getCount();
        }
    
        public String toString() {
            return super.toString() + "[Count = " + sync.getCount() + "]";
        }
    }

给个栗子,代码如下:

    package com.test;
    
    import java.util.HashMap;
    import java.util.Iterator;
    import java.util.concurrent.CountDownLatch;
    
    /**
     * 
     * @Description: Java 多线程条件通行工具——CountDownLatch
     * @ClassName: CountDownLatch.java
     * @author ChenYongJia
     * @Date 2019年4月17日 晚上22:54
     * @Email chen87647213@163.com
     */
    public class CountDownLatchTest {
    
    	public static void main(String[] args) throws InterruptedException {
    
    		CountDownLatchTest test = new CountDownLatchTest();
    
    		CountDownLatch latch = new CountDownLatch(1);
    
    		Thread t1 = new Thread(test.new MapOper(latch));
    		Thread t2 = new Thread(test.new MapOper(latch));
    		Thread t3 = new Thread(test.new MapOper(latch));
    		Thread t4 = new Thread(test.new MapOper(latch));
    
    		t1.setName("Thread1");
    		t2.setName("Thread2");
    		t3.setName("Thread3");
    		t4.setName("Thread4");
    		t1.start();
    		t2.start();
    		t3.start();
    		t4.start();
    
    		System.out.println("线程已经启动,休眠一会儿...");
    
    		Thread.sleep(1000);
    		latch.countDown();
    
    	}
    
    	public class MapOper implements Runnable {
    
    		CountDownLatch latch;
    
    		public MapOper(CountDownLatch latch) {
    			this.latch = latch;
    		}
    
    		public void run() {
    			try {
    				latch.await();
    			} catch (InterruptedException e) {
    				e.printStackTrace();
    			}
    			System.out.println(Thread.currentThread().getName() + " 同步开始!");
    		}
    	}
    
    }

控制台输出:

    线程已经启动,休眠一会儿...
    Thread2 同步开始!
    Thread3 同步开始!
    Thread1 同步开始!
    Thread4 同步开始!