Mybatis拦截器教程及几个实用自定义拦截器分享

2,133 阅读11分钟

本文内容

我在项目中用到了Mybatis拦截器,于是趁着知识还热乎将其整理了出来,都是一些实用性插件。

  • 自定义Mybatis拦截器教程
  • 写一个自动分配主键Id、创建者Id、更新者Id的拦截器插件
  • 写一个避免全表更新/删除的数据安全插件
  • 写一个乐观锁插件

自定义Mybatis拦截器教程

步骤如下:

  1. 写一个Mybatis拦截器类
  2. 新建Mybatis拦截器配置类,在里面注册刚写的拦截器

代码展示:一个测试用的Mybatis拦截器

package com.cc.interceptor;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Properties;

@Intercepts({
        @Signature(
                type = Executor.class,
                method = "update",
                args = {MappedStatement.class, Object.class}),
})
public class MybatisTestInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(MybatisTestInterceptor.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        System.out.println("触发Mybatis拦截器了");
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

代码说明

  • @Intercepts

    一个装载容器,稍后用来添加@Signature注解,提示:这里可以注册多个@Signature。

  • @Signature

    该注解有三个参数:

    • type:拦截器类型

      Mybatis拦截器并不是每个对象里面的方法都可以被拦截的。Mybatis拦截器只能拦截Executor、ParameterHandler、StatementHandler、ResultSetHandler四个对象里面的方法,也就是说这个type参数只能传这四种类型。

      • Executor:执行器,Mybatis的调度核心,负责SQL语句的生成

      • ParameterHandler:负责对用户传递的参数转换成JDBC Statement所需要的参数

      • StatementHandler:负责对JDBC Statement的操作

      • ResultSetHandler:负责将JDBC返回的ResultSet结果集对象转换成List类型的集合

    • method:拦截器类型中的方法,上面代码示例的“update”,包括了SQL中的:INSERT/UPDATE/DELETE

    • args:method方法中的入参,在拦截器函数中可以利用args获取更具体的对象参数。

拦截器需要实现类Interceptor,才能注册到Mybatis拦截器配置类中。

代码的话我们只需重点关注:Object intercept(Invocation invocation)即可,invocation.proceed();表示继续放行。

Mybatis拦截器配置类

/**
 * Mybatis拦截器执行顺序配置类
 * 多个拦截器存在时,这里对执行顺序进行配置
 */
@Configuration
public class MybatisInterceptorConfig {
    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    // 只执行一次
    @PostConstruct
    public void addDefaultTimeInterceptor() {
        /**
         * Mybatis拦截器可以使用@Component注解也可以在这里进行配置
         * 在这里配置可以控制拦截器的执行顺序,所以注意去掉@Component注解
         */
        for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
            org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();

            // 最后添加的会更早执行
            configuration.addInterceptor(new MybatisTestInterceptor());
        }
    }
}

现在,我们调用一个关于新增、更新、删除的代码,即可触发该拦截器。

自动分配主键Id、创建者Id、更新者Id的拦截器插件

当我们的数据库表没有使用自增主键,而是使用如雪花Id这种方式的时候,就可以用上这个插件,在插入数据的时候,判断插入对象id是否有值,没有则自动分配一个;

同理,创建者id和更新者Id也是如此,虽然未必所有的表都会有id、创建者id、更新者id,但是下面的示例也应该能给到一定的参考。

现有条件

  • 数据库表结构为:(id, createBy, updateBy,...)
  • 雪花Id生成器:IdGenerator
  • MDC保存登录用户的id(logback提供的一种方便在多线程下记录日志的功能,这里用来在登陆后全局保存用户id)

代码展示

那么拦截器代码为:

/**
 * Mybatis插入时自动填充Id,包括主键id、创建人id(createBy)、修改人id(updateBy)
 * 注意:需要实体类有id、createBy、updateBy属性
 */
