Mybatis mapper, 参数预检,检验pojo中是否漏写参数

75 阅读2分钟
package priv.wjh.study.mybatis;

import org.apache.ibatis.builder.xml.XMLMapperBuilder;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.ParamNameResolver;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.LocalCacheScope;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeAliasRegistry;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.util.ClassUtils;

import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 检验pojo中是否漏写参数
 * - 预填充参数,检查是否有参数不存在pojo中
 */
public class MybatisMapperCheck {

    public static void main(String[] args) {
        try {
            test();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    public static void test() throws IOException {
        // 读取的xml文件位置
        String path = "classpath*:dao/**.xml";

        // 读取xml文件
        ResourcePatternResolver resourceLoader = new PathMatchingResourcePatternResolver();
        Resource[] mapperResources = resourceLoader.getResources(path);
        Configuration configuration = new Configuration();
        configuration.setCacheEnabled(false);
        configuration.setLocalCacheScope(LocalCacheScope.STATEMENT);
        configuration.setJdbcTypeForNull(JdbcType.NULL);
        TypeAliasRegistry typeAliasRegistry = configuration.getTypeAliasRegistry();
        // 注册类型别名
        Set<String> typeAliasesPackage = getTypeAliasesPackage("com.**.domain");
        for (String string : typeAliasesPackage) {
            typeAliasRegistry.registerAliases(string);
        }


        // 解析xml文件
        for (Resource mapperLocation : mapperResources) {
            if (mapperLocation == null) {
                continue;
            }
            try {
                XMLMapperBuilder xmlMapperBuilder = new XMLMapperBuilder(mapperLocation.getInputStream(),
                                                                         configuration, mapperLocation.toString(),
                                                                         configuration.getSqlFragments());
                xmlMapperBuilder.parse();
            } catch (Exception e) {
                throw new RuntimeException("Failed to parse mapping resource: '" + mapperLocation + "'", e);
            } finally {
                ErrorContext.instance()
                        .reset();
            }
        }
        // 获取所有的MappedStatement
        Set<String> set = new HashSet<>();
        for (Object mappedStatement : configuration.getMappedStatements()) {
            // fix bug: org.apache.ibatis.session.Configuration.StrictMap.Ambiguity
            if(mappedStatement instanceof MappedStatement){
                set.add(((MappedStatement)mappedStatement).getId());
            }
        }
        // 检查mapper
        List<String> idList = new ArrayList<>(set);
        for (String id : idList) {
            try {
                check(configuration, id);
            }catch (Exception e){
                e.printStackTrace();
            }
        }
        System.out.println("---------");
    }

    private static void check(Configuration configuration, String id) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        System.out.println(id + ": --------- start");
        // 解析类名和方法名
        int i = id.lastIndexOf(".");
        String classId = id.substring(0, i);
        // 获取dao方法
        Method methodByName = getMethod(classId, id.substring(i +1));
        if (methodByName == null){
            return;
        }
        // 参数个数
        Object[] args = new Object[methodByName.getParameterTypes().length];
        int j = 0;
        // 填充参数
        for (Parameter parameter : methodByName.getParameters()) {
            Type parameterizedType = parameter.getParameterizedType();
            if (parameterizedType instanceof ParameterizedType){
                // 泛型类型
                String typeName = ((ParameterizedType) parameterizedType).getRawType()
                        .getTypeName();
                switch (typeName){
                    case "java.util.List": {
                        ArrayList<Object> objects = new ArrayList<>();
                        Object object = createObject(((ParameterizedType) parameterizedType).getActualTypeArguments()[0].getTypeName());
                        objects.add(object);
                        args[j++] = objects;
                        break;
                    }
                    case "java.util.Set": {
                        HashSet<Object> objects = new HashSet<>();
                        Object object = createObject(((ParameterizedType) parameterizedType).getActualTypeArguments()[0].getTypeName());
                        objects.add(object);
                        args[j++] = objects;
                        break;
                    }
                    default:{
                        throw new RuntimeException(typeName);
                    }
                }
            }else {
                // 普通类型
                args[j++] = createObject(parameterizedType.getTypeName());
            }
        }
        // 解析参数
        ParamNameResolver paramNameResolver = new ParamNameResolver(configuration, methodByName);
        Object namedParams = paramNameResolver.getNamedParams(args);
        // 获取sql
        MappedStatement mappedStatement = configuration.getMappedStatement(id);
        // 进行参数判断,获取必需参数
        // 运行if条件,可检验参数是否存在pojo中
        BoundSql boundSql = mappedStatement.getBoundSql(namedParams);
        // 获取需要填充的参数
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
        // 填充必需参数
        for (int k = 0; k < parameterMappings.size(); k++) {
            ParameterMapping parameterMapping = parameterMappings.get(k);
            if (parameterMapping.getMode() != ParameterMode.OUT) {
                Object value;
                String propertyName = parameterMapping.getProperty();
                if (boundSql.hasAdditionalParameter(propertyName)) {
                    value = boundSql.getAdditionalParameter(propertyName);
                } else if (namedParams == null) {
                    value = null;
                } else if (typeHandlerRegistry.hasTypeHandler(namedParams.getClass())) {
                    value = namedParams;
                } else {
                    try {
                        MetaObject metaObject = configuration.newMetaObject(namedParams);
                        value = metaObject.getValue(propertyName);
                    }catch (Exception e){
                        e.printStackTrace();
                        value = "Error Parameter "+ namedParams +" not found";
                    }
                }
                System.out.println("fill: " + id + "#" + propertyName + ":" + value);
            }
        }
        System.out.println(id + ": --------- end");
    }

    private static Method getMethod(String classId, String methodName) throws ClassNotFoundException {
        for (Method method : Class.forName(classId)
                .getMethods()) {
            if (method.getName().equals(methodName)){
                return method;
            }
        }
        return null;
    }


    private static Object createObject(String typeName) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        switch (typeName){
            case "int":
            case "java.lang.Integer":{
                return 0;
            }
            case "long":
            case "java.lang.Long":{
                return 0L;
            }
            case "java.lang.Long[]":{
                return new Long[]{};
            }
            case "boolean":
            case "java.lang.Boolean":{
                return true;
            }
            case "java.lang.String[]":{
                return new String[]{};
            }
            default:
                return (Class.forName(typeName)).newInstance();
        }
    }

    static final String DEFAULT_RESOURCE_PATTERN = "**/*.class";
    /**
     * copy from {@link com.ruoyi.framework.config.MyBatisConfig#setTypeAliasesPackage(String)}
     */
    private static Set<String> getTypeAliasesPackage(String typeAliasesPackage)
    {
        Set<String> res = new HashSet<>();
        ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(resolver);
        try
        {
            for (String aliasesPackage : typeAliasesPackage.split(","))
            {
                aliasesPackage = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
                        + ClassUtils.convertClassNameToResourcePath(aliasesPackage.trim()) + "/" + DEFAULT_RESOURCE_PATTERN;
                Resource[] resources = resolver.getResources(aliasesPackage);
                if (resources != null && resources.length > 0)
                {
                    MetadataReader metadataReader = null;
                    for (Resource resource : resources)
                    {
                        if (resource.isReadable())
                        {
                            metadataReader = metadataReaderFactory.getMetadataReader(resource);
                            try
                            {
                                res.add(Class.forName(metadataReader.getClassMetadata().getClassName()).getPackage().getName());
                            }
                            catch (ClassNotFoundException e)
                            {
                                e.printStackTrace();
                            }
                        }
                    }
                }
            }

        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
        return res;
    }
}