Mybatis-plus自定义sql拦截插件实现全表全字段分页模糊查询(支持多表)

624 阅读7分钟

Mybatis-plus自定义sql拦截器实现全表全字段分页模糊查询(支持多表)

Mybatis-plus是一个非常好用十分强大的操作sql数据库的框架,极大地提高了开发者的开发效率。

需求背景

某天我的前端小伙伴和我说他对原本分页查询的某些特定字段的模糊查询不满意,他希望能根据返回的每一个字段进行进行模糊查询,只要表头有的都能进行模糊查询,类似图下的效果:

企业微信截图_16927735204870.png 我听后说,这还不简单,我将返回的任何一个字段作为传参条件,用Mybatis-plus框架轻松搞定,我框框一顿写,写出下面的代码:

return xxxManager.lambdaQuery()
        ...
        .like(xxx != null,xxx,xxx)
        .or()
        .like(xxx != null,xxx,xxx)
        .or()
        .like(xxx != null,xxx,xxx)
        ...//以下省略n个
        .page(xxx);

直接秒杀! 然后他又告诉我,还有一个接口想要有这样的效果,没问题。还有一个,没问题,还有一个...不对劲,每个都这样写太麻烦了,难以容忍,事已至此,直接造轮子!

需求分析

显而易见,造的轮子必须满足以下特点:

  • 能根据返回的每一个字段进行模糊分页查询
  • 轮子简单易用、使用灵活,且不能对原有代码造成太大的影响

因此,我的思路如下:

  1. 实现InnerInterceptor接口,自定义拦截插件
  2. 使用注解声明的方式,判断拦截的Mapper方法是否需要被处理
  3. 若需要,则获取Mapper方法的入参,利用反射对入参的字段进行处理,拼接模糊查询的sql语句
  4. 将步骤3得到的sql语句插入原sql中,得到新的sql

如此,使用者只需在原来的Mapper方法上加个注解便可实现全字段的模糊查询。

代码实现

part1 定义注解

  • @QueryPage:声明该Mapper方法需要被拦截处理
/**
 * @author 摆渡人
 * @description 自定义分页全条件模糊查询
 * @date 2023/8/23 15:35
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface QueryPage {

    /**
     * 需要全字段分析的类
     * @return
     */
    Class<?> vo();

    /**
     *
     * 第几个where条件 从0开始
     * @return
     */
    int whereIndex() default 0;
}
  • @QueryParm:进行字段分析时,声明该字段的别名,或者是否需要排除
/**
 * @author 摆渡人
 * @description 自定义模糊字段
 * @date 2023/8/23 15:40
 */
@Target({ElementType.FIELD, ElementType.TYPE_USE})
@Retention(RetentionPolicy.RUNTIME)
public @interface QueryParm {

    /**
     * 字段别名
     * 比如外连接了t_user u,则填入u.name之类的
     * 优先级:@QueryParm.parm() > @TableField.value() > fileName
     * @return
     */
    String parm() default "";

    /**
     * 是否排除
     * @return
     */
    boolean exclude() default false;
}

part2 编写拦截插件

/**
 * @author 摆渡人
 * @description 分页查询多表全字段模糊查询sql拦截器
 * @date 2023/8/23 15:33
 */
@Slf4j
public class QueryPageSqlInterceptor implements InnerInterceptor {

