本文中源码来自JDK 8。
Semaphore,即信号量,用来控制并发访问特定资源的线程数量,实现流量控制。
1 使用场景
对数量有限资源的访问进行限制。 比如程序中的数据库连接池,最大连接数为15个。那么,应当控制最多15个线程同时获得数据库连接,多余线程需等待有空闲连接出现,才能获取成功。
2 原理
内部类Sync是AbstractQueuedSynchronizer的子类,是同步器实现。Semaphore中所有操作同步状态的方法,都转交给Sync实现。
Sync又有公平和非公平两个子类实现:FairSync、NonfairSync
2.1 初始化
指许可数量,即可创建Semaphore实例,默认采用非公平模式。
2.2 acquire()方法
申请一个许可。acquire()是可中断实现,acquireUninterruptibly()是不可中断实现。
它们都通过Sync转交给AQS的模板方法,又会调用被复写的tryAcquireShared(int arg)方法。
FairSync中,tryAcquireShared实现如下。
NonfairSync的tryAcquireShared实现,与FairSync中的相似,只是没有
hasQueuedPredecessors()的if判断:非公平时,新来的线程可以直接争抢许可,如果成功则不用入队。
2.3 release()方法
归还一个许可。转调AQS的releaseShared(int arg)方法。
重写后的tryReleaseShared(arg)
2.4 其他方法
availablePermits():返回当前可用许可数
drainPermits():获取当前所有可用许可
isFair():当前Semaphore是否是公平的
hasQueuedThreads():是否有线程正在等待获取许可
getQueueLength():返回队列中等待的线程数
3 使用Semaphore实现连接池
假设你需要自己实现一个连接池:
- 池中持有与某一地址的多个连接;
- 线程从中获取连接,使用完成后归还连接;
- 当连接都在使用中时,获取连接的线程将阻塞,直到有连接空闲;
- 不支持空闲连接回收、连接保活和连接异常检测。
使用Semaphore可以轻松实现这些功能。
import java.io.IOException;
import java.net.Socket;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
public class SimpleConnectionPool {
private final int poolSize;
private final String ip;
private final int port;
// 连接池
private final SimpleConnection[] connections;
// 连接状态数组:0 表示空闲, 1 表示繁忙
private final AtomicIntegerArray states;
private static final int IDLE = 0;
private static final int IN_USE = 1;
private final Semaphore semaphore;
public SimpleConnectionPool(int poolSize, String ip, int port) {
this.poolSize = poolSize;
this.ip = ip;
this.port = port;
// 许可数与连接数一致
this.semaphore = new Semaphore(poolSize);
this.connections = new SimpleConnection[poolSize];
// 连接初始状态为IDLE
this.states = new AtomicIntegerArray(poolSize);
// 初始化所有连接
for (int i = 0; i < poolSize; i++) {
try {
connections[i] = new SimpleConnection(ip, port);
} catch (IOException e) {
throw new RuntimeException("failed to create connection.", e);
}
}
}
public SimpleConnection borrow() throws InterruptedException {
semaphore.acquire();
for (int i = 0; i < poolSize; i++) {
if (states.get(i) == IDLE) {
if (states.compareAndSet(i, IDLE, IN_USE)) {
return connections[i];
}
}
}
// 不会执行到这里
return null;
}
public void free(SimpleConnection conn) {
for (int i = 0; i < poolSize; i++) {
if (connections[i] == conn) {
if (states.get(i) == IDLE) {
return;
}
states.set(i, IDLE);
semaphore.release();
return;
}
}
throw new RuntimeException("conn isn't in the pool.");
}
}
// 应提供更多功能
class SimpleConnection {
private static final AtomicInteger serialNumber = new AtomicInteger(1);
private final String name;
private final String ip;
private final int port;
private final Socket socket;
public SimpleConnection(String ip, int port) throws IOException {
this.name = "conn-" + serialNumber.getAndIncrement() + "[" + ip + ":" + port + "]";
this.ip = ip;
this.port = port;
this.socket = new Socket(ip, port);
}
private Socket socket() throws IOException {
return socket;
}
public String toString() {
return name;
}
}
我们来测试一下。
@Slf4j
public class TestConnectionPool {
public static void main(String[] args) {
// 与百度建立连接
SimpleConnectionPool connectionPool = new SimpleConnectionPool(3, "www.baidu.com", 80);
for (int i = 0; i < 5; i++) {
new Thread(() -> {
try {
SimpleConnection conn = connectionPool.borrow();
log.info("get conn:" + conn);
// 模拟网络通信耗时
Thread.sleep(1000);
connectionPool.free(conn);
} catch (InterruptedException e) {
}
}).start();
}
}
}