@Intercepts({
        @Signature(
                type = Executor.class,
                method = "update",
                args = {MappedStatement.class, Object.class}),
})
public class MybatisAutoGenerateIdInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(MybatisAutoGenerateIdInterceptor.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 对应上面类注解的args,获取需要的MappedStatement对象
        final Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];

        // 插入对象,即实体类如:SysUser
        Object obj = args[1];

        /**
         * 插入或更新的时候,赋予createBy或updateBy值
         */
        // 从MDC中获取用户id
        String userIdStr = MDC.get(GlobalConstant.MDC_USER_ID);
        Long userId = null;
        if (!StringUtils.isEmpty(userIdStr)) {
            userId = Long.parseLong(userIdStr);
        }

        // 识别SQL类型,看是INSERT还是UPDATE
        if (SqlCommandType.INSERT.equals(statement.getSqlCommandType())) {
            // 单个插入,BaseEntity是我的实体类的父类,该类有id、createBy、updateBy等通用属性
            if (obj instanceof BaseEntity) {
                BaseEntity entity = ((BaseEntity) obj);
                assignIdIfNull(entity, userId);
            }
            // 批量插入
            if (obj instanceof Map) {
                Map<?, ?> map = (Map<?, ?>) obj;
                List<?> list = (List<?>) map.get("list");
                if (list != null) {
                    for (Object entity : list) {
                        assignIdIfNull((BaseEntity) entity, userId);
                        // 批量操作不同于单个操作,单个操作支持部分字段,批量操作不支持,所以要提前设定默认值
                        assignDefaultValue((BaseEntity) entity);
                    }
                }
            }
        }

        if (SqlCommandType.UPDATE.equals(statement.getSqlCommandType())) {
            // 单个更新
            if (obj instanceof BaseEntity) {
                BaseEntity entity = ((BaseEntity) obj);
                entity.setUpdateBy(Optional.ofNullable(userId).orElse(0L));
            }
            // 批量更新
            if (obj instanceof Map) {
                Map<?, ?> map = (Map<?, ?>) obj;
                List<?> list = (List<?>) map.get("list");
                if (list != null) {
                    for (Object item : list) {
                        BaseEntity entity = ((BaseEntity) item);
                        entity.setUpdateBy(Optional.ofNullable(userId).orElse(0L));
                    }
                }
            }
        }

        return invocation.proceed();
    }

    /**
     * 赋予id、createBy值
     * @param entity 实体类
     * @param userId 操作用户id
     */
    private void assignIdIfNull(BaseEntity entity, Long userId) {
        // 如果数据没有id值则自动分配
        if (entity.getId() == null || entity.getId() == 0L) {
            entity.setId(IdGenerator.INSTANCE.nextId());
        }
		// 如果数据没有创建者并且当前登录用户id不为空
        if (entity.getCreateBy() == null && userId != null) {
            entity.setCreateBy(userId);
        }
    }

    /**
     * 批量插入的时候给基础属性赋予默认值,不然会失败
     * 这里是插入的最后一级,所以优先级最高,业务代码内手动设置的也会被这里覆盖掉,也应该被覆盖
     * @param entity
     */
    private void assignDefaultValue(BaseEntity entity) {
        entity.setUpdateBy(0L);
        entity.setClientId(0);
        Date now = new Date();
        entity.setCreateTime(now);
        entity.setUpdateTime(now);
    }


    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

这个插件十分简单,就是在保存数据前做一些预操作。

避免全表更新/删除的数据安全插件

为了避免不小心导致的全表更新或删除操作,我们可以再写一个数据安全插件,实现目的是:不允许不带WHERE条件的更新或删除操作

但是太严格了不好,所以留个后门:可以设置一些允许的表,那么我们就可以为它写一个配置类:

@Component
@ConfigurationProperties(prefix = "allow-tables")
public class AllowUpdateTables {
    /**
     * 启用
     */
    private boolean enable = false;

    /**
     * 允许全表删除的表
     */
    private List<String> delete;

