自动生成测试类,并填充测试数据(一)

560 阅读3分钟

一 说明

在日常开发过程中,需要对功能做单元测试,还需要制造测试数据。在测试的过程中,造数是比较蛋疼的事情,并且需要编写大量的set方法。所以就有感制作一个自动生成测试用例的代码。

二 设计思路

内容其实很简单,就是通过固定的字符串拼接 + 反射机制来实现功能。其中有几个需要注意的地方是,判断基本数据类型,List,Map,存在泛型的类,这些是否有复杂的嵌套(暂时没有实现多层嵌套,可以作为优化点。)

由于在开发过程中,核心逻辑都在Service层,所以目前开发了Service的单元测试生成器。后面再计划添加Controller层的。

三 上代码

本代码依赖 hutool 的工具包,请自行添加。

1 工具类 GenStrUtil

package com.cah.project.mock.util;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * 功能描述: 本项目的字符串工具类,扩展hutool的StrUtil <br/>
 */
public class GenStrUtil extends StrUtil {

    /**
     * 功能描述: 获取类型简称 <br/>
     *
     * @param type 类型
     * @return "java.lang.String"
     */
    public static String getTypeSimpleName(Type type) {
        String typeName = type.getTypeName();
        // 是否存在泛型,如 List<E>, Map<K,V>等
        if(typeName.contains("<")) {
            String child = typeName.substring(typeName.indexOf("<") + 1, typeName.indexOf(">"));
            String[] children = child.split(",");
            typeName = typeName.substring(0, typeName.indexOf("<"));
            return typeName.substring(typeName.lastIndexOf(".") + 1) +
                    "<" + Stream.of(children).map(c -> c.substring(c.lastIndexOf(".") + 1)).collect(Collectors.joining(", ")) + ">";
        } else {
            return typeName.substring(typeName.lastIndexOf(".") + 1);
        }
    }

    /**
     * 功能描述: 获取java类原本引入的包 <br/>
     *
     * @param projectName 项目模块名称
     * @param clazz java类
     * @return "java.lang.String"
     */
    public static String getClassImport(String projectName, Class<?> clazz) {
        File javaFile = new File(getJavaFilePath(projectName, clazz));
        try(BufferedReader reader = FileUtil.getUtf8Reader(javaFile)) {
            String line;
            StringBuilder sb = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                if(line.indexOf("import ") == 0) {
                    sb.append(line).append("\n");
                }
                if(line.contains("public")) {
                    break;
                }
            }
            return sb.toString();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return "";
    }

    /**
     * 功能描述: 获取java文件类路径 <br/>
     *
     * @param projectName 项目模块名称
     * @param clazz java类
     * @return "java.lang.String"
     */
    public static String getJavaFilePath(String projectName, Class<?> clazz) {
        return getProjectModulePath(projectName) + File.separator +
                "src\\main\\java\\" +
                clazz.getName().replaceAll("\\.", "\\\\") + ".java";
    }

    /**
     * 功能描述: 获取项目模块路径 <br/>
     *
     * @param projectName 项目模块名称
     * @return "java.lang.String"
     */
    public static String getProjectModulePath(String projectName) {
    	// 这里可以改进一下,将 System.getProperty("user.dir") 替换成固定路径,或者看着改吧。
        return System.getProperty("user.dir") + File.separator + projectName;
    }

}

2 生成类 GenServiceTest

package com.cah.project.mock.gen;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.EnumUtil;
import com.cah.project.mock.util.GenStrUtil;
import com.cah.project.module.standard.service.IDictTypeService;
import lombok.SneakyThrows;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * 功能描述: 生成服务的测试方法 <br/>
 */
public class GenServiceTest {

    /** 项目目录 */
    private static final String PROJECT_MODULE = "project-service";
    /** 引入的包列表 */
    private static final Set<String> IMPORT_LIST = new HashSet<>();

    public static void main(String[] args) throws Exception {
    	// 这里替换成自己需要生成的服务接口。
        Class<?> clazz = IDictTypeService.class;
        String className = clazz.getSimpleName();
        String clazzName = GenStrUtil.lowerFirst(className);
        // 包路径
        String packageStr = clazz.getPackage() + ";\n\n";
        // 文件头
        String javaHeadStr = "\n" +
                "/**\n" +
                " * <p> TODO </p>\n" +
                " * \n" +
                " * @author TODO \n" +
                " * @date " + DateUtil.now() + "\n" +
                " */\n" +
                "@SpringBootTest(classes = Application.class)\n" +
                "@RunWith(SpringRunner.class)\n" +
                "public class " + className + "Test {\n";
        // 引入服务
        String autowiredStr = "\n" +
                "    @Autowired\n" +
                "    private " + className + " " + clazzName + ";\n\n";
        // 组装测试方法体(主体逻辑)
        String methodStr = getTestMethods(clazz, clazzName);
        // 添加固定引入包
        IMPORT_LIST.add("import org.junit.Test;\n");
        IMPORT_LIST.add("import org.junit.runner.RunWith;\n");
        IMPORT_LIST.add("import org.springframework.beans.factory.annotation.Autowired;\n");
        IMPORT_LIST.add("import org.springframework.boot.test.context.SpringBootTest;\n");
        IMPORT_LIST.add("import org.springframework.test.context.junit4.SpringRunner;\n");

        // 组装结果
        String str = packageStr +
                GenStrUtil.getClassImport(PROJECT_MODULE, clazz) +
                String.join("", IMPORT_LIST) +
                javaHeadStr +
                autowiredStr +
                methodStr +
                "}\n";
        // 输出
        System.out.println(str);
    }

