Mybatis源码解析-插件开发

319 阅读8分钟

MyBatis提供了一种插件(plugin)的功能,虽然叫做插件,但其实这是拦截器功能。那么拦截器拦截MyBatis中的哪些内容呢?

概述

MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis允许使用插件来拦截的方法调用包括:

  • Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed) 拦截执行器的方法
  • ParameterHandler (getParameterObject, setParameters) 拦截参数的处理
  • ResultSetHandler (handleResultSets, handleOutputParameters) 拦截结果集的处理
  • StatementHandler (prepare, parameterize, batch, update, query) 拦截Sql语法构建的处理

Mybatis采用责任链模式,通过动态代理组织多个拦截器(插件),通过这些拦截器可以改变Mybatis的默认行为(诸如SQL重写之类的),由于插件会深入到Mybatis的核心,因此在编写自己的插件前最好了解下它的原理,以便写出安全高效的插件。

拦截器的使用

拦截器介绍及配置

首先我们看下MyBatis拦截器的接口定义:

package org.apache.ibatis.plugin;

import java.util.Properties;

/**
 * 拦截器接口,用户可以进行扩展
 *
 * @author Clinton Begin
 */
public interface Interceptor {

    /**
     * 具体拦截操作
     *
     * @param invocation
     *
     * @return
     *
     * @throws Throwable
     */
    Object intercept(Invocation invocation) throws Throwable;

    /**
     * 用于创建Executor、ParameterHandler、ResultSetHandler、StatementHandler的代理对象
     *
     * @param target 就是Executor、ParameterHandler、ResultSetHandler、StatementHandler实例
     *
     * @return
     */
    default Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    /**
     * 用来设置插件的属性值
     *
     * @param properties
     */
    default void setProperties(Properties properties) {
        // NOP
    }

}

比较简单,只有3个方法。 MyBatis默认没有一个拦截器接口的实现类,开发者们可以实现符合自己需求的拦截器。下面的MyBatis官网的一个拦截器实例:

@Intercepts({@Signature(type= Executor.class, method = "update", args = {MappedStatement.class,Object.class})})
public class ExamplePlugin implements Interceptor {
  public Object intercept(Invocation invocation) throws Throwable {
    return invocation.proceed();
  }
  public Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }
  public void setProperties(Properties properties) {
  }
}

全局xml配置:

<plugins>
    <plugin interceptor="org.format.mybatis.cache.interceptor.ExamplePlugin"></plugin>
</plugins>

这个拦截器拦截Executor接口的update方法(其实也就是SqlSession的新增,删除,修改操作),所有执行executor的update方法都会被该拦截器拦截到。

拦截器(插件)是什么时候注册并且执行的?下面按照顺序来进行解析!

注册Executor插件

DefaultSqlSessionFactory打开一个SqlSession连接

/**
 * 打开一个session 来自 数据源
 *
 * @param execType   执行器类型       默认ExecutorType是 ExecutorType.SIMPLE
 * @param level      事务隔离级别
 * @param autoCommit 是否自动提交事务
 *
 * @return SqlSession会话对象
 */
private SqlSession openSessionFromDataSource(ExecutorType execType, TransactionIsolationLevel level, boolean autoCommit) {
    /*
        之前sqlSessionFactoryBuilder.build(InputStream inputStream)
          ===》 xMLConfigBuilder.parse()  ====》 将配置文件中的数据源和sql语句封装赋值给了Configuration对象
        所以,这边可以直接获取对象中的相关信息
     */
    Transaction tx = null;
    try {
        // 获取环境配置
        final Environment environment = configuration.getEnvironment();
        // 事务工厂
        final TransactionFactory transactionFactory = getTransactionFactoryFromEnvironment(environment);
        // 新建一个事务
        tx = transactionFactory.newTransaction(environment.getDataSource(), level, autoCommit);
        // 拿到执行器Executor 解析注册Executor类型的拦截器(插件)
        final Executor executor = configuration.newExecutor(tx, execType);
        // 创建默认的SqlSession对象
        return new DefaultSqlSession(configuration, executor, autoCommit);
    } catch (Exception e) {
        closeTransaction(tx); // may have fetched a connection so lets call close()
        throw ExceptionFactory.wrapException("Error opening session.  Cause: " + e, e);
    } finally {
        ErrorContext.instance().reset();
    }
}

Configuration 创建Executor执行器

