自定义线程池

158 阅读1分钟
package com.angel.item.advanced.diy;

import com.google.common.collect.Sets;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;

/**
 * 自定义线程池
 */
@Slf4j
public class ThreadPool<T> {

    private final ReentrantLock lock = new ReentrantLock();
    /**
     * 核心线程数大小
     */
    private final Integer corePoolSize;
    /**
     * 空闲线程存活时长
     */
    private final Long keepAliveTime;
    private final BlockQueue<T> workerQueue;
    private final BiConsumer<BlockQueue<T>, Runnable> rejectPolicy;
    private final Set<Worker> workers = Sets.newConcurrentHashSet();
    /**
     * 存活时长单位
     */
    private final TimeUnit unit;
    /**
     * 最大线程数大小
     */
    private Integer maximumPoolSize;

    public ThreadPool(Integer corePoolSize, Long keepAliveTime, TimeUnit unit, BlockQueue<T> workerQueue, BiConsumer<BlockQueue<T>, Runnable> rejectPolicy) {
        this.corePoolSize = corePoolSize;
        this.keepAliveTime = keepAliveTime;
        this.unit = unit;
        this.workerQueue = workerQueue;
        this.rejectPolicy = rejectPolicy;
    }

    public static void main(String[] args) {
        ThreadPool<Runnable> threadPool = new ThreadPool<>(5, 10L, TimeUnit.MICROSECONDS, new BlockQueue<>(5), (queue, task) -> {
            // 饱和策略
            // 抛出异常
            // throw new RuntimeException("任务队列已满!");
            // 等待一段时间
            // queue.push(task, 10, TimeUnit.SECONDS);
            // 让调用者自己执行
            // task.run();
        });

        for (int i = 0; i < 20; i++) {
            int j = i;
            threadPool.execute(() -> {
                try {
                    Thread.sleep(1000L);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                log.debug("{}", j);
            });
        }


    }

    public void execute(Runnable runnable) {
        lock.lock();
        try {
            if (workers.size() < corePoolSize) {
                // 当线程数没有超过核心线程数的时候, 就会创建一个新的线程去执行任务
                Worker worker = new Worker(runnable);
                workers.add(worker);
                worker.start();
            } else {
                // 当达到核心线程数的时候, 就需要进入阻塞队列中进入等待, 等线程空闲以后, 再去队列中取任务执行
                workerQueue.tryPush(rejectPolicy, runnable);
            }
        } finally {
            lock.unlock();
        }
    }

    class Worker extends Thread {

        private Runnable task;

        public Worker(Runnable task) {
            this.task = task;
        }

        @Override
        public void run() {
            while (task != null || (task = (Runnable) workerQueue.pop(keepAliveTime, unit)) != null) {
                try {
                    task.run();
                } finally {
                    task = null;
                }
            }

            lock.lock();
            try {
                // 当线程没有任务执行的时候, 则将线程从worker中移除, 该线程则会自己销毁
                if (workers.contains(this)) {
                    log.debug("worker 被移除{}", this);
                    workers.remove(this);
                }
            } finally {
                lock.unlock();
            }
        }
    }
}

/**
 * 阻塞队列: 用于存放还未执行的任务
 * 和生产者消费者模型一样
 */
class BlockQueue<T> extends ReentrantLock {

    /**
     * 队列容量
     */
    private final int capacity;
    /**
     * 阻塞队列容器
     */
    private final Deque<T> queue;
    public Condition produce = this.newCondition();
    public Condition consumer = this.newCondition();


    public BlockQueue(int capacity) {
        this.capacity = capacity;
        this.queue = new ArrayDeque(capacity);
    }

    /**
     * 获取队列中的数据
     *
     * @param timeout  超时时间
     * @param timeUnit 超时单位
     * @return 返回任务, 如果超时则返回null
     */
    public T pop(long timeout, TimeUnit timeUnit) {
        lock();
        try {
            long timeoutNanos = timeUnit.toNanos(timeout);
            while (CollectionUtils.isEmpty(queue)) {
                try {
                    // 当等待时间小于等于0的时候, 说明超时时间已过, 直接return
                    if (timeoutNanos <= 0) {
                        return null;
                    }
                    // awaitNanos 的返回值是剩余的等待时间
                    timeoutNanos = consumer.awaitNanos(timeoutNanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            T t = queue.removeFirst();
            produce.signalAll();
            return t;
        } finally {
            unlock();
        }
    }

    /**
     * 向阻塞队列中添加任务
     *
     * @param t        任务对象
     * @param timeout  超时时间
     * @param timeUnit 超时单位
     */
    public void push(T t, long timeout, TimeUnit timeUnit) {
        lock();
        try {
            long timeoutNanos = timeUnit.toNanos(timeout);
            while (queue.size() == capacity) {
                try {
                    if (timeoutNanos <= 0) {
                        return;
                    }
                    timeoutNanos = produce.awaitNanos(timeoutNanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            queue.addLast(t);
            consumer.signalAll();
        } finally {
            unlock();
        }
    }

    public void tryPush(BiConsumer<BlockQueue<T>, Runnable> reject, Runnable runnable) {
        lock();
        try {
            if (queue.size() == capacity) {
                // 队列满了以后, 具体的饱和策略, 由调用者实现
                reject.accept(this, runnable);
            } else {
                queue.addLast((T) runnable);
                consumer.signalAll();
            }
        } finally {
            unlock();
        }
    }
}