使 单元测试 junit 支持多线程

2,213 阅读2分钟

在进行单元测试的时候,有一个需求,需要在多线程中同时跑多个@test方法,看了一下junit的源码,终于找到方法了,分享一下 单元测试方法默认是在运行在org.junit.runners.JUnit4中的,但是是final类,不能继承,找到它的子类org.junit.runners.BlockJUnit4ClassRunner,JUnit4没有重写父类的任何方法,相关的操作可以直接继承BlockJUnit4ClassRunner重写相关方法 ThreadTest类,重写runChild中的相关方法加入到线程池中

public class ThreadTest extends BlockJUnit4ClassRunner {
    private static final String TAG = ThreadTest.class.getSimpleName();

    public ThreadTest(Class<?> klass) throws InitializationError {
        super(klass);
        ThreadPollTest.classCountInteger.addAndGet(1);
    }

    protected void runChild(final FrameworkMethod method, final RunNotifier notifier) {
        final Description description = this.describeChild(method);
        if (method.getAnnotation(Ignore.class) != null) {
            synchronized (notifier) {
                notifier.fireTestIgnored(description);
            }
        } else {
            Runnable runnable = new Runnable() {
                @Override
                public void run() {
                    threadRunLeaf(ThreadTest.this.methodBlock(method), description, notifier);
                }
            };
            //避免线程池中任务数量过多
            while (ThreadPollTest.threadPollExecute.getQueue().size() > 1000) {
                try {
                    sleep(10);
                } catch (Exception e) {
                    e.printStackTrace();
                }

            }
            ThreadPollTest.threadPollExecute.execute(runnable);
            int addInThreadCount = ThreadPollTest.addInThreadCount.addAndGet(1);

            if (!ThreadPollTest.hashMap.containsKey(description.getClassName())) {
                ThreadPollTest.totalTestCount.addAndGet(testCount());
                ThreadPollTest.hashMap.put(description.getClassName(), "");

            }

            while (ThreadPollTest.classCountInteger.get() <= ThreadPollTest.hashMap.size()
                    && addInThreadCount >= ThreadPollTest.totalTestCount.get()
                    && ThreadPollTest.threadPollExecute.getCompletedTaskCount() < ThreadPollTest.totalTestCount.get()) {
                try {
                    while (ThreadPollTest.threadPollExecute.getActiveCount() > 0 || ThreadPollTest.threadPollExecute.getQueue().size() > 0) {
                        sleep(100);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

    }

    protected final void threadRunLeaf(Statement statement, Description description, RunNotifier notifier) {
        EachTestNotifier eachNotifier = new EachTestNotifier(notifier, description);
        synchronized (eachNotifier) {
            eachNotifier.fireTestStarted();
        }

        try {
            statement.evaluate();
        } catch (AssumptionViolatedException var10) {
            synchronized (eachNotifier) {
                eachNotifier.addFailedAssumption(var10);
            }
        } catch (Throwable var11) {
            synchronized (eachNotifier) {
                eachNotifier.addFailure(var11);
            }
        } finally {
            synchronized (eachNotifier) {
                eachNotifier.fireTestFinished();

            }
        }

    }
}

线程池管理类,最少两个线程

public class ThreadPollTest {
    public static final AtomicInteger classCountInteger = new AtomicInteger(0);
    public static final AtomicInteger addInThreadCount = new AtomicInteger(0);
    public static final AtomicInteger totalTestCount = new AtomicInteger(0);
    public static final ConcurrentHashMap<String, String> hashMap = new ConcurrentHashMap<String, String>();
    public static final ScheduledThreadPoolExecutor threadPollExecute = new ScheduledThreadPoolExecutor(getCupThreadCount(), new DefaultThreadFactory("THREAD_POLLS"));

    public static int getCupThreadCount() {
        int defaultCount = 2;
        try {
            int cpuCount = Runtime.getRuntime().availableProcessors();
            defaultCount = cpuCount <= 0 ? 2 : cpuCount;
        } catch (Exception e) {
            e.printStackTrace();
        }
        System.out.println("线程核心数:" + defaultCount);
        return defaultCount;
    }
}

ThreadFactory的子类,可有可无,可以直接在线程池中设置线程数量

public class DefaultThreadFactory implements ThreadFactory {
    private static final AtomicInteger POOL_NUMBER = new AtomicInteger (1);
    private final ThreadGroup group;
    private final AtomicInteger threadNumber = new AtomicInteger (1);
    private final String namePrefix;

    public DefaultThreadFactory(String name) {
        SecurityManager s = System.getSecurityManager ();
        group = (s != null) ? s.getThreadGroup () :
                Thread.currentThread ().getThreadGroup ();
        namePrefix = name + "-pool-" +
                POOL_NUMBER.getAndIncrement () +
                "-thread-";
    }

    @Override
    public Thread newThread(Runnable r) {
        Thread t = new Thread (group, r,
                namePrefix + threadNumber.getAndIncrement (),
                0);
        if (t.isDaemon ()) {
            t.setDaemon (false);
        }
        if (t.getPriority () != Thread.NORM_PRIORITY) {
            t.setPriority (Thread.NORM_PRIORITY);
        }
        return t;
    }
}

最后在测试类中加入注解调用,单元测试的代码就完成了,其他的单元测试类直接继承BaseTest就可以实现单元测试不同方法在不同线程中跑

@RunWith(ThreadTest.class)
public abstract class BaseTest {
}