public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    if (ExecutorType.BATCH == executorType) {
        // 批处理执行器
        executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
        // 复用执行器
        executor = new ReuseExecutor(this, transaction);
    } else {
        // 默认执行器
        executor = new SimpleExecutor(this, transaction);
    }
    /*
            1、一级缓存默认是开启的
            2、二级缓存默认是关闭的
            3、先执行一级缓存还是二级缓存,先执行二级缓存
     */
    // cacheEnabled一级缓存默认是开启的,为true的话代表开启二级缓存
    // 二级缓存通过装饰器模式的方式加载进执行器中
    if (cacheEnabled) {
        executor = new CachingExecutor(executor);
    }
    // 重点:插件拦截器过滤链 把执行器怼到里面去
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
}

InterceptorChain拦截器链注册所有的插件

/**
 * 执行所有的过滤器
 *
 * @param target
 *
 * @return
 */
public Object pluginAll(Object target) {
    // 遍历所有拦截器(插件)
    for (Interceptor interceptor : interceptors) {
        // 调用拦截器方法进行包装
        target = interceptor.plugin(target);
    }
    return target;
}

Plugin通过JDK动态代理生成一个插件的代理链对象(多个插件之间层层代理)

/**
 * 简化动态代理创建的方法
 *
 * @param target
 * @param interceptor
 *
 * @return
 */
public static Object wrap(Object target, Interceptor interceptor) {
    // 获取自定义插件中,通过@Intercepts注解指定的方法
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    // 获取目标对象的类型
    Class<?> type = target.getClass();
    // 根据类型获取所有拦截的接口
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
        // 如果需要拦截,则用JDK动态代理生成一个代理对象
        return Proxy.newProxyInstance(
                type.getClassLoader(),
                interfaces,
                new Plugin(target, interceptor, signatureMap));
    }
    return target;
}

执行Executor插件

DefaultSqlSession进行查询操作执行

/**
 * 这个方法返回列表
 *
 * @param statement Unique identifier matching the statement to use.
 * @param parameter A parameter object to pass to the statement.
 * @param rowBounds Bounds to limit object retrieval
 * @param <E>
 *
 * @return
 */
@Override
public <E> List<E> selectList(String statement, Object parameter, RowBounds rowBounds) {
    try {
        // 从configuration中拿到具体的MappedStatement,根据id获取
        MappedStatement ms = configuration.getMappedStatement(statement);
        // 首先关注wrapCollection(parameter) 包装了一下入参
        // 然后调用executor.query()查询方法
        //1.Executor插件进行拦截
        return executor.query(ms, wrapCollection(parameter), rowBounds, Executor.NO_RESULT_HANDLER);
    } catch (Exception e) {
        throw ExceptionFactory.wrapException("Error querying database.  Cause: " + e, e);
    } finally {
        ErrorContext.instance().reset();
    }
}

注册ParameterHandler,ResultSetHandler,StatementHandler插件

SimpleExecutor执行doQuery进行查询操作时会去注册StatementHandler插件并且先注册ParameterHandler,ResultSetHandler插件

@Override
public <E> List<E> doQuery(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
    Statement stmt = null;
    try {
        // 获取核心对象Configuration
        Configuration configuration = ms.getConfiguration();
        // 拿到一个StatementHandler对象,具体是RoutingStatementHandler类型,里面还注册了插件
        StatementHandler handler = configuration.newStatementHandler(wrapper, ms, parameter, rowBounds, resultHandler, boundSql);
        // 解析成一个prepareStatement对象,并设置参数等
        stmt = prepareStatement(handler, ms.getStatementLog());
        // 用StatementHandler来处理查询,默认是PreparedStatementHandler
        return handler.query(stmt, resultHandler);
    } finally {
        // 释放资源
        closeStatement(stmt);
    }
}

Configuration来进行StatementHandler插件的注册

public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
}

RoutingStatementHandler来创建RoutingStatementHandler实例时会先注册ParameterHandler,ResultSetHandler插件

public RoutingStatementHandler(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    // 核心就在这,根据不同的statementType属性来创建对应的StatementHandler具体实现
    switch (ms.getStatementType()) {
        case STATEMENT:
            delegate = new SimpleStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
            break;
        case PREPARED:
            delegate = new PreparedStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
            break;
        case CALLABLE:
            delegate = new CallableStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
            break;
        default:
            throw new ExecutorException("Unknown statement type: " + ms.getStatementType());
    }

}

BaseStatementHandler的构造函数里面会注册ParameterHandler,ResultSetHandler插件

