自定义线程池

92 阅读3分钟

自定义线程池的代码如下

package org.example.lesson3;

import lombok.Data;
import lombok.NoArgsConstructor;


import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Queue;
import java.util.concurrent.*;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;

@Data

@NoArgsConstructor(force = true)
public class ThreadPoolCustomer {
    /**
     * 线程池参数
     * corePoolSize:线程池中核心线程数
     * maximumPoolSize:线程池中最大线程数
     * keepAliveTime:非核心线程闲置超时时间
     * threadNamePrefix:线程名称前缀
     */
    private volatile int corePoolSize;
    private volatile int maximumPoolSize;
    private volatile long keepAliveTime;


    private ThreadFactory threadFactory;

    private volatile Queue<Worker> workers = new ConcurrentLinkedQueue<>();

    private volatile BlockingQueue<Runnable> workQueue;

    private ReentrantLock lock = new ReentrantLock();

    private AtomicInteger activeCount = new AtomicInteger(0);

    private AtomicInteger activeCoreCount = new AtomicInteger(0);

    private AtomicInteger runningWorkerCount = new AtomicInteger(0);

    // 决绝策略
    private RejectedExecutionHandler handler;

    // 挂起线程 等待任务被执行完后唤醒该线程 即汇报线程
    private volatile Thread huibao = null;

    public ThreadPoolCustomer(int corePoolSize, int maximumPoolSize,
                              long keepAliveTime, ThreadFactory threadFactory,
                              BlockingQueue<Runnable> workQueue,
                              RejectedExecutionHandler rejectedExecutionHandler) {
        this.corePoolSize = corePoolSize;
        this.maximumPoolSize = maximumPoolSize;
        this.keepAliveTime = keepAliveTime;
        this.threadFactory = threadFactory;
        this.workQueue = workQueue;
        this.handler = rejectedExecutionHandler;
    }

    public void execute(Runnable task) {
        boolean success = false;
        lock.lock();
        if (workers == null || workers.size() < corePoolSize) {
            Worker worker = new Worker(threadFactory, this, true);
            worker.task = task;
            workers.add(worker);
            worker.start();
            success = true;
        }
        /**
         *Queue.offer() 方法详解‌
         * 1. 方法定义与核心特性‌
         * Queue.offer(E e) 是 Java 队列(Queue)接口的核心方法,用于向队列尾部插入元素。其核心特性如下:
         * ‌返回值‌:成功插入返回 true,队列已满时返回 false(而非抛出异常)‌
         */
        else if (workQueue != null && workQueue.offer(task)) {
            // 任务入队成功
            success = true;
        } else {
            if (workers.size() < maximumPoolSize) {
                Worker worker = new Worker(threadFactory, this, false);
                worker.task = task;
                workers.add(worker);
                worker.start();
                success = true;
            }
        }
        lock.unlock();
        if (!success) {
            if (activeCount.get() >= maximumPoolSize) {
                handler.rejectedExecution(task, this);
            } else {
                throw new RuntimeException("任务无法添加,原因未知");
            }
        }
    }

    public void await() {
        for (; ; ) {

            if (workQueue.isEmpty()) {
                if (runningWorkerCount.get() == 0) {
                    System.out.println("Waiting for tasks to complete," +
                            "core=" + corePoolSize
                            + ",nonCore=" + (workers.size() - corePoolSize)
                            + ",workQueue=" + workQueue.size()
                    );
                    return;
                }
            }
            huibao(Thread.currentThread());
            // 等待 5000ms
            System.out.println("等待线程被挂起");
            // 20s
            LockSupport.parkNanos(1000L *1000*20*1000);
            //LockSupport.park();
            System.out.println("等待线程恢复");

        }
    }

    private synchronized void tryHuiBao() {
        if (huibao != null && runningWorkerCount.get() == 0 && workQueue.isEmpty()) {
            System.out.println(huibao.getName() + " 线程被唤醒,workQueue="+workQueue.size()+",时间="
                    + new SimpleDateFormat("HH:mm:ss").format(new Date()));
            LockSupport.unpark(huibao);
            huibao = null;
        }
    }

    public void huibao(Thread thread) {
        huibao = thread;
    }


    protected static class Worker implements Runnable {

        private Runnable task;

        private final ReentrantLock WORKER_LOCK = new ReentrantLock();
        private final boolean core;