    /**
     * 允许全表更新的表
     */
    private List<String> update;

    ...
}

这样在配置文件中就要配合一下(application.yml):

# 在这里设置允许不带WHERE条件删除/更新的表,可以避免全表的误操作,配合MyabtisInterceptor拦截器使用
allow-tables:
  enable: true
  delete:
    - test
    - sys_user
    - sys_menu
  update:
    - test

准备工作完毕,现在贴拦截器代码:

/**
 * Mybatis安全更新拦截器,对update、delete操作进行保护处理,防止出现全表的更新或删除
 */
@Intercepts({
        @Signature(
                type = Executor.class,
                method = "update",
                args = {MappedStatement.class, Object.class}),
})
public class MybatisSafeUpdateInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(MybatisSafeUpdateInterceptor.class);

    // 关于该插件的配置
    private final AllowUpdateTables allowUpdateTables;

    public MybatisSafeUpdateInterceptor(AllowUpdateTables allowUpdateTables) {
        this.allowUpdateTables = allowUpdateTables;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (allowUpdateTables == null) {
            throw new BaseException("allowUpdateTables为空,请注册成Bean");
        }
        if (!allowUpdateTables.isEnable()) {
            return invocation.proceed();
        }

        final Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];

        // 全表更新、删除的时候
        if (SqlCommandType.UPDATE.equals(statement.getSqlCommandType())) {
            boolean allow = checkAllowTableOperation(invocation);
            if (!allow) {
                throw new BaseException("该表不支持不带where条件的更新语句");
            }
        }

        if (SqlCommandType.DELETE.equals(statement.getSqlCommandType())) {
            // 删除
            boolean allow = checkAllowTableOperation(invocation);
            if (!allow) {
                throw new BaseException("该表不支持不带where条件的删除语句");
            }
        }

        return invocation.proceed();
    }
    /**
     * 对没有带where条件的update和delete语句进行检查,看是否在允许列表中
     * 目的是为了避免忘记写条件导致的全表误操作
     */
    private boolean checkAllowTableOperation(Invocation invocation) throws JSQLParserException {
        String sql = getSqlByInvocation(invocation);
        if (StringUtils.isEmpty(sql)) {
            throw new BaseException("待执行的SQL不能为空");
        }

        // 这里用JSqlParser工具来判断是否存在where条件
        final Statement smt = CCJSqlParserUtil.parse(sql);
        if (smt instanceof Update) {
            final Update update = (Update) smt;
            return update.getWhere() != null;
        } else if (smt instanceof Delete) {
            final Delete delete = (Delete) smt;
            return delete.getWhere() != null;
        }

        return false;
    }

    /**
     * 获取待执行的sql语句
     */
    private String getSqlByInvocation(Invocation invocation) {
        try {
            MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
            final Object[] queryArgs = invocation.getArgs();
            final Object parameter = queryArgs[1];
            BoundSql boundSql = mappedStatement.getBoundSql(parameter);
            return boundSql.getSql();
        } catch (Exception e) {
            log.error("获取SQL语句失败", e);
        }
        return null;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

最后注册到Mybatis拦截器配置类的时候要记得把配置类加上:

@Configuration
public class MybatisInterceptorConfig {
    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    @Autowired
    private AllowUpdateTables allowUpdateTables;

    // 只执行一次
    @PostConstruct
    public void addDefaultTimeInterceptor() {
        /**
         * Mybatis拦截器可以使用@Component注解也可以在这里进行配置
         * 在这里配置可以控制拦截器的执行顺序,所以注意去掉@Component注解
         */
        MybatisAutoGenerateIdInterceptor autoGenerateIdInterceptor = new MybatisAutoGenerateIdInterceptor();
        MybatisSafeUpdateInterceptor safeUpdateInterceptor = new MybatisSafeUpdateInterceptor(allowUpdateTables);

        for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
            org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();

            // 最后添加的会更早执行
            configuration.addInterceptor(autoGenerateIdInterceptor);
            configuration.addInterceptor(safeUpdateInterceptor);
        }
    }
}