protected BaseStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    this.configuration = mappedStatement.getConfiguration();
    this.executor = executor;
    this.mappedStatement = mappedStatement;
    this.rowBounds = rowBounds;

    this.typeHandlerRegistry = configuration.getTypeHandlerRegistry();
    this.objectFactory = configuration.getObjectFactory();

    if (boundSql == null) { // issue #435, get the key before calculating the statement
        generateKeys(parameterObject);
        boundSql = mappedStatement.getBoundSql(parameterObject);
    }

    this.boundSql = boundSql;
    
    //注册ParameterHandler插件
    this.parameterHandler = configuration.newParameterHandler(mappedStatement, parameterObject, boundSql);
    //注册ResultSetHandler插件
    this.resultSetHandler = configuration.newResultSetHandler(executor, mappedStatement, rowBounds, parameterHandler, resultHandler, boundSql);
}
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
}

public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
                                            ResultHandler resultHandler, BoundSql boundSql) {
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
}

这样下来四种类型的拦截器(插件)就这样注册好了。Executor也执行完了。

执行StatementHandler拦截器

回到SimpleExecutor的doQuery方法里面调用了prepareStatement方法

private Statement prepareStatement(StatementHandler handler, Log statementLog) throws SQLException {
    Statement stmt;
    // 获取JDBC连接对象
    Connection connection = getConnection(statementLog);
    // 创建Statement对象 这里会打印日志==>  Preparing: select id,name,age,hobby from student WHERE id in ( ? , ? , ? ) AND name like ? AND age = ? order by name desc
    stmt = handler.prepare(connection, transaction.getTimeout());
    // 继续调用StatementHandler的方法设置参数
    handler.parameterize(stmt);
    return stmt;
}

执行ParameterHandler拦截器

StatementHandler调用parameterize会调用ParameterHandler的setParameters方法所以ParameterHandler拦截器也是这时候执行的

@Override
public void parameterize(Statement statement) throws SQLException {
    //这里调用插件进行拦截
    parameterHandler.setParameters((PreparedStatement) statement);
}

执行ResultSetHandler拦截器

PreparedStatementHandler执行query方法的时候会调用到ResultSetHandler的handleResultSets方法。这时候会调用到ResultSetHandler注册的拦截器

@Override
public <E> List<E> query(Statement statement, ResultHandler resultHandler) throws SQLException {
    // 强转为PreparedStatement
    PreparedStatement ps = (PreparedStatement) statement;
    // JDBC执行!
    ps.execute();
    // 处理结果集
    return resultSetHandler.handleResultSets(ps);
}

这时候Mybatis的拦截器的注册和使用就都完了。执行顺序 Executor注册->Executor执行->ParameterHandler注册->StatementHandler注册->ResultSetHandler注册->ParameterHandler执行->StatementHandler执行->ResultSetHandler执行

手写一个Mybatis分页插件

package org.apache.ibatis.mytest;

import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;

import java.sql.*;
import java.util.Properties;

import static org.apache.ibatis.reflection.SystemMetaObject.DEFAULT_OBJECT_FACTORY;
import static org.apache.ibatis.reflection.SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY;

@Intercepts({
//        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
//        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),

//        @Signature(type = ParameterHandler.class, method = "getParameterObject", args = {}),
//        @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class}),

        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class}),
//        @Signature(type = StatementHandler.class, method = "query", args = {Statement.class,ResultHandler.class}),

