「源码学习」简单实现线程池

141 阅读5分钟

线程源码学习系列

「源码学习」Thread 类

「源码学习」ThreadLocal 类

「源码学习」ThreadPoolExecutor类

曾写过多次线程池,经过本次线程的源码学习后,根据所领悟到的功力,参考源码,动手一步步实现一个简易版的线程池,主要有以下几点:

  • 核心线程
  • 最大线程
  • 线程阻塞队列
  • 拒绝策略
  • 关闭线程池

MyThreadPoolExecutor

package com.yx.thread.pool.v1;

import java.util.HashSet;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 线程池
 */
public class MyThreadPoolExecutor {

    //默认核心线程数
    private static final int CORE_SIZE_DEFAULT = 5;
    //默认阻塞队列数
    private static final int WAIT_QUEUE_SIZE_DEFAULT = 10;
    //默认最大线程数
    private static final int MAX_SIZE_DEFAULT = 10;

    //核心线程数
    private volatile int corePoolSize;
    //最大线程数
    private volatile int maxPoolSize;

    /**
     * 线程池执行状态
     */
    private volatile boolean RUNNING = true;
    /**
     * 线程池终止状态
     */
    private volatile boolean SHUTDOWN = false;

    //工作线程
    private final HashSet<Worker> workers = new HashSet<Worker>();
    //任务队列
    private final BlockingQueue<Runnable> waitQueue;
    //线程工厂
    private final ThreadFactory threadFactory;

    //默认拒绝策略
    private static final MyRejectedExecutionHandler defaultHandler =
            new  AbortPolicy();

    private volatile MyRejectedExecutionHandler handle;

    //全局锁
    private final ReentrantLock mainLock = new ReentrantLock();

    /**
     * 构造方法
     * @param corePoolSize
     * @param maxPoolSize
     */
    public MyThreadPoolExecutor(int corePoolSize,int maxPoolSize){
        this(corePoolSize,maxPoolSize,WAIT_QUEUE_SIZE_DEFAULT, Executors.defaultThreadFactory(),defaultHandler);
    }

    public MyThreadPoolExecutor(int corePoolSize,int maxPoolSize ,int queueSize){
        this(corePoolSize,maxPoolSize,queueSize, Executors.defaultThreadFactory(),defaultHandler);
    }

    public MyThreadPoolExecutor(int corePoolSize,int maxPoolSize,int queueSize,MyRejectedExecutionHandler handle){
        this(corePoolSize,maxPoolSize,queueSize, Executors.defaultThreadFactory(),handle);
    }

    public MyThreadPoolExecutor(int corePoolSize,int maxPoolSize ,int queueSize,ThreadFactory threadFactory,MyRejectedExecutionHandler handle) {
        if( maxPoolSize < corePoolSize){
            throw new RuntimeException("maxPoolSize < corePoolSize");
        }
        this.corePoolSize = corePoolSize <= 0 ? CORE_SIZE_DEFAULT : corePoolSize ;
        this.maxPoolSize = maxPoolSize <=0 ? MAX_SIZE_DEFAULT : maxPoolSize;
        this.waitQueue = queueSize <= 0 ?  new ArrayBlockingQueue<Runnable>(WAIT_QUEUE_SIZE_DEFAULT) : new ArrayBlockingQueue<Runnable>(queueSize);
        this.threadFactory = threadFactory;
        this.handle = handle;
    }

    /**
     * 执行任务
     * @param task
     * @return
     */
    public void execute(Runnable task){

        /**
         *执行流程:
         * if worker线程数 < 核心线程数
         *    创建新的worker线程 执行任务
         *    return;
         * if 当前阻塞等待队列已满
         *    if worker线程数 < 最大线程数
         *      创建新的worker线程 执行任务
         *      return;
         *    else
         *      拒绝策略
         */
        int workerCount = workers.size();
        if(workerCount < corePoolSize){
            ReentrantLock mainLock = this.mainLock;
            try{
                System.out.println("< corePoolSize ,create new worker");
                mainLock.lock();
                new Worker();
                return ;
            }finally {
                mainLock.unlock();
            }

        }
        //队列已满,添加任务失败
        if(!waitQueue.offer(task)){
            if(workerCount < maxPoolSize){
                ReentrantLock mainLock = this.mainLock;
                try {
                    System.out.println("< maxPoolSize,create new worker");
                    mainLock.lock();
                    new Worker();
                }finally {
                    mainLock.unlock();
                }
            } else {
                //拒绝任务
                reject(task);
            }
        }
    }

    /**
     * 优雅关闭
     */
    public void shutdown(){
        this.RUNNING = false;
    }

    /**
     * 暴力关闭
     */
    public void shutdownNow(){
        this.SHUTDOWN = true;
    }

    public boolean isShutdown(){
        return !this.RUNNING || SHUTDOWN;
    }

    public BlockingQueue<Runnable> getQueue(){
        return this.waitQueue;
    }