    /**
     * 功能描述: 组装测试方法体 <br/>
     *
     * @param clazz     类
     * @param clazzName 类的属性名称
     * @return "java.lang.String"
     */
    private static String getTestMethods(Class<?> clazz, String clazzName) {
        StringBuilder sb = new StringBuilder();
        // 获取自身全部的方法
        Method[] methods = clazz.getDeclaredMethods();
        // 方法重载的加序号
        Map<String, Integer> methodNameMap = new HashMap<>();
        for(Method method : methods) {
            if(method.isDefault() || method.getName().contains("$")) {
                // 是 lambda 的,则跳过
                continue;
            }
            // 测试方法名(避免重复)
            String methodName = method.getName();
            int count = Optional.ofNullable(methodNameMap.get(methodName)).orElse(0);
            if(count > 0) {
                methodName = methodName + count;
            }
            methodNameMap.put(methodName, count + 1);
            // 组装方法名
            sb.append("    @Test\n");
            sb.append("    public void test" + GenStrUtil.upperFirst(methodName) + "() {\n");
            // 定义入参名称
            List<String> intoNameList = new ArrayList<>();
            Type[] paramTypes = method.getGenericParameterTypes();
            for(int i = 0; i < paramTypes.length; i++) {
                Type paramType = paramTypes[i];
                // 参数名
                String paramName = "arg" + i;
                // 获取参数类型
                String paramTypeName = GenStrUtil.getTypeSimpleName(paramType);
                // 组装
                sb.append("        " + paramTypeName + " " + paramName + " = " + getNewObject(paramType, paramName));
                // 添加到参数列表中
                intoNameList.add(paramName);
            }
            // 入参集合拼接
            String intoName = CollUtil.isEmpty(intoNameList) ? "" : String.join(", ", intoNameList);
            // 组装执行与返回
            Type returnType = method.getGenericReturnType();
            String execMethod = clazzName + "." + method.getName() + "(" + intoName + ");\n";
            sb.append("        ");
            if(!GenStrUtil.equals("void", returnType.getTypeName())) {
                sb.append(GenStrUtil.getTypeSimpleName(returnType) + " result = ");
            }
            sb.append(execMethod);
            if(!GenStrUtil.equals("void", returnType.getTypeName())) {
                sb.append("        System.out.println(result);\n");
            }
            sb.append("    }\n\n");
        }
        return sb.toString();
    }

    /**
     * 功能描述: 构建新对象 <br/>
     *
     * @param type 对象类型
     * @return "java.lang.String"
     */
    @SneakyThrows
    private static String getNewObject(Type type, String paramName) {
        String typeName = type.getTypeName();
        if (typeName.contains("List")) {
            IMPORT_LIST.add("import java.util.ArrayList;\n");
            return "new ArrayList<>();\n";
        }
        if (typeName.contains("Map")) {
            IMPORT_LIST.add("import java.util.HashMap;\n");
            return "new HashMap<>();\n";
        }
        if(typeName.contains("Integer") || typeName.contains("int")) {
            return "1;\n";
        }
        if(typeName.contains("Long") || typeName.contains("long")) {
            return "1L;\n";
        }
        if(typeName.contains("String")) {
            return "\"\";\n";
        }
        if(typeName.contains("Boolean") || typeName.contains("boolean")) {
            return "false;\n";
        }
        if(typeName.contains("BigDecimal")) {
            return "BigDecimal.ZERO;\n";
        }
        Class<?> clazz = Class.forName(typeName);
        if(EnumUtil.isEnum(clazz)) {
            return GenStrUtil.getTypeSimpleName(type) + "." + EnumUtil.getNames((Class<? extends Enum<?>>) clazz).get(0) + ";\n";
        } else {
            String newStr =  "new " + clazz.getSimpleName() + "();\n";
            // 添加set属性
            StringBuilder sb = new StringBuilder();
            Field[] fields = clazz.getDeclaredFields();
            for(Field field : fields) {
                // 终态和静态不设置
                if(!Modifier.isFinal(field.getModifiers()) && !Modifier.isStatic(field.getModifiers())) {
                    System.out.println();
                    sb.append("        " + paramName + ".set" + GenStrUtil.upperFirst(field.getName()) + "(\"XXXX\");\n");
                }
            }
            return newStr + sb.toString();
        }
    }

}

四 结语

该生成器有待优化的内容:

  • 多级泛型的判断与创建
  • 在set的时候,判断类型
  • 自动添加mock数据(下一个为mock框架,有业务逻辑的mock数据)

五 感谢关注

优惠券合集,购买返佣。

柒宅mini二维码.jpg