//        @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class}),

})
@Slf4j
public class PageInterceptor implements Interceptor {
    private static final String PARAM_KEY = "tenant";
    //
    private String defaultDialect = "mysql";
    private String defaultPageSqlId = ".*Page$";
    private String dialect;
    private String pageSqlId;
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaStatementHandler = MetaObject.forObject(statementHandler,
                DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环
        // 可以分离出最原始的的目标类)
        while (metaStatementHandler.hasGetter("h")) {
            Object object = metaStatementHandler.getValue("h");
            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY,
                    DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        }
        // 分离最后一个代理对象的目标类
        while (metaStatementHandler.hasGetter("target")) {
            Object object = metaStatementHandler.getValue("target");
            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY,
                    DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        }
        Configuration configuration = (Configuration) metaStatementHandler.
                getValue("delegate.configuration");
        dialect = configuration.getVariables().getProperty("dialect");
        if (null == dialect || "".equals(dialect)) {
            log.warn("Property dialect is not setted,use default 'mysql' ");
            dialect = defaultDialect;
        }
        pageSqlId = configuration.getVariables().getProperty("pageSqlId");
        if (null == pageSqlId || "".equals(pageSqlId)) {
            log.warn("Property pageSqlId is not setted,use default '.*Page$' ");
            pageSqlId = defaultPageSqlId;
        }
        MappedStatement mappedStatement = (MappedStatement)
                metaStatementHandler.getValue("delegate.mappedStatement");
        // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的
        //  MappedStatement的sql
        if (mappedStatement.getId().matches(pageSqlId)) {
            BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
            Object parameterObject = boundSql.getParameterObject();
            if (parameterObject == null) {
                throw new NullPointerException("parameterObject is null!");
            } else {
                // 分页参数作为参数对象parameterObject的一个属性
                PageParameter page = (PageParameter) metaStatementHandler
                        .getValue("delegate.boundSql.parameterObject");
                String sql = boundSql.getSql();
                // 重写sql
                String pageSql = buildPageSql(sql, page);
                metaStatementHandler.setValue("delegate.boundSql.sql", pageSql);
                // 采用物理分页后,就不需要mybatis的内存分页了,所以重置下面的两个参数
                metaStatementHandler.setValue("delegate.rowBounds.offset",
                        RowBounds.NO_ROW_OFFSET);
                metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
                Connection connection = (Connection) invocation.getArgs()[0];
                // 重设分页参数里的总页数等
                setPageParameter(sql, connection, mappedStatement, boundSql, page);
            }
        }
        // 将执行权交给下一个拦截器
        return invocation.proceed();
    }
    /**
     * 从数据库里查询总的记录数并计算总页数,回写进分页参数<code>PageParameter</code>,这样调用
     * 者就可用通过 分页参数<code>PageParameter</code>获得相关信息。
     *
     * @param sql
     * @param connection
     * @param mappedStatement
     * @param boundSql
     * @param page
     */
    private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,
                                  BoundSql boundSql, PageParameter page) {
        // 记录总记录数
        String countSql = "select count(0) from (" + sql + ") as total";
        PreparedStatement countStmt = null;
        ResultSet rs = null;
        try {
            countStmt = connection.prepareStatement(countSql);
            BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,
                    boundSql.getParameterMappings(), boundSql.getParameterObject());
            setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject());
            rs = countStmt.executeQuery();
            int totalCount = 0;
            if (rs.next()) {
                totalCount = rs.getInt(1);
            }
            page.setTotalCount(totalCount);
            int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1);
            page.setTotalPage(totalPage);
        } catch (SQLException e) {
            log.error("Ignore this exception", e);
        } finally {
            try {
                rs.close();
            } catch (SQLException e) {
                log.error("Ignore this exception", e);
            }
            try {
                countStmt.close();
            } catch (SQLException e) {
                log.error("Ignore this exception", e);
            }
        }
    }

    /**
     * 对SQL参数(?)设值
     *
     * @param ps
     * @param mappedStatement
     * @param boundSql
     * @param parameterObject
     * @throws SQLException
     */
    private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
                               Object parameterObject) throws SQLException {
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
        parameterHandler.setParameters(ps);
    }



    private String buildPageSql(String sql, PageParameter page) {
        if (page != null) {
            StringBuilder pageSql = new StringBuilder();
            if ("mysql".equals(dialect)) {
                pageSql = buildPageSqlForMysql(sql, page);
            } else if ("oracle".equals(dialect)) {
                pageSql = buildPageSqlForOracle(sql, page);
            } else {
                return sql;
            }
            return pageSql.toString();
        } else {
            return sql;
        }
    }

    public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {
        StringBuilder pageSql = new StringBuilder(100);
        String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
        pageSql.append(sql);
        pageSql.append(" limit " + beginrow + "," + page.getPageSize());
        return pageSql;
    }

    public StringBuilder buildPageSqlForOracle(String sql, PageParameter page) {
        StringBuilder pageSql = new StringBuilder(100);
        String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
        String endrow = String.valueOf(page.getCurrentPage() * page.getPageSize());
        pageSql.append("select * from ( select temp.*, rownum row_id from ( ");
        pageSql.append(sql);
        pageSql.append(" ) temp where rownum <= ").append(endrow);
        pageSql.append(") where row_id > ").append(beginrow);
        return pageSql;
    }

    @Override
    public Object plugin(Object target) {
        return Interceptor.super.plugin(target);
    }

    @Override
    public void setProperties(Properties properties) {
        Interceptor.super.setProperties(properties);
    }


}
package org.apache.ibatis.mytest;

import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class PageParameter {
    private Integer pageSize;
    private Integer currentPage;
    private Integer totalCount;
    private Integer totalPage;
}