浅解 JUnit 4 第十七篇:如何实现一个简单的 Runner?(上)

14 阅读2分钟

与普通的测试类 T\text{T} 对应的 Runner:RunnerT\text{Runner}:\text{Runner}_\text{T} 会是 org.junit.runners.JUnit4\text{org.junit.runners.JUnit4} 类型,如果我们有定制化的需求,能否自己实现一个 Runner\text{Runner} 并让 JUnit 4 运行它呢?本文会探讨这个问题。

要点

  • 通过继承 org.junit.runners.BlockJUnit4ClassRunner\text{org.junit.runners.BlockJUnit4ClassRunner},我们可以实现自己的 Runner\text{Runner}
  • 在测试类 T\text{T} 上加上 @RunWith(XXXRunner.class) 注解,就可以让 XXXRunner\text{XXXRunner} 来运行 T\text{T} 中的测试方法

正文

对普通的测试类 T\text{T} 而言,它对应的 Runner\text{Runner} 会是 org.junit.runners.JUnit4\text{org.junit.runners.JUnit4} 类型。它的简要类图如下 ⬇️

image.png

如果我们想统计每个测试方法的耗时,一个方案是实现一个专门的 Runner\text{Runner} 来完成这个逻辑。由于 org.junit.runners.JUnit4\text{org.junit.runners.JUnit4}final class,我们无法继承它。但是我们可以继承它的父类,即 org.junit.runners.BlockJUnit4ClassRunner\text{org.junit.runners.BlockJUnit4ClassRunner}。下一小节会展示项目中的代码。

项目代码

项目结构如下

.
├── pom.xml
└── src
    ├── main
    │   └── java
    │       └── org
    │           └── study
    │               └── SimpleProductCalculator.java
    └── test
        └── java
            └── org
                └── study
                    ├── runners
                    │   └── MyRunner.java
                    └── SimpleProductCalculatorTest.java

SimpleProductCalculator.java

SimpleProductCalculator.java 的代码如下 ⬇️ (SimpleProductCalculator 用原始的方式计算两个非负整数的乘积)

package org.study;

public class SimpleProductCalculator {
    public int calculateProduct(int a, int b) {
        int product = 0;
        while (b > 0) {
            product += a;
            b--;
        }
        return product;
    }
}

MyRunner.java

MyRunner.java 的代码如下

package org.study.runners;

import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;

import java.util.ArrayList;
import java.util.List;

public class MyRunner extends BlockJUnit4ClassRunner {

    public MyRunner(Class<?> clazz) throws InitializationError {
        super(clazz);
    }

    @Override
    public void run(RunNotifier notifier) {
        displayChildrenInfo();
        super.run(notifier);
    }

    private void displayChildrenInfo() {
        List<FrameworkMethod> children = getChildren();

        List<String> lines = new ArrayList<>();
        lines.add(String.format("There are %s child nodes (i.e. test methods)", children.size()));
        lines.add("Child nodes are listed as follows");
        lines.add("");

        int n = 0;
        for (FrameworkMethod child : children) {
            n++;
            String info = String.format("%s. %s()", n, child.getName());
            lines.add(info);
        }

        prettyPrint(lines);
    }

    private void prettyPrint(List<String> lines) {
        int maxWidth = lines.stream().map(String::length)
                .mapToInt(x -> x)
                .max()
                .orElseThrow();

        System.out.println("+".repeat(maxWidth + 4));
        for (String line : lines) {
            System.out.println("| " + line + " ".repeat(maxWidth - line.length()) + " |");
        }
        System.out.println("+".repeat(maxWidth + 4));
    }

    @Override
    protected void runChild(FrameworkMethod method, RunNotifier notifier) {
        long t1 = System.currentTimeMillis();
        super.runChild(method, notifier);
        long t2 = System.currentTimeMillis();

        String message =
                String.format(
                        "It took %s ms to run %s() test method",
                        (t2 - t1),
                        method.getName()
                );
        System.out.println(message);
    }
}

SimpleProductCalculatorTest.java

SimpleProductCalculatorTest.java 的代码如下 ⬇️

package org.study;

import org.junit.*;
import org.junit.runner.JUnitCore;
import org.junit.runner.Result;
import org.junit.runner.RunWith;
import org.junit.runner.notification.Failure;
import org.study.runners.MyRunner;

@RunWith(MyRunner.class)
public class SimpleProductCalculatorTest {

    private final SimpleProductCalculator productCalculator = new SimpleProductCalculator();

    @Test
    public void test_case1() {
        int a = 1000;
        int b = 1000;
        Assert.assertEquals(a * b, productCalculator.calculateProduct(a, b));
    }

    @Test
    public void test_case2() {
        int a = 10000;
        int b = 10000;
        Assert.assertEquals(a * b, productCalculator.calculateProduct(a, b));
    }

    @Test
    public void test_case3() {
        int a = 10;
        int b = 1_0000_0000;
        Assert.assertEquals(a * b, productCalculator.calculateProduct(a, b));
    }

    public static void main(String[] args) {
        Result result = JUnitCore.runClasses(SimpleProductCalculatorTest.class);
        for (Failure failure : result.getFailures()) {
            System.out.println(failure);
        }
    }
}

pom.xml

pom.xml 的内容如下 ⬇️

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>junit-study</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>25</maven.compiler.source>
        <maven.compiler.target>25</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
    <dependencies>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.13.2</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

</project>

项目中定义的类的简要类图如下 ⬇️

image.png

运行结果

运行 SimpleProductCalculatorTestmain 方法,应该可以看到类似下图的效果(具体的运行耗时可能有差异)

image.png

其他

用 PlantUML 画图,所用到的代码

画 "org.junit.runners.JUnit4 的简要类图" 所用到的代码

@startuml

title <i>org.junit.runners.JUnit4</i> 的简要类图

interface org.junit.runner.Describable
abstract org.junit.runner.Runner
interface org.junit.runner.manipulation.Filterable
interface org.junit.runner.manipulation.Sortable
interface org.junit.runner.manipulation.Orderable
abstract org.junit.runners.ParentRunner<T>
class org.junit.runners.BlockJUnit4ClassRunner
class org.junit.runners.JUnit4

org.junit.runner.Describable <|.. org.junit.runner.Runner
org.junit.runner.Runner <|-- org.junit.runners.ParentRunner
org.junit.runner.manipulation.Filterable <|.. org.junit.runners.ParentRunner
org.junit.runner.manipulation.Sortable <|-- org.junit.runner.manipulation.Orderable
org.junit.runner.manipulation.Orderable <|.. org.junit.runners.ParentRunner
org.junit.runners.ParentRunner <|-- org.junit.runners.BlockJUnit4ClassRunner : extends ParentRunner<FrameworkMethod>
org.junit.runners.BlockJUnit4ClassRunner <|-- org.junit.runners.JUnit4

@enduml

画 "项目中定义的类的简要类图" 所用到的代码

@startuml

title 项目中定义的类的简要类图
caption 注意: 图中没有展示 <i>MyRunner</i> 的继承体系

class org.study.SimpleProductCalculator {
    + int calculateProduct(int a, int b)
}

class org.study.runners.MyRunner {
    + MyRunner(Class<?> clazz) throws InitializationError
    + void run(RunNotifier notifier)
    - void displayChildrenInfo()
    - void prettyPrint(List<String> lines)
    # void runChild(FrameworkMethod method, RunNotifier notifier)
}

class org.study.SimpleProductCalculatorTest {
    - final SimpleProductCalculator productCalculator
    + void test_case1()
    + void test_case2()
    + void test_case3()
    + {static} void main(String[] args)
}

@enduml

参考资料