    /**
     * 拒绝策略
     */
    final void reject(Runnable task) {
        this.handle.rejectedExecution(task,this);
    }


    /**
     * 工作线程
     */
    private final class Worker implements Runnable{

        final Thread thread;

        public Worker(){
            this.thread = threadFactory.newThread(this);
            //UUID作为线程名称
            UUID uuid = UUID.randomUUID();
            this.thread.setName(uuid.toString());
            this.thread.start();
            workers.add(this);
        }

        @Override
        public void run() {
            Runnable task ;
            while (true){
                try {
//                    System.out.println("worker:" + workers.size() + ",waitqueue:" + waitQueue.size());
                    /**
                     * 关闭大于核心线程数的worker线程
                     * 如果worker线程数大于核心线程数 && 阻塞等待队列没有任务
                     */
                    if(waitQueue.isEmpty() && workers.size() > corePoolSize){
                        processWorkerExit(this);
                        break;
                    }
                    /**
                     * 1、优雅关闭:执行状态是false && 没有待执行的任务
                     * 2、暴击关闭:直接结束线程池
                     */
                    if((!RUNNING && waitQueue.isEmpty()) || SHUTDOWN){
                        System.out.println("close worker >>>>>>>>>" + this.thread.getName());
                        break;
                    } else {
                        task = waitQueue.take();
                        if(task != null){
                            task.run();
                            System.out.println("end task >> " + this.thread.getName()+ ">>>" + task.toString());
                        }
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } finally {
                    //GC回收
                    task = null;
                }
            }
        }
    }

    private void processWorkerExit(Worker worker){
        System.out.println("interrupt worker >>>" + worker.thread.getName());
        ReentrantLock mainLock = this.mainLock;
        try{
            mainLock.lock();
            worker.thread.interrupt();
            workers.remove(worker);
            System.out.println("worker:" + workers.size() + ",waitqueue:" + waitQueue.size());
        }finally {
            mainLock.unlock();
        }
    }
    /**
     * 仿写ThreadPoolExecutor 4种拒绝策略
     */
    public static class AbortPolicy implements MyRejectedExecutionHandler{
        public AbortPolicy(){
        }
        @Override
        public void rejectedExecution(Runnable r, MyThreadPoolExecutor executor) {
            throw new RuntimeException("AbortPolicy:" + r.toString());
        }
    }

    /**
     * 由调用线程(提交任务的线程)处理该任务。
     */
    public static class CallerRunsPolicy implements MyRejectedExecutionHandler {
        public CallerRunsPolicy() { }

        @Override
        public void rejectedExecution(Runnable r, MyThreadPoolExecutor executor) {
            if (!executor.isShutdown()) {
                System.out.println("CallerRunsPolicy:" + r.toString());
                r.run();
            }
        }
    }

    /**
     * 丢弃队列最前面的任务,然后重新提交被拒绝的任务。
     */
    public static class DiscardOldestPolicy implements MyRejectedExecutionHandler {
        public DiscardOldestPolicy() { }

        @Override
        public void rejectedExecution(Runnable r, MyThreadPoolExecutor executor) {
            if (!executor.isShutdown()) {
                Runnable pollRunnable = executor.getQueue().poll();
                executor.execute(r);
                System.out.println("DiscardOldestPolicy:" + pollRunnable.toString() + ">>>" + r.toString());
            }
        }
    }

    /**
     * 丢弃任务,但不抛出异常。
     */
    public static class DiscardPolicy implements MyRejectedExecutionHandler {
        public DiscardPolicy() { }
        @Override
        public void rejectedExecution(Runnable r, MyThreadPoolExecutor e) {
            System.out.println("DiscardPolicy:" + r.toString());
        }
    }
}

MyRejectedExecutionHandler

/**
 * 拒绝策略
 */
public interface MyRejectedExecutionHandler {
    void rejectedExecution(Runnable r, MyThreadPoolExecutor executor);
}

Test

package com.yx.thread.pool.v1;

public class Test {

    public static void main(String[] args) throws InterruptedException {
        MyThreadPoolExecutor pool = new MyThreadPoolExecutor(2,5,2);
        for(int i=0; i<30; i++){
            System.out.println("add task[" + i + "]");
            try{
                pool.execute(new MyTask("task[" + i + "]"));
            } catch (Exception e){
                System.out.println("add task fail" + e.getMessage());
            }
        }
        pool.shutdown();
    }

}
class MyTask implements Runnable {

    private String taskName;

    public MyTask(String taskName){
        this.taskName = taskName;
    }