    @SneakyThrows
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        // 利用ms获取方法上的注解
        QueryPage queryPage = this.getQueryPage(ms);
        if(queryPage != null) {
            Class<?> voClass = queryPage.vo();
            Object vo = this.findVo(voClass, parameter);
            if(vo == null) {
                return;
            }
            String originalSql = boundSql.getSql().trim();
            String newSql = this.joinSql(originalSql,queryPage,vo);
            // 通过反射替换sql
            Field field = boundSql.getClass().getDeclaredField("sql");
            field.setAccessible(true);
            field.set(boundSql,newSql);
        }
    }

    /**
     * 获取拦截方法上的@QueryPage注解
     * @param ms MappedStatement
     * @return QueryPageSqlConfigDTO
     */
    @SneakyThrows
    private QueryPage getQueryPage(MappedStatement ms) {
        String msId = ms.getId();
        int lastIndex = msId.lastIndexOf(".");
        String className = msId.substring(0, lastIndex);
        String methodName = msId.substring(lastIndex + 1);
        Class<?> clazz = Class.forName(className);
        Method[] methods = clazz.getMethods();
        for (Method method : methods) {
            if(method.getName().equals(methodName)) {
                return method.getAnnotation(QueryPage.class);
            }
        }
        return null;
    }

    /**
     * 从入参中匹配Vo
     * @param vo Class<?>
     * @param parameter parameter
     * @return vo
     */
    private Object findVo(Class<?> vo,Object parameter) {
        if (parameter != null) {
            if (parameter instanceof Map) {
                Map<?, ?> parameterMap = (Map<?, ?>) parameter;
                for (Map.Entry entry : parameterMap.entrySet()) {
                    if(entry.getValue() != null && entry.getValue().getClass().equals(vo)) {
                        return entry.getValue();
                    }
                }
            } else if (parameter.getClass().equals(vo)) {
                return parameter;
            }
        }
        return null;
    }

    /**
     * 插入sql
     * @param sql 原sql
     * @param queryPage 配置
     * @param vo vo
     * @return 新sql
     */
    private String joinSql(String sql,QueryPage queryPage,Object vo) {
        if(queryPage == null) {
            return sql;
        }

        String appendSql = this.getAppendSql(vo);
        if (StringUtils.isEmpty(appendSql)) {
            return sql;
        }

        Integer appendSqlWhereIndex = queryPage.whereIndex();
        String where = "where";
        String order = "order by";
        String group = "group by";
        int whereIndex = StringUtils.ordinalIndexOf(sql.toLowerCase(), where, appendSqlWhereIndex + 1);
        int orderIndex = sql.toLowerCase().indexOf(order);
        int groupIndex = sql.toLowerCase().indexOf(group);
        if (whereIndex > - 1) {
            String subSql = sql.substring(0, whereIndex + where.length() + 1);
            subSql = subSql + " " + appendSql + " AND " + sql.substring(whereIndex + where.length() + 1);
            return subSql;
        }

        if (groupIndex > - 1) {
            String subSql = sql.substring(0, groupIndex);
            subSql = subSql + " where " + appendSql + " " + sql.substring(groupIndex);
            return subSql;
        }
        if (orderIndex > - 1) {
            String subSql = sql.substring(0, orderIndex);
            subSql = subSql + " where " + appendSql + " " + sql.substring(orderIndex);
            return subSql;
        }
        sql += " where " + appendSql;
        return sql;
    }

    /**
     * 获取拼接的sql
     * @param vo vo
     * @return appendSql
     */
    private String getAppendSql(Object vo) {
        Field[] fields = this.getAllFields(vo.getClass());
        StringBuilder sb = new StringBuilder();
        sb.append(" (");
        Arrays.stream(fields).filter(Objects::nonNull).forEach(field -> {
            field.setAccessible(true);
            try {
                QueryParm queryParm = field.getAnnotation(QueryParm.class);
                String parm = "";
                if(queryParm != null) {
                    // 判断是否排除
                    boolean exclude = queryParm.exclude();
                    if(exclude) {
                        return;
                    }
                    parm = queryParm.parm();
                }

                // 获取字段的值,忽略空字段
                Object fieldValue = field.get(vo);
                if(fieldValue == null || "".equals(fieldValue)) {
                    return;
                }

                // 获取字段名
                TableField tableField = field.getAnnotation(TableField.class);
                String fieldName = parm.equals("") ? ((tableField != null && !"".equals(tableField.value())) ? tableField.value() : field.getName()) : parm;

                // 转下划线
                fieldName = com.baomidou.mybatisplus.core.toolkit.StringUtils.camelToUnderline(fieldName);

                sb.append(" ").append(fieldName).append(" like ").append("'%").append(fieldValue).append("%' or");
            } catch (Exception e) {
                log.error("拼接查询sql发生异常:",e);
                throw new BusinessException("查询失败");
            }
        });
        sb.append(") ");
        if(sb.toString().equals(" () ")) {
            return null;
        }
        // sb去掉最后的or
        int lastOrIndex = sb.lastIndexOf("or");
        sb.delete(lastOrIndex, lastOrIndex + 2);
        return sb.toString();
    }

    /**
     * 获取类的所有field,包含其继承的父类
     * @param clazz
     * @return
     */
    public Field[] getAllFields(Class<?> clazz) {
        List<Field> fieldList = new ArrayList<>();
        while (clazz != null){
            fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
            clazz = (Class<T>) clazz.getSuperclass();
        }
        Field[] fields = new Field[fieldList.size()];
        return fieldList.toArray(fields);
    }
}

