手撕线程池

37 阅读1分钟
class MyThreadPool{
    int threadPoolCoreSize;
    MyMessageQueue<Runnable> queue;
    long time;
    TimeUnit unit;
    HashSet<MyWorker> myWorkerHashSet;
    class MyWorker extends Thread{
        Runnable task;
        public MyWorker(Runnable task){
            this.task = task;
        }

        @Override
        public void run() {

            while(task!=null || (task = queue.get()) != null){
                task.run();
                task = null;
            }
            synchronized (myWorkerHashSet){
                myWorkerHashSet.remove(this);
            }



        }
    }
    public MyThreadPool(int threadPoolCoreSize, long time, TimeUnit unit,int messageQueueSize){
        this.threadPoolCoreSize = threadPoolCoreSize;
        this.time = time;
        this.unit = unit;
        this.queue = new MyMessageQueue<Runnable>(messageQueueSize);
    }

    public void execute(Runnable task){
        if(this.threadPoolCoreSize >= myWorkerHashSet.size()){
            MyWorker myWorker = new MyWorker(task);
            myWorkerHashSet.add(myWorker);
            myWorker.start();
        }else{
            queue.put(task);
        }
    }
}

interface MyRejectPolicy<T>{
    public void reject(MyMessageQueue<T> queue, T task);
}




class MyMessageQueue<T>{
    ReentrantLock lock = new ReentrantLock();
    Condition producerWaiter = lock.newCondition();
    Condition consumerWaiter = lock.newCondition();
    Queue<T> queue;
    int size;
    public MyMessageQueue(int size){
        this.queue = new LinkedList<>();
        this.size = size;
    }

    public void put(T task){
        lock.lock();
        try{
            while ((this.size == queue.size())) {
                producerWaiter.await();
            }
            queue.add(task);
            consumerWaiter.signalAll();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } finally {
            lock.unlock();
        }
    }

    public T get(){
        lock.lock();
        try {
            while(queue.isEmpty()){
                consumerWaiter.await();
            }
            T t = queue.remove();
            producerWaiter.signalAll();
            return t;

        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } finally {
            lock.unlock();
        }
    }

    public void putWithRejectPolicy(T task, MyRejectPolicy<T> rejectPolicy){

        lock.lock();
        try {
            if(queue.size() == size){
                rejectPolicy.reject(this,task);

            }else{
                queue.add(task);
                consumerWaiter.signalAll();
            }


        }finally {
            lock.unlock();
        }
    }

    public T poll(long timeout, TimeUnit unit){
        long nanos = unit.toNanos(timeout);
        lock.lock();
        try {
            while(queue.isEmpty()){

                if(nanos <= 0){
                    return null;
                }

                nanos = producerWaiter.awaitNanos(nanos);

            }
            T t = queue.remove();
            producerWaiter.signalAll();
            return t;


        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } finally {
            lock.unlock();
        }
    }


}

代码细节说明如下:

  1. 对于拒绝策略,这个拒绝的过程应该是发生在向阻塞队列添加消息的时候。因此应该在阻塞队列这个类里面添加带有阻塞队列的方法如96至112行。同时这个阻塞队列的函数的形参应该是这个队列以及任务对象,原因是应该在这个队列添加任务消息的时候作出一些响应动作。
  2. 在工作对象,myWorker类,应该做的事情,不止仅仅传入的Runnable对象,同时还应该持续不断地获取消息队列里面的其他消息,当没有消息的时候才算是结束。对应于代码的14至26行。同时完成任务以后,应该在工作队列中移除这个消息。