AQS应用篇

111 阅读3分钟

概述

原理分析

相关实例

  • CountDownLatch(共享锁)

  • 概述 CountDownLatch是一个同步工具类,它允许一个或多个线程一直等待,直到其他线程执行完后再执行。例如,应用程序的主线程希望在负责启动框架服务的线程已经启动所有框架服务之后执行。

  • 实例


public class CountDownLatchDemo {


    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(10);

        for (int i = 0; i < 9; i++) {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread().getName() + " 运行");
                    try {
                        Thread.sleep(3000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        latch.countDown();
                    }
                }
            }).start();
        }

        System.out.println("等待子线程运行结束");
        latch.await(10, TimeUnit.SECONDS);
        System.out.println("子线程运行结束");

    }
}

Thread-3 运行
Thread-2 运行
Thread-6 运行
Thread-0 运行
Thread-7 运行
Thread-1 运行
Thread-5 运行
Thread-4 运行
Thread-8 运行
子线程运行结束
  • 源码分析

Sync内部类

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

构造函数

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

wait方法

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}


//SYN tryAcquireShared 方法
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}


//AQS acquireSharedInterruptibly 方法
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}


//AQS doAcquireSharedInterruptibly 方法
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}





countDown方法


public void countDown() {
    sync.releaseShared(1);
}


//SYN 内部类tryReleaseShared 方法
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

//AQS releaseShared 方法
public final boolean releaseShared(int arg) {
    //status==0 时唤醒后继节点
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}


  • Semaphore

  • 概述 Semaphore 通常我们叫它信号量, 可以用来控制同时访问特定资源的线程数量,通过协调各个线程,以保证合理的使用资源

  • ReentrantReadWriteLock

  • CyclicBarrier

  • 概述 CyclicBarrier 的字面意思是可循环使用(Cyclic)的屏障(Barrier)。 它要做的事情是,让一组线程到达一个屏障(也可以叫同步点)时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程才会继续干活。 CyclicBarrier 默认的构造方法是 CyclicBarrier(int parties),其参数表示屏障拦截的线程数量,每个线程调用 await 方法告诉 CyclicBarrier 我已经到达了屏障,然后当前线程被阻塞。

  • 对比CountDownLatch

  1. CyclicBarrier的计数器由自己控制,而CountDownLatch的计数器则由使用者来控制
  2. CountDownLatch只能拦截一轮,而CyclicBarrier可以实现循环拦截

image.png

  • 实例
public class CyclicBarrierDemo {

    public static void main(String[] args) {
        ExecutorService executor = Executors.newCachedThreadPool();
        //周末3人聚会,需要等待3个人全部到齐餐厅后才能开始吃饭
        CyclicBarrier cb = new CyclicBarrier(3);
        System.out.println("初始化:有" + (3 - cb.getNumberWaiting()) + "个人正在赶来餐厅");
        for (int i = 0; i < 3; i++) {   //定义3个任务,即3个人从家里赶到餐厅
            //设置用户的编号
            final int person = i;
            executor.execute(() -> {    //lambda表达式
                try {
                    //此处睡眠,模拟3个人从家里来到餐厅所花费的时间
                    Thread.sleep((long) (Math.random() * 10000));
                    System.out.println(Thread.currentThread().getName() + "---用户" + person + "即将达到餐厅," +
                            "用户" + person + "到达餐厅了。" + "当前已有" + (cb.getNumberWaiting() + 1) + "个人到达餐厅");
                    cb.await();
                    System.out.println("三个人都到到餐厅啦," + Thread.currentThread().getName() + "开始吃饭了");
                    //todo 吃完饭后想去网吧开黑  这里具体代码我就不写啦  留给小伙伴自己实现 >.<
                    //再次wait(),等待3个人全部到达网吧  cb是可以复用的!
                    cb.await();
                    //3个人都到达网吧了,开始玩游戏 playGame()...
                } catch (InterruptedException | BrokenBarrierException e) {
                    e.printStackTrace();
                }
            });
        }
        executor.shutdown();    //关闭线程池
    }

}

image.png

  • 源码分析 CyclicBarrier 基于 Condition 和 ReentrantLock 来实现的。在CyclicBarrier类的内部有一个计数器,每个线程在到达屏障点的时候都会调用await方法将自己阻塞,此时计数器会减1,当计数器减为0的时候所有因调用await方法而被阻塞的线程将被唤醒。这就是实现一组线程相互等待的原理

CyclicBarrier 属性

//同步操作锁
private final ReentrantLock lock = new ReentrantLock();
//线程拦截器
private final Condition trip = lock.newCondition();
//每次拦截的线程数
private final int parties;
//换代前执行的任务
private final Runnable barrierCommand;
//表示栅栏的当前代
private Generation generation = new Generation();
//计数器
private int count;
 
//静态内部类Generation
private static class Generation {
  boolean broken = false;
  }

CyclicBarrier 构造函数

 public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

CyclicBarrier 方法

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;

        //检查当前栅栏是否被打翻
        if (g.broken)
            throw new BrokenBarrierException();

        //检查当前线程是否被中断
        if (Thread.interrupted()) {
            //如果当前线程被中断会做以下三件事
            //1.打翻当前栅栏
            //2.唤醒拦截的所有线程
            //3.抛出中断异常
            breakBarrier();
            throw new InterruptedException();
        }

        //每次都将计数器的值减1
        int index = --count;
        
        //计数器的值减为0则需唤醒所有线程并转换到下一代
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                //醒所有线程前先执行指定的任务
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                //唤醒所有线程并转到下一代
                nextGeneration();
                return 0;
            } finally {
                //保在任务未成功执行时能将所有线程唤醒
                if (!ranAction)
                    breakBarrier();
            }
        }

        // 如果计数器不为0则执行此循环
        for (;;) {
            try {
                // 据传入的参数来决定是定时等待还是非定时等待
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                //若当前线程在等待期间被中断则打翻栅栏唤醒其他线程
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                   // 若在捕获中断异常前已经完成在栅栏上的等待, 则直接调用中断操作
                   Thread.currentThread().interrupt();
                }
            }

            //如果线程因为打翻栅栏操作而被唤醒则抛出异常
            if (g.broken)
                throw new BrokenBarrierException();
                
            //如果线程因为换代操作而被唤醒则返回计数器的值
            if (g != generation)
                return index;
            
            //如果线程因为时间到了而被唤醒则打翻栅栏并抛出异常
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        // 释放锁
        lock.unlock();
    }
}

breakBarrie 方法

 //如果当前线程被中断会做以下三件事
//1.打翻当前栅栏
//2.唤醒拦截的所有线程
//3.重置计算器
private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll();
}

nextGeneration 方法

private void nextGeneration() {
    // signal completion of last generation
    trip.signalAll();
    // set up next generation
    count = parties;
    generation = new Generation();
}
  • ReentranLock