在进行单元测试的时候,有一个需求,需要在多线程中同时跑多个@test方法,看了一下junit的源码,终于找到方法了,分享一下 单元测试方法默认是在运行在org.junit.runners.JUnit4中的,但是是final类,不能继承,找到它的子类org.junit.runners.BlockJUnit4ClassRunner,JUnit4没有重写父类的任何方法,相关的操作可以直接继承BlockJUnit4ClassRunner重写相关方法 ThreadTest类,重写runChild中的相关方法加入到线程池中
public class ThreadTest extends BlockJUnit4ClassRunner {
private static final String TAG = ThreadTest.class.getSimpleName();
public ThreadTest(Class<?> klass) throws InitializationError {
super(klass);
ThreadPollTest.classCountInteger.addAndGet(1);
}
protected void runChild(final FrameworkMethod method, final RunNotifier notifier) {
final Description description = this.describeChild(method);
if (method.getAnnotation(Ignore.class) != null) {
synchronized (notifier) {
notifier.fireTestIgnored(description);
}
} else {
Runnable runnable = new Runnable() {
@Override
public void run() {
threadRunLeaf(ThreadTest.this.methodBlock(method), description, notifier);
}
};
//避免线程池中任务数量过多
while (ThreadPollTest.threadPollExecute.getQueue().size() > 1000) {
try {
sleep(10);
} catch (Exception e) {
e.printStackTrace();
}
}
ThreadPollTest.threadPollExecute.execute(runnable);
int addInThreadCount = ThreadPollTest.addInThreadCount.addAndGet(1);
if (!ThreadPollTest.hashMap.containsKey(description.getClassName())) {
ThreadPollTest.totalTestCount.addAndGet(testCount());
ThreadPollTest.hashMap.put(description.getClassName(), "");
}
while (ThreadPollTest.classCountInteger.get() <= ThreadPollTest.hashMap.size()
&& addInThreadCount >= ThreadPollTest.totalTestCount.get()
&& ThreadPollTest.threadPollExecute.getCompletedTaskCount() < ThreadPollTest.totalTestCount.get()) {
try {
while (ThreadPollTest.threadPollExecute.getActiveCount() > 0 || ThreadPollTest.threadPollExecute.getQueue().size() > 0) {
sleep(100);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
protected final void threadRunLeaf(Statement statement, Description description, RunNotifier notifier) {
EachTestNotifier eachNotifier = new EachTestNotifier(notifier, description);
synchronized (eachNotifier) {
eachNotifier.fireTestStarted();
}
try {
statement.evaluate();
} catch (AssumptionViolatedException var10) {
synchronized (eachNotifier) {
eachNotifier.addFailedAssumption(var10);
}
} catch (Throwable var11) {
synchronized (eachNotifier) {
eachNotifier.addFailure(var11);
}
} finally {
synchronized (eachNotifier) {
eachNotifier.fireTestFinished();
}
}
}
}线程池管理类,最少两个线程
public class ThreadPollTest {
public static final AtomicInteger classCountInteger = new AtomicInteger(0);
public static final AtomicInteger addInThreadCount = new AtomicInteger(0);
public static final AtomicInteger totalTestCount = new AtomicInteger(0);
public static final ConcurrentHashMap<String, String> hashMap = new ConcurrentHashMap<String, String>();
public static final ScheduledThreadPoolExecutor threadPollExecute = new ScheduledThreadPoolExecutor(getCupThreadCount(), new DefaultThreadFactory("THREAD_POLLS"));
public static int getCupThreadCount() {
int defaultCount = 2;
try {
int cpuCount = Runtime.getRuntime().availableProcessors();
defaultCount = cpuCount <= 0 ? 2 : cpuCount;
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("线程核心数:" + defaultCount);
return defaultCount;
}
}ThreadFactory的子类,可有可无,可以直接在线程池中设置线程数量
public class DefaultThreadFactory implements ThreadFactory {
private static final AtomicInteger POOL_NUMBER = new AtomicInteger (1);
private final ThreadGroup group;
private final AtomicInteger threadNumber = new AtomicInteger (1);
private final String namePrefix;
public DefaultThreadFactory(String name) {
SecurityManager s = System.getSecurityManager ();
group = (s != null) ? s.getThreadGroup () :
Thread.currentThread ().getThreadGroup ();
namePrefix = name + "-pool-" +
POOL_NUMBER.getAndIncrement () +
"-thread-";
}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread (group, r,
namePrefix + threadNumber.getAndIncrement (),
0);
if (t.isDaemon ()) {
t.setDaemon (false);
}
if (t.getPriority () != Thread.NORM_PRIORITY) {
t.setPriority (Thread.NORM_PRIORITY);
}
return t;
}
}最后在测试类中加入注解调用,单元测试的代码就完成了,其他的单元测试类直接继承BaseTest就可以实现单元测试不同方法在不同线程中跑
@RunWith(ThreadTest.class)
public abstract class BaseTest {
}