同步工具类

155 阅读4分钟

同步工具类

并行任务中,如果我们想要指定任务的执行策略(如123任务完成后再开始456任务),我们可以借助同步工具类实现。接下来, 我将介绍一些常用的同步工具类。

闭锁Latch

闭锁是一种同步工具类,可以延迟线程的进度知道其达到终止状态。闭锁相当于一扇门,在闭锁到达结束状态之前,这扇门一直是关闭的,并且没有任何线程能通过,当闭锁达到结束状态时,闭锁会打开并允许所有线程通过。

CountDownLatch是一个经典灵活又常见的闭锁实现。CountDownLatch可以在被构造时指定其计数值,调用对象的countDown方法减少计数值,而await方法将会阻塞直到计数值为0

import java.util.concurrent.CountDownLatch;

public class CountDownLatchSimplifiedExample {
    private static final int NUM_WORKERS = 3;

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch startSignal = new CountDownLatch(1); // 启动信号
        CountDownLatch doneSignal = new CountDownLatch(NUM_WORKERS); // 完成信号

        for (int i = 0; i < NUM_WORKERS; i++) {
            new Thread(()->{
                try {
                    startSignal.await(); // 等待启动信号
                    System.out.println("工作线程 " + Thread.currentThread().getName() + " 开始工作...");
                    Thread.sleep(1000); // 工作线程执行任务
                    System.out.println("工作线程 " + Thread.currentThread().getName() + " 完成工作。");
                    doneSignal.countDown(); // 完成信号
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }).start();
        }
        System.out.println("准备启动所有工作线程...");
        startSignal.countDown(); // 启动所有工作线程
        doneSignal.await(); // 等待所有工作线程完成
        System.out.println("所有工作线程已完成。继续主线程的工作...");
    }
}

CountDownLatch 对象不能被重用。它的计数器只能被减少,不能被重置。一旦计数器达到零,CountDownLatch 对象就不能再次使用。因此,如果需要重用计数器,可以考虑使用其他的同步机制,比如 CyclicBarrier

栅栏Barrier

栅栏与闭锁的关键区别在于,所有线程必须到达栅栏位置,才能继续执行。闭锁用于等待事件,而栅栏用于等待其他线程。

CyclicBarrier可以使一定数量的参与方反复地在栅栏位置汇集,它在并发算法中非常有用:这种算法通常将一个问题拆分成一系列相互独立的子问题。当线程到达栅栏位置时使用await方法,这个方法将阻塞直到所有线程都到达这个位置。所有线程到达后,栅栏将释放所有线程,然后栅栏将重置以便下次使用。如果栅栏得知有的线程无法到达栅栏处(通常是被中断),那么所有阻塞的线程都会抛出Broken Barrier Exception。

CyclicBarrier十分适合自动化模拟任务,以下是一个示例

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

public class ReusableBarrierExample {
    private static final int ROUNDS = 5;

    public static void main(String[] args) {
        int numberOfParticipants = 3;
        CyclicBarrier barrier = new CyclicBarrier(numberOfParticipants, new Runnable() {
            @Override
            public void run() {
                System.out.println("All participants have finished a round!");
            }
        });

        for (int i = 1; i <= numberOfParticipants; i++) {
            int participantNumber = i;
            new Thread(new Runnable() {
                private CyclicBarrier barrier;
                private int participantNumber;

                // Constructor
                public Runnable init(CyclicBarrier barrier, int participantNumber) {
                    this.barrier = barrier;
                    this.participantNumber = participantNumber;
                    return this;
                }

                @Override
                public void run() {
                    try {
                        for (int round = 1; round <= ROUNDS; round++) {
                            System.out.println("Participant " + participantNumber + " is starting round " + round);
                            // Simulate time taken to complete the round
                            Thread.sleep((long) (Math.random() * 10000));
                            System.out.println("Participant " + participantNumber + " has finished round " + round);

                            // Wait for other participants to finish the round
                            barrier.await();
                        }
                    } catch (InterruptedException | BrokenBarrierException e) {
                        e.printStackTrace();
                    }
                }
            }.init(barrier, participantNumber)).start();
        }
    }
}

信号量Semaphore

计数信号量CotuntingSemaphore用来控制同时访问某个特定资源的操作数量,或者同时执行某个指定操作的数量。信号量可以实现某种资源池,或者对容器施加边界。

Semaphore中管理这一组虚拟许可,可以通过acquire取得一个许可,无可用时将阻塞直到取得许可。release方法将返回一个许可。这样来看,Semaphore类似于一个无序的队列/栈。

以下是一个连接池的示例

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Semaphore;

public class ConnectionPool {
    private final List<Connection> pool;
    private final Semaphore semaphore;
    private final int poolSize;

    public ConnectionPool(String url, String user, String password, int poolSize) throws SQLException {
        this.poolSize = poolSize;
        this.pool = new ArrayList<>(poolSize);
        this.semaphore = new Semaphore(poolSize, true);

        // Initialize the connection pool
        for (int i = 0; i < poolSize; i++) {
            Connection connection = DriverManager.getConnection(url, user, password);
            pool.add(connection);
        }
    }

    public Connection acquireConnection() throws InterruptedException {
        semaphore.acquire();
        synchronized (pool) {
            return pool.remove(pool.size() - 1);
        }
    }

    public void releaseConnection(Connection connection) {
        synchronized (pool) {
            pool.add(connection);
        }
        semaphore.release();
    }

    public void shutdown() throws SQLException {
        for (Connection connection : pool) {
            connection.close();
        }
    }

    public static void main(String[] args) {
        try {
            ConnectionPool connectionPool = new ConnectionPool(
                    "jdbc:mysql://localhost:3306/test",
                    "root",
                    "password",
                    10
            );

            // Example usage
            Connection connection = connectionPool.acquireConnection();
            // Use the connection...

            // Release the connection back to the pool
            connectionPool.releaseConnection(connection);

            // Shutdown the pool
            connectionPool.shutdown();

        } catch (SQLException | InterruptedException e) {
            e.printStackTrace();
        }
    }
}