自定义线程池的代码如下
package org.example.lesson3;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Queue;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
@Data
@NoArgsConstructor(force = true)
public class ThreadPoolCustomer {
/**
* 线程池参数
* corePoolSize:线程池中核心线程数
* maximumPoolSize:线程池中最大线程数
* keepAliveTime:非核心线程闲置超时时间
* threadNamePrefix:线程名称前缀
*/
private volatile int corePoolSize;
private volatile int maximumPoolSize;
private volatile long keepAliveTime;
private ThreadFactory threadFactory;
private volatile Queue<Worker> workers = new ConcurrentLinkedQueue<>();
private volatile BlockingQueue<Runnable> workQueue;
private ReentrantLock lock = new ReentrantLock();
private AtomicInteger activeCount = new AtomicInteger(0);
private AtomicInteger activeCoreCount = new AtomicInteger(0);
private AtomicInteger runningWorkerCount = new AtomicInteger(0);
// 决绝策略
private RejectedExecutionHandler handler;
// 挂起线程 等待任务被执行完后唤醒该线程 即汇报线程
private volatile Thread huibao = null;
public ThreadPoolCustomer(int corePoolSize, int maximumPoolSize,
long keepAliveTime, ThreadFactory threadFactory,
BlockingQueue<Runnable> workQueue,
RejectedExecutionHandler rejectedExecutionHandler) {
this.corePoolSize = corePoolSize;
this.maximumPoolSize = maximumPoolSize;
this.keepAliveTime = keepAliveTime;
this.threadFactory = threadFactory;
this.workQueue = workQueue;
this.handler = rejectedExecutionHandler;
}
public void execute(Runnable task) {
boolean success = false;
lock.lock();
if (workers == null || workers.size() < corePoolSize) {
Worker worker = new Worker(threadFactory, this, true);
worker.task = task;
workers.add(worker);
worker.start();
success = true;
}
/**
*Queue.offer() 方法详解
* 1. 方法定义与核心特性
* Queue.offer(E e) 是 Java 队列(Queue)接口的核心方法,用于向队列尾部插入元素。其核心特性如下:
* 返回值:成功插入返回 true,队列已满时返回 false(而非抛出异常)
*/
else if (workQueue != null && workQueue.offer(task)) {
// 任务入队成功
success = true;
} else {
if (workers.size() < maximumPoolSize) {
Worker worker = new Worker(threadFactory, this, false);
worker.task = task;
workers.add(worker);
worker.start();
success = true;
}
}
lock.unlock();
if (!success) {
if (activeCount.get() >= maximumPoolSize) {
handler.rejectedExecution(task, this);
} else {
throw new RuntimeException("任务无法添加,原因未知");
}
}
}
public void await() {
for (; ; ) {
if (workQueue.isEmpty()) {
if (runningWorkerCount.get() == 0) {
System.out.println("Waiting for tasks to complete," +
"core=" + corePoolSize
+ ",nonCore=" + (workers.size() - corePoolSize)
+ ",workQueue=" + workQueue.size()
);
return;
}
}
huibao(Thread.currentThread());
// 等待 5000ms
System.out.println("等待线程被挂起");
// 20s
LockSupport.parkNanos(1000L *1000*20*1000);
//LockSupport.park();
System.out.println("等待线程恢复");
}
}
private synchronized void tryHuiBao() {
if (huibao != null && runningWorkerCount.get() == 0 && workQueue.isEmpty()) {
System.out.println(huibao.getName() + " 线程被唤醒,workQueue="+workQueue.size()+",时间="
+ new SimpleDateFormat("HH:mm:ss").format(new Date()));
LockSupport.unpark(huibao);
huibao = null;
}
}
public void huibao(Thread thread) {
huibao = thread;
}
protected static class Worker implements Runnable {
private Runnable task;
private final ReentrantLock WORKER_LOCK = new ReentrantLock();
private final boolean core;
private final Thread t;
private final ThreadPoolCustomer pool;
public Worker(ThreadFactory threadFactory, ThreadPoolCustomer pool, boolean core) {
this.t = threadFactory.newThread(this, core);
this.core = core;
this.pool = pool;
this.pool.activeCount.incrementAndGet();
if (core) {
this.pool.activeCoreCount.incrementAndGet();
}
}
@Override
public void run() {
runWorker();
}
private void runWorker() {
for (; ; ) {
try {
if (task != null) {
pool.runningWorkerCount.incrementAndGet();
task.run();
pool.runningWorkerCount.decrementAndGet();
task = null;
pool.tryHuiBao();
} else {
if (core) {
// 核心线程从 workQueue 中获取任务
task = pool.getWorkQueue().take();
} else {
// 非核心线程从 workQueue 中获取任务
task = pool.getWorkQueue()
.poll(pool.keepAliveTime, java.util.concurrent.TimeUnit.MILLISECONDS);
}
if (task == null) {
if (core) {
// 核心线程不超时
} else {
// 非核心线程超时退出
pool.workers.remove(this);
pool.activeCount.decrementAndGet();
pool.tryHuiBao();
return;
}
} else {
// 任务执行成功,继续获取下一个任务
pool.runningWorkerCount.incrementAndGet();
task.run();
pool.runningWorkerCount.decrementAndGet();
task = null;
pool.tryHuiBao();
}
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
public void start() {
t.start();
}
}
public interface RejectedExecutionHandler {
void rejectedExecution(Runnable r, ThreadPoolCustomer executor);
}
public static interface ThreadFactory {
/**
* Constructs a new {@code Thread}. Implementations may also initialize
* priority, name, daemon status, {@code ThreadGroup}, etc.
*
* @param r a runnable to be executed by new thread instance
* @return constructed thread, or {@code null} if the request to
* create a thread is rejected
*/
Thread newThread(Runnable r, boolean core);
}
static class DefaultThreadFactory implements ThreadPoolCustomer.ThreadFactory {
private final String prefix;
private final AtomicInteger counter = new AtomicInteger(1);
public DefaultThreadFactory(String prefix) {
this.prefix = prefix;
}
@Override
public Thread newThread(Runnable r, boolean core) {
Thread thread = new Thread(r, prefix + "[" + (core ? "*" : "$") + "]-" + counter.getAndIncrement());
thread.setDaemon(false);
return thread;
}
}
static class Task implements Runnable {
private String name;
public String getName() {
return name;
}
public Task(String name) {
this.name = name;
}
public String getTime() {
return new SimpleDateFormat("HH:mm:ss").format(new Date());
}
@Override
public void run() {
System.out.println(Thread.currentThread().getName()
+ "working on task: " + name + ", time: " + getTime());
try {
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
}
测试代码如下:
package org.example.lesson3;
import java.util.concurrent.ArrayBlockingQueue;
public class TestThreadPool {
public static void main(String[] args) {
ThreadPoolCustomer pool = new ThreadPoolCustomer(
2, 5, 5000L,
new ThreadPoolCustomer.DefaultThreadFactory("worker"),
new ArrayBlockingQueue<>(5),
(task, executor) -> {
ThreadPoolCustomer.Task task1 = (ThreadPoolCustomer.Task) task;
System.out.println(task1.getName() + " rejected");
}
);
for (int i = 0; i < 11; i++) {
pool.execute(new ThreadPoolCustomer.Task("task-" + (i + 1)));
}
pool.await();
for (int i = 0; i < 2; i++) {
pool.execute(new ThreadPoolCustomer.Task("task-" + (i + 1)));
}
}
}
这个代码对线程池的核心参数和任务执行过程进行了模拟 方便开发者对java的线程池的参数和使用有更好的理解