使用配置

在Mybatis-plus的配置类中添加该拦截插件即可

/**
 * @author 摆渡人
 * @description Mybatis-plus配置类
 * @date 2022/11/28 20:43
 */
@Component
public class MybatisPlusConfig implements MetaObjectHandler {

    @Override
    public void insertFill(MetaObject metaObject) {
        this.setFieldValByName("createTime", LocalDateTime.now(), metaObject);
        this.setFieldValByName("updateTime", LocalDateTime.now(), metaObject);
        this.fillStrategy(metaObject, "countId", UUID.randomUUID().toString().replace("-", ""));
        this.setFieldValByName("timeStamp",new Date().getTime(),metaObject);
    }

    @Override
    public void updateFill(MetaObject metaObject) {
        this.setFieldValByName("updateTime",LocalDateTime.now(),metaObject);
        this.setFieldValByName("completeTime", LocalDateTime.now(), metaObject);
    }

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor(){
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 全字段模糊查询插件
        interceptor.addInnerInterceptor(new QueryPageSqlInterceptor());
        // 添加分页插件
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }
}

测试用例

环境准备

part1 表和数据如下

image.png

part2 conrtoller代码

@Slf4j
@RestController
@Api("测试用例")
public class BookController {

    @Autowired
    private BookMapper bookMapper;

    @ApiOperation("分页查询-书")
    @GetMapping("/book/page")
    @NoNeedLogin
    public ResponseDTO<IPage<BookEntity>> book(BookEntity bookEntity,Integer pageNum,Integer pageSize) {
        Page page = new Page();
        page.setCurrent(pageNum);
        page.setSize(pageSize);
        List<BookEntity> bookEntities = bookMapper.pageBook(page, bookEntity);
        page.setRecords(bookEntities);
        return ResponseDTO.ok(page);
    }

    @ApiOperation("分页查询-标签")
    @GetMapping("/tag/page")
    @NoNeedLogin
    public ResponseDTO<IPage<TagEntity>> tag(TagEntity tagEntity,Integer pageNum,Integer pageSize) {
        Page page = new Page();
        page.setCurrent(pageNum);
        page.setSize(pageSize);
        List<TagEntity> tagEntities = bookMapper.pageTag(page, tagEntity);
        page.setRecords(tagEntities);
        return ResponseDTO.ok(page);
    }

    @ApiOperation("分页查询-书和标签")
    @GetMapping("/book-tag/page")
    @NoNeedLogin
    public ResponseDTO<IPage<BookTagVo>> bookTag(BookTagVo bookTagVo,Integer pageNum,Integer pageSize) {
        Page page = new Page();
        page.setCurrent(pageNum);
        page.setSize(pageSize);
        List<BookTagVo> bookTagVoList = bookMapper.page(page, bookTagVo);
        page.setRecords(bookTagVoList);
        return ResponseDTO.ok(page);
    }
}

part3 BookTagVo

@Data
public class BookTagVo {

    @QueryParm(exclude = true)
    private Integer id;

    private String bookName;

    @TableField("remark")
    @QueryParm(parm = "b.remark")
    private String bookRemark;

    private String author;

    private LocalDateTime createTime;

    @QueryParm(exclude = true)
    private Integer tagId;

    private String tagName;

    @TableField("remark")
    @QueryParm(parm = "t.remark")
    private String tagRemark;

}

part4 BookMapper

/**
 * @author 摆渡人
 * @description
 * @date 2023/8/25 11:06
 */
@Mapper
public interface BookMapper {

    @Select("select * from t_book")
    @QueryPage(vo = BookEntity.class)
    List<BookEntity> pageBook(Page page,BookEntity bookEntity);

    @Select("select * from t_tag")
    @QueryPage(vo = TagEntity.class)
    List<TagEntity> pageTag(Page page,TagEntity tagEntity);

    @Select("select b.id as bookId,b.book_name,b.remark as bookRemark,b.author,b.tag_id,b.create_time,t.tag_name,t.remark as tagRemark " +
            "from t_book b " +
            "left join t_tag t on b.tag_id = t.id")
    @QueryPage(vo = BookTagVo.class)
    List<BookTagVo> page(Page page,BookTagVo bookTagVo);

}