这个插件其实也是很简单的,执行语句前拦截一下,判断能不能执行。

这个插件引出来一个工具类:JSqlParser,这是一个解析SQL的工具,很好用,该插件第一版我是通过SQl里面有没有WHERE字符串来判断的,有了这个工具第二版代码就好看多了。

乐观锁插件

乐观锁原理

先简单说下乐观锁的原理吧,已经理解乐观锁的可以跳过这段

乐观锁是用来解决并发更新问题的,很典型的商品抢购场景,一件商品只有10的库存,同时涌入了100个购买请求,我们需要保证不超过10个请求可以成功,这时候乐观锁就派上用场了。

假设商品表结构为:(id, name(商品名称), price(价格), store(库存))

为了使用乐观锁,我们还要新增一个int类型的version字段,所以:

(id, name(商品名称), price(价格), store(库存), version(版本))

那么我们的商品售卖流程为:

  1. 查询商品信息

    SELECT * FROM 商品 WHERE id = 10;
    

    得到商品详情:

    {
        "id": 10,
        "name": "锤子手机"
        "price": 998.00,
        "store": 10,
        "version": 1
    }
    
  2. 商品库存减一,版本号加一,并且WHERE条件带上旧版本号:

    UPDATE 商品 SET store = store-1, version = 2 WHERE id = 10 AND version = 1;
    

    如果在第一步查询到的商品库存充足,但是在同一时刻被其他请求抢购完了,那么在更新商品库存的时候,就会因为WHERE version = 1的条件不符合导致更新失败,即抢购失败。

乐观锁的意思是更新数据前乐观的认为读取的数据是没有被修改过的,所以不用锁数据,这样做的缺点是SQL执行出错的可能性比较大,所以要为此做一些用户反馈,比如“抢购失败”、“当前用户过多”等提示。

科普一下悲观锁,悲观锁是悲观的认为读取的数据都会被修改,所以每次读取数据的时候都要锁住,不允许其他人修改,这样的好处是SQL执行的成功率很高,同样的,缺点是锁数据会影响其他的操作事务,即降低了系统的数据吞吐量。

实现思路

由上可以知道,乐观锁关键在于version字段,所以我们能思考出其实现的步骤:

  1. 数据库表增加version字段,最好是整型,INT就足够了
  2. 实体类的version属性用乐观锁注解修饰,让乐观锁插件可以扫描出
  3. 插件拦截得到数据对象和执行SQL,获取数据对象的version值,令其+1,同时增加WHERE条件:version = 旧值

代码展示

数据库和实体类的version字段太简单,就不多说明了,

首先是乐观锁注解,该注解只能用在属性上:

/**
 * 乐观锁字段标记
 */
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface OptimisticLockVersion {
}

然后用在实体类的version属性上:

@OptimisticLockVersion
@ApiModelProperty(value = "版本")
private Integer version;

乐观锁标记注解不是必要的,它的目的是为了让插件识别出哪一个是乐观锁字段,如果有规范说乐观锁字段的名字就叫version或者其他,那么也更直截了当一些,只是用乐观锁标记注解可以让这个字段名更加灵活。

然后是拦截器代码:

/**
 * 乐观锁插件
 */