    @Override
    public void run() {
        try {
            System.out.println("start task >> " + this.taskName);
            Thread.sleep(9000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    @Override
    public String toString() {
        return "MyTask{" +
                "taskName='" + taskName + ''' +
                '}';
    }
}

日志输出

add task[0]
< corePoolSize ,create new worker
add task[1]
< corePoolSize ,create new worker
add task[2]
start task >> task[2]
add task[3]
add task[4]
start task >> task[3]
add task[5]
add task[6]
< maxPoolSize,create new worker
add task[7]
< maxPoolSize,create new worker
start task >> task[4]
add task[8]
add task[9]
< maxPoolSize,create new worker
start task >> task[5]
add task[10]
add task[11]
add task failAbortPolicy:MyTask{taskName='task[11]'}
add task[12]
start task >> task[8]
add task[13]
add task failAbortPolicy:MyTask{taskName='task[13]'}
add task[14]
add task failAbortPolicy:MyTask{taskName='task[14]'}
add task[15]
add task failAbortPolicy:MyTask{taskName='task[15]'}
add task[16]
add task failAbortPolicy:MyTask{taskName='task[16]'}
add task[17]
add task failAbortPolicy:MyTask{taskName='task[17]'}
add task[18]
add task failAbortPolicy:MyTask{taskName='task[18]'}
add task[19]
add task failAbortPolicy:MyTask{taskName='task[19]'}
add task[20]
add task failAbortPolicy:MyTask{taskName='task[20]'}
add task[21]
add task failAbortPolicy:MyTask{taskName='task[21]'}
add task[22]
add task failAbortPolicy:MyTask{taskName='task[22]'}
add task[23]
add task failAbortPolicy:MyTask{taskName='task[23]'}
add task[24]
add task failAbortPolicy:MyTask{taskName='task[24]'}
add task[25]
add task failAbortPolicy:MyTask{taskName='task[25]'}
add task[26]
add task failAbortPolicy:MyTask{taskName='task[26]'}
add task[27]
add task failAbortPolicy:MyTask{taskName='task[27]'}
add task[28]
add task failAbortPolicy:MyTask{taskName='task[28]'}
add task[29]
add task failAbortPolicy:MyTask{taskName='task[29]'}
end task >> 5931157f-2dd5-4507-b494-f92e3b96e714>>>MyTask{taskName='task[8]'}
end task >> 91f0a092-1dad-4ff6-966d-235227adbbd1>>>MyTask{taskName='task[4]'}
end task >> 37f31cac-99ff-4d70-b5de-671d70dda9ab>>>MyTask{taskName='task[2]'}
end task >> 17f079fb-b407-4994-a82b-525f16cd6cf8>>>MyTask{taskName='task[3]'}
end task >> c4997ecf-1da3-4827-8fac-7e7f15b6e0cb>>>MyTask{taskName='task[5]'}
interrupt worker >>>17f079fb-b407-4994-a82b-525f16cd6cf8
interrupt worker >>>37f31cac-99ff-4d70-b5de-671d70dda9ab
worker:4,waitqueue:0
start task >> task[12]
worker:3,waitqueue:0
start task >> task[10]
interrupt worker >>>c4997ecf-1da3-4827-8fac-7e7f15b6e0cb
worker:2,waitqueue:0
end task >> 91f0a092-1dad-4ff6-966d-235227adbbd1>>>MyTask{taskName='task[12]'}
end task >> 5931157f-2dd5-4507-b494-f92e3b96e714>>>MyTask{taskName='task[10]'}
close worker >>>>>>>>>91f0a092-1dad-4ff6-966d-235227adbbd1
close worker >>>>>>>>>5931157f-2dd5-4507-b494-f92e3b96e714

需要注意

  • Hashset并不是线程安全的,所操作workers,需要用到锁ReentrantLock。
  • BlockingQueue是线程安全的,底层自带锁。
public void put(E e) throws InterruptedException {
        checkNotNull(e);
        final ReentrantLock lock = this.lock;
        lock.lockInterruptibly();
        try {
            while (count == items.length)
                notFull.await();
            enqueue(e);
        } finally {
            lock.unlock();
        }
    }

最后

为么使用BlockingQueue呢?

首先,它是线程安全的。BlockingQueue是一个特殊的队列,当我们从BlockingQueue中取数据时,如果BlockingQueue是空的,则取数据的操作会进入到阻塞状态,当 BlockingQueue 中有了新数据时,这个取数据的操作又会被重新唤醒。同理,如果 BlockingQueue 中的数据已经满了,往 BlockingQueue 中存数据的操作又会进入阻塞状态,直到 BlockingQueue 中又有新的空间,存数据的操作又会被重新唤醒。

既然HashSet线程不安全,为么线程池还要用它存放worker线程呢?

主要还是HashSet的特性,自动消除重复的数据,确保不会出现单个线程有多个entry,并且保持高效率。

实现了一个略微精简的线程池,还需要有很多不足和优化的地方:

  • worker线程名称待优化,通过线程计数器命名。
  • 增加一些统计值,线程数、任务数、线程峰值等等。
  • 当前worker线程退出的方式比较暴力,可增加keepAlivetime参数,允许线程的空闲时间。
  • 没有像ThreadPoolExecutor源码逻辑那么严谨。

当前实现代码,仅供参考。

后面会持续优化....