MyBatis提供了一种插件(plugin)的功能,虽然叫做插件,但其实这是拦截器功能。那么拦截器拦截MyBatis中的哪些内容呢?
概述
MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis允许使用插件来拦截的方法调用包括:
- Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed) 拦截执行器的方法
- ParameterHandler (getParameterObject, setParameters) 拦截参数的处理
- ResultSetHandler (handleResultSets, handleOutputParameters) 拦截结果集的处理
- StatementHandler (prepare, parameterize, batch, update, query) 拦截Sql语法构建的处理
Mybatis采用责任链模式,通过动态代理组织多个拦截器(插件),通过这些拦截器可以改变Mybatis的默认行为(诸如SQL重写之类的),由于插件会深入到Mybatis的核心,因此在编写自己的插件前最好了解下它的原理,以便写出安全高效的插件。
拦截器的使用
拦截器介绍及配置
首先我们看下MyBatis拦截器的接口定义:
package org.apache.ibatis.plugin;
import java.util.Properties;
/**
* 拦截器接口,用户可以进行扩展
*
* @author Clinton Begin
*/
public interface Interceptor {
/**
* 具体拦截操作
*
* @param invocation
*
* @return
*
* @throws Throwable
*/
Object intercept(Invocation invocation) throws Throwable;
/**
* 用于创建Executor、ParameterHandler、ResultSetHandler、StatementHandler的代理对象
*
* @param target 就是Executor、ParameterHandler、ResultSetHandler、StatementHandler实例
*
* @return
*/
default Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 用来设置插件的属性值
*
* @param properties
*/
default void setProperties(Properties properties) {
// NOP
}
}
比较简单,只有3个方法。 MyBatis默认没有一个拦截器接口的实现类,开发者们可以实现符合自己需求的拦截器。下面的MyBatis官网的一个拦截器实例:
@Intercepts({@Signature(type= Executor.class, method = "update", args = {MappedStatement.class,Object.class})})
public class ExamplePlugin implements Interceptor {
public Object intercept(Invocation invocation) throws Throwable {
return invocation.proceed();
}
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
public void setProperties(Properties properties) {
}
}
全局xml配置:
<plugins>
<plugin interceptor="org.format.mybatis.cache.interceptor.ExamplePlugin"></plugin>
</plugins>
这个拦截器拦截Executor接口的update方法(其实也就是SqlSession的新增,删除,修改操作),所有执行executor的update方法都会被该拦截器拦截到。
拦截器(插件)是什么时候注册并且执行的?下面按照顺序来进行解析!
注册Executor插件
DefaultSqlSessionFactory打开一个SqlSession连接
/**
* 打开一个session 来自 数据源
*
* @param execType 执行器类型 默认ExecutorType是 ExecutorType.SIMPLE
* @param level 事务隔离级别
* @param autoCommit 是否自动提交事务
*
* @return SqlSession会话对象
*/
private SqlSession openSessionFromDataSource(ExecutorType execType, TransactionIsolationLevel level, boolean autoCommit) {
/*
之前sqlSessionFactoryBuilder.build(InputStream inputStream)
===》 xMLConfigBuilder.parse() ====》 将配置文件中的数据源和sql语句封装赋值给了Configuration对象
所以,这边可以直接获取对象中的相关信息
*/
Transaction tx = null;
try {
// 获取环境配置
final Environment environment = configuration.getEnvironment();
// 事务工厂
final TransactionFactory transactionFactory = getTransactionFactoryFromEnvironment(environment);
// 新建一个事务
tx = transactionFactory.newTransaction(environment.getDataSource(), level, autoCommit);
// 拿到执行器Executor 解析注册Executor类型的拦截器(插件)
final Executor executor = configuration.newExecutor(tx, execType);
// 创建默认的SqlSession对象
return new DefaultSqlSession(configuration, executor, autoCommit);
} catch (Exception e) {
closeTransaction(tx); // may have fetched a connection so lets call close()
throw ExceptionFactory.wrapException("Error opening session. Cause: " + e, e);
} finally {
ErrorContext.instance().reset();
}
}
Configuration 创建Executor执行器
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
executorType = executorType == null ? defaultExecutorType : executorType;
executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) {
// 批处理执行器
executor = new BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) {
// 复用执行器
executor = new ReuseExecutor(this, transaction);
} else {
// 默认执行器
executor = new SimpleExecutor(this, transaction);
}
/*
1、一级缓存默认是开启的
2、二级缓存默认是关闭的
3、先执行一级缓存还是二级缓存,先执行二级缓存
*/
// cacheEnabled一级缓存默认是开启的,为true的话代表开启二级缓存
// 二级缓存通过装饰器模式的方式加载进执行器中
if (cacheEnabled) {
executor = new CachingExecutor(executor);
}
// 重点:插件拦截器过滤链 把执行器怼到里面去
executor = (Executor) interceptorChain.pluginAll(executor);
return executor;
}
InterceptorChain拦截器链注册所有的插件
/**
* 执行所有的过滤器
*
* @param target
*
* @return
*/
public Object pluginAll(Object target) {
// 遍历所有拦截器(插件)
for (Interceptor interceptor : interceptors) {
// 调用拦截器方法进行包装
target = interceptor.plugin(target);
}
return target;
}
Plugin通过JDK动态代理生成一个插件的代理链对象(多个插件之间层层代理)
/**
* 简化动态代理创建的方法
*
* @param target
* @param interceptor
*
* @return
*/
public static Object wrap(Object target, Interceptor interceptor) {
// 获取自定义插件中,通过@Intercepts注解指定的方法
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
// 获取目标对象的类型
Class<?> type = target.getClass();
// 根据类型获取所有拦截的接口
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
// 如果需要拦截,则用JDK动态代理生成一个代理对象
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
return target;
}
执行Executor插件
DefaultSqlSession进行查询操作执行
/**
* 这个方法返回列表
*
* @param statement Unique identifier matching the statement to use.
* @param parameter A parameter object to pass to the statement.
* @param rowBounds Bounds to limit object retrieval
* @param <E>
*
* @return
*/
@Override
public <E> List<E> selectList(String statement, Object parameter, RowBounds rowBounds) {
try {
// 从configuration中拿到具体的MappedStatement,根据id获取
MappedStatement ms = configuration.getMappedStatement(statement);
// 首先关注wrapCollection(parameter) 包装了一下入参
// 然后调用executor.query()查询方法
//1.Executor插件进行拦截
return executor.query(ms, wrapCollection(parameter), rowBounds, Executor.NO_RESULT_HANDLER);
} catch (Exception e) {
throw ExceptionFactory.wrapException("Error querying database. Cause: " + e, e);
} finally {
ErrorContext.instance().reset();
}
}
注册ParameterHandler,ResultSetHandler,StatementHandler插件
SimpleExecutor执行doQuery进行查询操作时会去注册StatementHandler插件并且先注册ParameterHandler,ResultSetHandler插件
@Override
public <E> List<E> doQuery(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
Statement stmt = null;
try {
// 获取核心对象Configuration
Configuration configuration = ms.getConfiguration();
// 拿到一个StatementHandler对象,具体是RoutingStatementHandler类型,里面还注册了插件
StatementHandler handler = configuration.newStatementHandler(wrapper, ms, parameter, rowBounds, resultHandler, boundSql);
// 解析成一个prepareStatement对象,并设置参数等
stmt = prepareStatement(handler, ms.getStatementLog());
// 用StatementHandler来处理查询,默认是PreparedStatementHandler
return handler.query(stmt, resultHandler);
} finally {
// 释放资源
closeStatement(stmt);
}
}
Configuration来进行StatementHandler插件的注册
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
return statementHandler;
}
RoutingStatementHandler来创建RoutingStatementHandler实例时会先注册ParameterHandler,ResultSetHandler插件
public RoutingStatementHandler(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 核心就在这,根据不同的statementType属性来创建对应的StatementHandler具体实现
switch (ms.getStatementType()) {
case STATEMENT:
delegate = new SimpleStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
break;
case PREPARED:
delegate = new PreparedStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
break;
case CALLABLE:
delegate = new CallableStatementHandler(executor, ms, parameter, rowBounds, resultHandler, boundSql);
break;
default:
throw new ExecutorException("Unknown statement type: " + ms.getStatementType());
}
}
BaseStatementHandler的构造函数里面会注册ParameterHandler,ResultSetHandler插件
protected BaseStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
this.configuration = mappedStatement.getConfiguration();
this.executor = executor;
this.mappedStatement = mappedStatement;
this.rowBounds = rowBounds;
this.typeHandlerRegistry = configuration.getTypeHandlerRegistry();
this.objectFactory = configuration.getObjectFactory();
if (boundSql == null) { // issue #435, get the key before calculating the statement
generateKeys(parameterObject);
boundSql = mappedStatement.getBoundSql(parameterObject);
}
this.boundSql = boundSql;
//注册ParameterHandler插件
this.parameterHandler = configuration.newParameterHandler(mappedStatement, parameterObject, boundSql);
//注册ResultSetHandler插件
this.resultSetHandler = configuration.newResultSetHandler(executor, mappedStatement, rowBounds, parameterHandler, resultHandler, boundSql);
}
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
return parameterHandler;
}
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
ResultHandler resultHandler, BoundSql boundSql) {
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
return resultSetHandler;
}
这样下来四种类型的拦截器(插件)就这样注册好了。Executor也执行完了。
执行StatementHandler拦截器
回到SimpleExecutor的doQuery方法里面调用了prepareStatement方法
private Statement prepareStatement(StatementHandler handler, Log statementLog) throws SQLException {
Statement stmt;
// 获取JDBC连接对象
Connection connection = getConnection(statementLog);
// 创建Statement对象 这里会打印日志==> Preparing: select id,name,age,hobby from student WHERE id in ( ? , ? , ? ) AND name like ? AND age = ? order by name desc
stmt = handler.prepare(connection, transaction.getTimeout());
// 继续调用StatementHandler的方法设置参数
handler.parameterize(stmt);
return stmt;
}
执行ParameterHandler拦截器
StatementHandler调用parameterize会调用ParameterHandler的setParameters方法所以ParameterHandler拦截器也是这时候执行的
@Override
public void parameterize(Statement statement) throws SQLException {
//这里调用插件进行拦截
parameterHandler.setParameters((PreparedStatement) statement);
}
执行ResultSetHandler拦截器
PreparedStatementHandler执行query方法的时候会调用到ResultSetHandler的handleResultSets方法。这时候会调用到ResultSetHandler注册的拦截器
@Override
public <E> List<E> query(Statement statement, ResultHandler resultHandler) throws SQLException {
// 强转为PreparedStatement
PreparedStatement ps = (PreparedStatement) statement;
// JDBC执行!
ps.execute();
// 处理结果集
return resultSetHandler.handleResultSets(ps);
}
这时候Mybatis的拦截器的注册和使用就都完了。执行顺序 Executor注册->Executor执行->ParameterHandler注册->StatementHandler注册->ResultSetHandler注册->ParameterHandler执行->StatementHandler执行->ResultSetHandler执行
手写一个Mybatis分页插件
package org.apache.ibatis.mytest;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;
import java.sql.*;
import java.util.Properties;
import static org.apache.ibatis.reflection.SystemMetaObject.DEFAULT_OBJECT_FACTORY;
import static org.apache.ibatis.reflection.SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY;
@Intercepts({
// @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
// @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
// @Signature(type = ParameterHandler.class, method = "getParameterObject", args = {}),
// @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class}),
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class}),
// @Signature(type = StatementHandler.class, method = "query", args = {Statement.class,ResultHandler.class}),
// @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class}),
})
@Slf4j
public class PageInterceptor implements Interceptor {
private static final String PARAM_KEY = "tenant";
//
private String defaultDialect = "mysql";
private String defaultPageSqlId = ".*Page$";
private String dialect;
private String pageSqlId;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaStatementHandler = MetaObject.forObject(statementHandler,
DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
// 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环
// 可以分离出最原始的的目标类)
while (metaStatementHandler.hasGetter("h")) {
Object object = metaStatementHandler.getValue("h");
metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY,
DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
}
// 分离最后一个代理对象的目标类
while (metaStatementHandler.hasGetter("target")) {
Object object = metaStatementHandler.getValue("target");
metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY,
DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
}
Configuration configuration = (Configuration) metaStatementHandler.
getValue("delegate.configuration");
dialect = configuration.getVariables().getProperty("dialect");
if (null == dialect || "".equals(dialect)) {
log.warn("Property dialect is not setted,use default 'mysql' ");
dialect = defaultDialect;
}
pageSqlId = configuration.getVariables().getProperty("pageSqlId");
if (null == pageSqlId || "".equals(pageSqlId)) {
log.warn("Property pageSqlId is not setted,use default '.*Page$' ");
pageSqlId = defaultPageSqlId;
}
MappedStatement mappedStatement = (MappedStatement)
metaStatementHandler.getValue("delegate.mappedStatement");
// 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的
// MappedStatement的sql
if (mappedStatement.getId().matches(pageSqlId)) {
BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
Object parameterObject = boundSql.getParameterObject();
if (parameterObject == null) {
throw new NullPointerException("parameterObject is null!");
} else {
// 分页参数作为参数对象parameterObject的一个属性
PageParameter page = (PageParameter) metaStatementHandler
.getValue("delegate.boundSql.parameterObject");
String sql = boundSql.getSql();
// 重写sql
String pageSql = buildPageSql(sql, page);
metaStatementHandler.setValue("delegate.boundSql.sql", pageSql);
// 采用物理分页后,就不需要mybatis的内存分页了,所以重置下面的两个参数
metaStatementHandler.setValue("delegate.rowBounds.offset",
RowBounds.NO_ROW_OFFSET);
metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
Connection connection = (Connection) invocation.getArgs()[0];
// 重设分页参数里的总页数等
setPageParameter(sql, connection, mappedStatement, boundSql, page);
}
}
// 将执行权交给下一个拦截器
return invocation.proceed();
}
/**
* 从数据库里查询总的记录数并计算总页数,回写进分页参数<code>PageParameter</code>,这样调用
* 者就可用通过 分页参数<code>PageParameter</code>获得相关信息。
*
* @param sql
* @param connection
* @param mappedStatement
* @param boundSql
* @param page
*/
private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,
BoundSql boundSql, PageParameter page) {
// 记录总记录数
String countSql = "select count(0) from (" + sql + ") as total";
PreparedStatement countStmt = null;
ResultSet rs = null;
try {
countStmt = connection.prepareStatement(countSql);
BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject());
rs = countStmt.executeQuery();
int totalCount = 0;
if (rs.next()) {
totalCount = rs.getInt(1);
}
page.setTotalCount(totalCount);
int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1);
page.setTotalPage(totalPage);
} catch (SQLException e) {
log.error("Ignore this exception", e);
} finally {
try {
rs.close();
} catch (SQLException e) {
log.error("Ignore this exception", e);
}
try {
countStmt.close();
} catch (SQLException e) {
log.error("Ignore this exception", e);
}
}
}
/**
* 对SQL参数(?)设值
*
* @param ps
* @param mappedStatement
* @param boundSql
* @param parameterObject
* @throws SQLException
*/
private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
Object parameterObject) throws SQLException {
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
parameterHandler.setParameters(ps);
}
private String buildPageSql(String sql, PageParameter page) {
if (page != null) {
StringBuilder pageSql = new StringBuilder();
if ("mysql".equals(dialect)) {
pageSql = buildPageSqlForMysql(sql, page);
} else if ("oracle".equals(dialect)) {
pageSql = buildPageSqlForOracle(sql, page);
} else {
return sql;
}
return pageSql.toString();
} else {
return sql;
}
}
public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {
StringBuilder pageSql = new StringBuilder(100);
String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
pageSql.append(sql);
pageSql.append(" limit " + beginrow + "," + page.getPageSize());
return pageSql;
}
public StringBuilder buildPageSqlForOracle(String sql, PageParameter page) {
StringBuilder pageSql = new StringBuilder(100);
String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
String endrow = String.valueOf(page.getCurrentPage() * page.getPageSize());
pageSql.append("select * from ( select temp.*, rownum row_id from ( ");
pageSql.append(sql);
pageSql.append(" ) temp where rownum <= ").append(endrow);
pageSql.append(") where row_id > ").append(beginrow);
return pageSql;
}
@Override
public Object plugin(Object target) {
return Interceptor.super.plugin(target);
}
@Override
public void setProperties(Properties properties) {
Interceptor.super.setProperties(properties);
}
}
package org.apache.ibatis.mytest;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class PageParameter {
private Integer pageSize;
private Integer currentPage;
private Integer totalCount;
private Integer totalPage;
}