        private final Thread t;

        private final ThreadPoolCustomer pool;




        public Worker(ThreadFactory threadFactory, ThreadPoolCustomer pool, boolean core) {
            this.t = threadFactory.newThread(this, core);
            this.core = core;
            this.pool = pool;
            this.pool.activeCount.incrementAndGet();
            if (core) {
                this.pool.activeCoreCount.incrementAndGet();
            }
        }

        @Override
        public void run() {
            runWorker();
        }

        private void runWorker() {
            for (; ; ) {
                try {
                    if (task != null) {
                        pool.runningWorkerCount.incrementAndGet();
                        task.run();
                        pool.runningWorkerCount.decrementAndGet();
                        task = null;
                        pool.tryHuiBao();
                    } else {
                        if (core) {
                            // 核心线程从 workQueue 中获取任务
                            task = pool.getWorkQueue().take();
                        } else {
                            // 非核心线程从 workQueue 中获取任务
                            task = pool.getWorkQueue()
                                    .poll(pool.keepAliveTime, java.util.concurrent.TimeUnit.MILLISECONDS);
                        }
                        if (task == null) {
                            if (core) {

                                // 核心线程不超时
                            } else {
                                // 非核心线程超时退出
                                pool.workers.remove(this);
                                pool.activeCount.decrementAndGet();
                                pool.tryHuiBao();
                                return;
                            }
                        } else {
                            // 任务执行成功,继续获取下一个任务
                            pool.runningWorkerCount.incrementAndGet();
                            task.run();
                            pool.runningWorkerCount.decrementAndGet();
                            task = null;
                            pool.tryHuiBao();
                        }
                    }
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        }



        public void start() {
            t.start();
        }


    }


    public interface RejectedExecutionHandler {
        void rejectedExecution(Runnable r, ThreadPoolCustomer executor);
    }


    public static interface ThreadFactory {

        /**
         * Constructs a new {@code Thread}.  Implementations may also initialize
         * priority, name, daemon status, {@code ThreadGroup}, etc.
         *
         * @param r a runnable to be executed by new thread instance
         * @return constructed thread, or {@code null} if the request to
         * create a thread is rejected
         */
        Thread newThread(Runnable r, boolean core);
    }

    static class DefaultThreadFactory implements ThreadPoolCustomer.ThreadFactory {
        private final String prefix;
        private final AtomicInteger counter = new AtomicInteger(1);

        public DefaultThreadFactory(String prefix) {
            this.prefix = prefix;
        }

        @Override
        public Thread newThread(Runnable r, boolean core) {
            Thread thread = new Thread(r, prefix + "[" + (core ? "*" : "$") + "]-" + counter.getAndIncrement());
            thread.setDaemon(false);
            return thread;
        }
    }

    static class Task implements Runnable {
        private String name;

        public String getName() {
            return name;
        }

        public Task(String name) {
            this.name = name;
        }

        public String getTime() {
            return new SimpleDateFormat("HH:mm:ss").format(new Date());
        }

        @Override
        public void run() {
            System.out.println(Thread.currentThread().getName()
                    + "working on task: " + name + ", time: " + getTime());
            try {
                TimeUnit.SECONDS.sleep(3);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }

        }
    }
}

测试代码如下:

package org.example.lesson3;

import java.util.concurrent.ArrayBlockingQueue;

public class TestThreadPool {
    public static void main(String[] args) {

        ThreadPoolCustomer pool = new ThreadPoolCustomer(
                2, 5, 5000L,
                new ThreadPoolCustomer.DefaultThreadFactory("worker"),
                new ArrayBlockingQueue<>(5),
                (task, executor) -> {
                    ThreadPoolCustomer.Task task1 = (ThreadPoolCustomer.Task) task;
                    System.out.println(task1.getName() + " rejected");
                }
        );

        for (int i = 0; i < 11; i++) {
            pool.execute(new ThreadPoolCustomer.Task("task-" + (i + 1)));
        }
        pool.await();
        for (int i = 0; i < 2; i++) {
            pool.execute(new ThreadPoolCustomer.Task("task-" + (i + 1)));
        }

    }
}

这个代码对线程池的核心参数和任务执行过程进行了模拟 方便开发者对java的线程池的参数和使用有更好的理解