多线程执行封装实现

468 阅读2分钟

起因

总有一些时候,领导说,返回慢,能不能并行读取一下;原来有同事写了一个,我觉得难理解,我就自己写了一个

目的

多线程调用使用简单,节省代码

代码实现

部分库调用:

  • import cn.hutool.core.collection.CollUtil;
  • import cn.hutool.core.lang.Dict;
/**
 * 多线程执行
 * <p>
 * Created by songzhaoying on 2021/8/19 15:59.
 *
 * @author songzhaoying@com.
 * @date 2021/8/19 15:59.
 */
public class ParallelReadUtil {

    private static final Logger logger = LoggerFactory.getLogger(ParallelReadUtil.class);

    /**
     * 超时 seconds
     */
    private static final long TIMEOUT_SECONDS = 10;


    private static final int QUEUE_CAPACITY = 160;

    private static final int CORE_POOL_SIZE = 8;

    private static final int MAXIMUM_POOL_SIZE = 16;

    private static final long KEEP_ALIVE_SECONDS = 60;

    /**
     * 隐式公有构造
     */
    private ParallelReadUtil() {

    }

    /**
     * 自定义线程池
     */
    protected static final ThreadPoolExecutor EXECUTOR_SERVICE = new ThreadPoolExecutor(CORE_POOL_SIZE
            , MAXIMUM_POOL_SIZE
            , KEEP_ALIVE_SECONDS, TimeUnit.SECONDS
            , new LinkedBlockingQueue<>(QUEUE_CAPACITY)
            , new BillTreadFactory(), new ExecutorRejectedPolicy());

    /**
     * 多线程返回 字典
     *
     * @param parallelReadSupplierMap
     * @return
     * @throws Exception
     */
    public static Dict getCompletableDict(Map<String, Supplier<Object>> parallelReadSupplierMap) throws Exception {
        Map<String, Object> retMap = mapCompletable(parallelReadSupplierMap);
        return new Dict(retMap);
    }

    /**
     * 多线程返回 map
     *
     * @param parallelReadSupplierMap
     * @return
     * @throws Exception
     */
    public static Map<String, Object> mapCompletable(Map<String, Supplier<Object>> parallelReadSupplierMap) throws Exception {
        Map<String, Object> retMap = new HashMap<>();
        if (CollectionUtils.isEmpty(parallelReadSupplierMap)) {
            return retMap;
        }

        List<Supplier<Object>> supplierList = new ArrayList(parallelReadSupplierMap.values());
        CompletableFuture<List<Object>> completableFutureList = listCompletable0(supplierList, true);
        List<Object> retList = completableFutureList.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);

        return CollUtil.zip(parallelReadSupplierMap.keySet(), retList);
    }

    /**
     * 多线程执行,不取返回
     *
     * @param suppliers
     * @return
     * @throws Exception
     */
    public static void consumerListCompletable(Supplier<Object>... suppliers) throws Exception {
        consumerListCompletable(true, suppliers);
        return;
    }

    /**
     * 多线程执行,不取返回
     *
     * @param careException true 异常中断主线程
     * @param suppliers
     * @return
     * @throws Exception
     */
    public static void consumerListCompletable(boolean careException, Supplier<Object>... suppliers) throws Exception {
        CompletableFuture<List<Object>> listCompletableFuture = listCompletable0(Arrays.asList(suppliers), false);

        if (Boolean.TRUE.equals(careException) && listCompletableFuture != null) {
            listCompletableFuture.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
        }

        return;
    }

    /**
     * 多线程返回
     *
     * @param supplierList
     * @param isReturn
     * @return
     * @throws Exception
     */
    private static CompletableFuture<List<Object>> listCompletable0(List<Supplier<Object>> supplierList, boolean isReturn) throws Exception {
        if (CollectionUtils.isEmpty(supplierList)) {
            throw new RuntimeException("参数错误");
        }
        boolean existNull = supplierList.stream().anyMatch(Objects::isNull);
        if (existNull) {
            throw new RuntimeException("参数错误, 不允许存在null Supplier");
        }

        CompletableFuture[] completableFutureArr = supplierList.stream()
                .map(t -> CompletableFuture.supplyAsync(t, EXECUTOR_SERVICE)).toArray(CompletableFuture[]::new);

        CompletableFuture<Void> voidCompletableFuture = CompletableFuture.allOf(completableFutureArr);

        // 不需要返回结果
        if (!isReturn) {
            return voidCompletableFuture.thenApply(v ->
                    Stream.of(completableFutureArr)
                            .map(future -> null)
                            .collect(Collectors.toList()));
        }

        // 存在失效就 整体异常
        return voidCompletableFuture.thenApply(v ->
                Stream.of(completableFutureArr)
                        .map(t -> t.exceptionally(e -> {
                            voidCompletableFuture.completeExceptionally((Throwable) e);
                            throw new RuntimeException("exception: " + ((Throwable) e).getMessage());
                        })).map(future -> future.join())
                        .collect(Collectors.toList())
        );
    }

    /**
     * 线程工厂
     */
    static class BillTreadFactory implements ThreadFactory {

        private final AtomicInteger mThreadNum = new AtomicInteger(1);

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r, "pBill-thread-" + mThreadNum.getAndIncrement());

            logger.info("{} has been created", t.getName());

            if (logger.isDebugEnabled()) {
                System.out.println(t.getName() + " has been created");
            }

            return t;
        }
    }

    /**
     * 拒绝策略
     */
    public static class ExecutorRejectedPolicy extends ThreadPoolExecutor.CallerRunsPolicy {

        @Override
        public void rejectedExecution(Runnable r, ThreadPoolExecutor e) {
            doLog(r, e);

            logger.info("{} main Thread run", r.getClass().getName());
            super.rejectedExecution(r, e);

            return;
        }

        /**
         * 拒绝日志
         *
         * @param r
         * @param e
         */
        private void doLog(Runnable r, ThreadPoolExecutor e) {
            logger.info("{} ThreadPoolExecutor rejected", r.getClass().getName());

            if (logger.isDebugEnabled()) {
                System.err.println(r.getClass().getName() + " rejected");
            }

            return;
        }
    }

    /**
     * test
     *
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        Map<String, Supplier<Object>> map = new HashMap<>();
        map.put("test1", () -> {
            System.out.println("get test1 at: " + System.currentTimeMillis() + ", thread: " + Thread.currentThread().getName());
            return "hello";
        });
        map.put("test2", () -> {
            System.out.println("get test2 at: " + System.currentTimeMillis() + ", thread: " + Thread.currentThread().getName());
            if (false) {
                throw new RuntimeException("getCompletableDict error");
            }
            return 100;
        });

        consumerListCompletable(false
                , () -> {
                    System.out.println("consumer test1 at: " + System.currentTimeMillis() + ", thread: " + Thread.currentThread().getName());
                    if (true) {
                        throw new RuntimeException("consumerListCompletable error");
                    }
                    System.out.println("consumer test1 error end");
                    return null;
                }
                , () -> {
                    System.out.println("consumer test2 at: " + System.currentTimeMillis() + ", thread: " + Thread.currentThread().getName());
                    return null;
                });

        Dict completableDict = ParallelReadUtil.getCompletableDict(map);
        String test1 = completableDict.get("test1", "");
        Integer test2 = completableDict.get("test2", null);

        System.out.println("f1: " + test1);
        System.out.println("f2: " + test2);
    }
}