@Intercepts({
        @Signature(
                type = StatementHandler.class,
                method = "prepare",
                args = {Connection.class, Integer.class}),
})
public class MybatisOptimisticLockInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(MybatisOptimisticLockInterceptor.class);
    
    /**
     * 记录没有悲观锁的表,就不用去判断了
     */
    private List<String> ignoreEntityList;

    public MybatisOptimisticLockInterceptor() {
        ignoreEntityList = new ArrayList<>();
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        /**
         * 实现思路:
         * 1. 判断语句类型,仅支持update
         * 2. 获取数据对象,提取出version值,设该值为oldVersion,令其+1设该值为newVersion
         * 3. 获取SQL语句,用JSqlParser工具修改version参数值为newVersion,添加version的查询条件:where version = oldVersion
         * 4. 将新的SQL语句覆盖进去,然后继续执行
         */

        // 下面代码直接抄
        StatementHandler handler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = SystemMetaObject.forObject(handler);
        MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement");

        // 操作类型
        SqlCommandType sqlType = ms.getSqlCommandType();
        if (sqlType != SqlCommandType.UPDATE) {
            return invocation.proceed();
        }

        // 获取数据对象
        final Object parameterObject = metaObject.getValue("delegate.boundSql.parameterObject");

        // 不支持数组
        if (parameterObject instanceof List) {
            return invocation.proceed();
        }

        // 有没有记录过
        String key = parameterObject.getClass().getName();
        if (ignoreEntityList.contains(key)) {
            return invocation.proceed();
        }

        // 检查是否有乐观锁字段
        final ArrayList<Field> fieldList = FieldLoader.INSTANCE.getFieldList(parameterObject.getClass());
        // 乐观锁注解校验
        OptimisticLockVersion optimisticLockVersion = null;
        Field versionField = null;
        for (Field field : fieldList) {
            OptimisticLockVersion annotation = field.getDeclaredAnnotation(OptimisticLockVersion.class);
            if (annotation != null) {
                if (optimisticLockVersion != null) {
                    throw new BaseException("同一个实体类不能有两个乐观锁字段:" + key);
                }
                optimisticLockVersion = annotation;
                versionField = field;
            }
        }

        if (optimisticLockVersion == null) {
            // 记录在案,下次就省的处理了
            ignoreEntityList.add(key);
            return invocation.proceed();
        }

        String versionFieldName = versionField.getName();

        Object value = metaObject.getValue("delegate.boundSql.parameterObject." + versionFieldName);
        if (value == null) {
            throw new BaseException("乐观锁字段的值不能为空");
        }

        // 旧值
        int oldVersion = Integer.parseInt((value.toString()));
        if (oldVersion <= 0) {
            return invocation.proceed();
        }
        // 新值
        int newVersion = oldVersion + 1;

        // 写入新值
        metaObject.setValue("delegate.boundSql.parameterObject." + versionFieldName, newVersion);

        // 修改SQL语句
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        // 原始SQL
        String originalSql = boundSql.getSql();

        final Statement smt = CCJSqlParserUtil.parse(originalSql);
        if (!(smt instanceof Update)) {
            // 双重防护
            return invocation.proceed();
        }

        Update update = ((Update) smt);

        // 添加查询条件
        final Expression originalWhere = update.getWhere();
        Expression newWhere = new EqualsTo(new Column("version"), new LongValue(oldVersion));

        if (originalWhere != null) {
            newWhere = new AndExpression(originalWhere, newWhere);
        }
        update.setWhere(newWhere);
        originalSql = update.toString();

        // 覆盖为新的SQL
        metaObject.setValue("delegate.boundSql.sql", originalSql);

        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    public List<String> getIgnoreEntityList() {
        return ignoreEntityList;
    }

    public void setIgnoreEntityList(List<String> ignoreEntityList) {
        this.ignoreEntityList = ignoreEntityList;
    }
}

注意!!在乐观锁插件中我们拦截的是StatementHandler对象,即生成SQL的时候,在上面的几个拦截器中我们拦截的是Executor执行器对象,此时SQL已经生成完毕了,而乐观锁插件的目的是动态修改SQL,所以我们拦截的对象也变成了StatementHandler,如果不是,我们替换新SQL是不会生效的。

这个插件也是借助了SQL解析器JSqlParser工具,这个工具真的太棒了。

最后别忘了将插件注册进Mybatis拦截器配置类中。

结语

以上是我在项目中用的几个Mybatis实用插件,能够优化项目中不少代码,增加可维护性,最主要的是优雅!优雅!优雅!