测试用例一(单表查询)

  • 请求参数 image.png
  • 响应结果
    {
        "code": 0,
        "level": null,
        "msg": "success",
        "ok": true,
        "data": {
            "records": [
                {
                    "id": 1,
                    "bookName": "高等数学",
                    "remark": "和宋浩老师学的,好难",
                    "author": "xxx出版社",
                    "tagId": 1,
                    "createTime": "2023-08-25 10:57:23"
                },
                {
                    "id": 3,
                    "bookName": "大学物理",
                    "remark": "上课想睡觉,好难",
                    "author": "xxx老师",
                    "tagId": 2,
                    "createTime": "2023-08-25 10:57:23"
                },
                {
                    "id": 4,
                    "bookName": "大学物理实验",
                    "remark": "得花不少时间,好难",
                    "author": "xxx老师",
                    "tagId": 2,
                    "createTime": "2023-08-25 10:57:23"
                }
            ],
            "total": 3,
            "size": 10,
            "current": 1,
            "orders": [],
            "optimizeCountSql": true,
            "searchCount": true,
            "countId": null,
            "maxLimit": null,
            "pages": 1
        }
    }
    
  • 实际执行的sql
    SELECT COUNT(*) AS total FROM t_book WHERE (book_name LIKE '%数学%' OR remark LIKE '%好难%');
    select * from t_book where ( book_name like '%数学%' or remark like '%好难%' ) LIMIT 10;
    

测试用例二(多表查询)

  • 请求参数 image.png
  • 响应结果
    {
        "code": 0,
        "level": null,
        "msg": "success",
        "ok": true,
        "data": {
            "records": [
                {
                    "id": 1,
                    "bookName": "高等数学",
                    "bookRemark": "和宋浩老师学的,好难",
                    "author": "xxx出版社",
                    "createTime": "2023-08-25 10:57:23",
                    "tagId": 1,
                    "tagName": "数学",
                    "tagRemark": "这是与数学有关的书,蛮难的"
                },
                {
                    "id": 3,
                    "bookName": "大学物理",
                    "bookRemark": "上课想睡觉,好难",
                    "author": "xxx老师",
                    "createTime": "2023-08-25 10:57:23",
                    "tagId": 2,
                    "tagName": "物理",
                    "tagRemark": "学好物理"
                },
                {
                    "id": 4,
                    "bookName": "大学物理实验",
                    "bookRemark": "得花不少时间,好难",
                    "author": "xxx老师",
                    "createTime": "2023-08-25 10:57:23",
                    "tagId": 2,
                    "tagName": "物理",
                    "tagRemark": "学好物理"
                },
                {
                    "id": 5,
                    "bookName": "大学英语",
                    "bookRemark": "hello world",
                    "author": "xxx报社",
                    "createTime": "2023-08-25 10:57:23",
                    "tagId": 3,
                    "tagName": "英语",
                    "tagRemark": "say hello"
                },
                {
                    "id": 6,
                    "bookName": "大学英语口语",
                    "bookRemark": "how are you",
                    "author": "xxx报社",
                    "createTime": "2023-08-25 10:57:23",
                    "tagId": 3,
                    "tagName": "英语",
                    "tagRemark": "say hello"
                }
            ],
            "total": 5,
            "size": 10,
            "current": 1,
            "orders": [],
            "optimizeCountSql": true,
            "searchCount": true,
            "countId": null,
            "maxLimit": null,
            "pages": 1
        }
    }
    
  • 实际执行的sql
SELECT COUNT(*) AS total FROM t_book b LEFT JOIN t_tag t ON b.tag_id = t.id WHERE (book_name LIKE '%数学%' OR b.remark LIKE '%好难%' OR t.remark LIKE '%say%');
select b.id,b.book_name,b.remark as bookRemark,b.author,b.tag_id,b.create_time,t.tag_name,t.remark as tagRemark from t_book b left join t_tag t on b.tag_id = t.id where ( book_name like '%数学%' or b.remark like '%好难%' or t.remark like '%say%' ) LIMIT 10;

优化分析

1. 判断是否需要拦截

为了判断拦截的Mapper是否需要拦截,每次都会调用获取@QueryPage的方法QueryPage getQueryPage(MappedStatement ms),方法中会截取ms.getId()成两部分,类路径方法名,然后利用Class.forName(className)获取Mapper类,再遍历Mapper类中的所有方法,找到与方法名相同的方法,然后再判断该方法上是否有@QueryPage注解。

我觉得,一个Mapper方法上是否有@QueryPage注解是在项目编译运行后就确定下来且不会变的,每次都像上述那样遍历找一遍似乎不太优雅,但目前也没找到解决方案,希望大佬指点!