从 PageHelper 到 MyBatis Plugin

1,125 阅读10分钟

我正在参加「掘金·启航计划」

在很多业务场景下我们需要去拦截 SQL,达到不入侵原有代码业务处理一些东西,比如:历史记录、分页操作、数据权限过滤操作、SQL 执行时间性能监控等等,这里我们就可以用到 MyBatis 的插件 Plugin。下面我们来了解一下 Plugin 到底是如何工作的。

一、背景

使用过 MyBatis 框架的朋友们肯定都听说过 PageHelper 这个分页神器吧,其实 PageHelper 的底层实现就是依靠 plugin。下面我们来看一下 PageHelper 是如何利用 plugin 实现分页的。

二、MyBatis 执行概要图

首先我们先看一下 MyBatis 的执行流程图,对其执行流程有一个大体的认识。 Mybatis流程

三、MyBatis 核心对象介绍

MyBatis 代码实现的角度来看,MyBatis 的主要的核心部件有以下几个:

  • Configuration:初始化基础配置,比如 MyBatis 的别名等,一些重要的类型对象,如,插件,映射器,ObjectFactorytypeHandler 对象,MyBatis 所有的配置信息都维持在 Configuration 对象之中。
  • SqlSessionFactorySqlSession 工厂,用于生产 SqlSession
  • SqlSession: 作为 MyBatis 工作的主要顶层 API,表示和数据库交互的会话,完成必要数据库增删改查功能
  • ExecutorMyBatis 执行器,是 MyBatis 调度的核心,负责 SQL 语句的生成和查询缓存的维护
  • StatementHandler:封装了 JDBC Statement 操作,负责对 JDBC Statement 的操作,如设置参数、将 Statement 结果集转换成List集合。
  • ParameterHandler:负责对用户传递的参数转换成 JDBC Statement 所需要的参数,
  • ResultSetHandler:负责将 JDBC 返回的 ResultSet 结果集对象转换成 List 类型的集合;
  • TypeHandler:负责 java 数据类型和 jdbc 数据类型之间的映射和转换
  • MappedStatementMappedStatement 维护了一条 <select|update|delete|insert> 节点的封装,
  • SqlSource:负责根据用户传递的 parameterObject,动态地生成 SQL 语句,将信息封装到 BoundSql 对象中,并返回
  • BoundSql:表示动态生成的 SQL 语句以及相应的参数信息

说了这么多,怎么还没进入正题啊,别急,下面就开始讲解 Plugin 的实现原理。

四、Plugin 实现原理

MyBatis 支持对 Executor、StatementHandler、PameterHandler和ResultSetHandler 接口进行拦截,也就是说会对这4种对象进行代理。下面我们结合 PageHelper 来讲解 Plugin 是怎样实现的。

1、定义 Plugin

要使用自定义 Plugin 首先要实现 Interceptor 接口。可以通俗的理解为一个 Plugin 就是一个拦截器。

public interface Interceptor {
  // 实现拦截逻辑   
  Object intercept(Invocation invocation) throws Throwable;
  // 获取代理类
  Object plugin(Object target);
  // 初始化配置
  void setProperties(Properties properties);

}

现在我们来看一下 PageHelper 是如何通过 Plugin 实现分页的。

@Intercepts(
    {
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
    }
)
public class PageInterceptor implements Interceptor {
    //缓存count查询的ms
    protected Cache<CacheKey, MappedStatement> msCountMap = null;
    private Dialect dialect;
    private String default_dialect_class = "com.github.pagehelper.PageHelper";
    private Field additionalParametersField;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        try {
            Object[] args = invocation.getArgs();
            MappedStatement ms = (MappedStatement) args[0];
            Object parameter = args[1];
            RowBounds rowBounds = (RowBounds) args[2];
            ResultHandler resultHandler = (ResultHandler) args[3];
            Executor executor = (Executor) invocation.getTarget();
            CacheKey cacheKey;
            BoundSql boundSql;
            //由于逻辑关系,只会进入一次
            if(args.length == 4){
                //4 个参数时
                boundSql = ms.getBoundSql(parameter);
                cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
            } else {
                //6 个参数时
                cacheKey = (CacheKey) args[4];
                boundSql = (BoundSql) args[5];
            }
            List resultList;
            //调用方法判断是否需要进行分页,如果不需要,直接返回结果
            if (!dialect.skip(ms, parameter, rowBounds)) {
                //反射获取动态参数
                Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
                //判断是否需要进行 count 查询
                if (dialect.beforeCount(ms, parameter, rowBounds)) {
                    //创建 count 查询的缓存 key
                    CacheKey countKey = executor.createCacheKey(ms, parameter, RowBounds.DEFAULT, boundSql);
                    countKey.update(MSUtils.COUNT);
                    MappedStatement countMs = msCountMap.get(countKey);
                    if (countMs == null) {
                        //根据当前的 ms 创建一个返回值为 Long 类型的 ms
                        countMs = MSUtils.newCountMappedStatement(ms);
                        msCountMap.put(countKey, countMs);
                    }
                    //调用方言获取 count sql
                    String countSql = dialect.getCountSql(ms, boundSql, parameter, rowBounds, countKey);
                    countKey.update(countSql);
                    BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameter);
                    //当使用动态 SQL 时,可能会产生临时的参数,这些参数需要手动设置到新的 BoundSql 中
                    for (String key : additionalParameters.keySet()) {
                        countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
                    }
                    //执行 count 查询
                    Object countResultList = executor.query(countMs, parameter, RowBounds.DEFAULT, resultHandler, countKey, countBoundSql);
                    Long count = (Long) ((List) countResultList).get(0);
                    //处理查询总数
                    //返回 true 时继续分页查询,false 时直接返回
                    if (!dialect.afterCount(count, parameter, rowBounds)) {
                        //当查询总数为 0 时,直接返回空的结果
                        return dialect.afterPage(new ArrayList(), parameter, rowBounds);
                    }
                }
                //判断是否需要进行分页查询
                if (dialect.beforePage(ms, parameter, rowBounds)) {
                    //生成分页的缓存 key
                    CacheKey pageKey = cacheKey;
                    //处理参数对象
                    parameter = dialect.processParameterObject(ms, parameter, boundSql, pageKey);
                    //调用方言获取分页 sql
                    String pageSql = dialect.getPageSql(ms, boundSql, parameter, rowBounds, pageKey);
                    BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameter);
                    //设置动态参数
                    for (String key : additionalParameters.keySet()) {
                        pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
                    }
                    //执行分页查询
                    resultList = executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
                } else {
                    //不执行分页的情况下,也不执行内存分页
                    resultList = executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, cacheKey, boundSql);
                }
            } else {
                //rowBounds用参数值,不使用分页插件处理时,仍然支持默认的内存分页
                resultList = executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
            }
            return dialect.afterPage(resultList, parameter, rowBounds);
        } finally {
            dialect.afterAll();
        }
    }

    @Override
    public Object plugin(Object target) {
        //TODO Spring bean 方式配置时,如果没有配置属性就不会执行下面的 setProperties 方法,就不会初始化,因此考虑在这个方法中做一次判断和初始化
        //TODO https://github.com/pagehelper/Mybatis-PageHelper/issues/26
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        //缓存 count ms
        msCountMap = CacheFactory.createCache(properties.getProperty("msCountCache"), "ms", properties);
        String dialectClass = properties.getProperty("dialect");
        if (StringUtil.isEmpty(dialectClass)) {
            dialectClass = default_dialect_class;
        }
        try {
            Class<?> aClass = Class.forName(dialectClass);
            dialect = (Dialect) aClass.newInstance();
        } catch (Exception e) {
            throw new PageException(e);
        }
        dialect.setProperties(properties);
        try {
            //反射获取 BoundSql 中的 additionalParameters 属性
            additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
            additionalParametersField.setAccessible(true);
        } catch (NoSuchFieldException e) {
            throw new PageException(e);
        }
    }

}

代码太长不看系列: 其实这段代码最主要的逻辑就是在执行 Executor 方法的时候,拦截 query 也就是查询类型的 SQL, 首先会判断它是否需要分页,如果需要分页就会根据查询参数在 SQL 末尾加上 limit pageNum, pageSize来实现分页。

2、注册拦截器

  • 通过 SqlSessionFactoryBean 去构建 Configuration 添加拦截器并构建获取 SqlSessionFactory

public class SqlSessionFactoryBean implements FactoryBean<SqlSessionFactory>, InitializingBean, ApplicationListener<ApplicationEvent> {
 
 
    // ... 此处省略部分源码

    protected SqlSessionFactory buildSqlSessionFactory() throws IOException {
        // ... 此处省略部分源码
 
        // 查看是否注入拦截器,有的话添加到Interceptor集合里面
        if (!isEmpty(this.plugins)) {
            for (Interceptor plugin : this.plugins) {
                configuration.addInterceptor(plugin);
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("Registered plugin: '" + plugin + "'");
                }
            }
        }
 
        // ... 此处省略部分源码
 
        return this.sqlSessionFactoryBuilder.build(configuration);
    }
 
    // ... 此处省略部分源码
}


  • 通过原始的 XMLConfigBuilder 构建 configuration 添加拦截器
public class XMLConfigBuilder extends BaseBuilder {
    //解析配置
    private void parseConfiguration(XNode root) {
        try {
            //省略部分代码
            pluginElement(root.evalNode("plugins"));
 
        } catch (Exception e) {
            throw new BuilderException("Error parsing SQL Mapper Configuration. Cause: " + e, e);
        }
    }
 
    private void pluginElement(XNode parent) throws Exception {
        if (parent != null) {
            for (XNode child : parent.getChildren()) {
                String interceptor = child.getStringAttribute("interceptor");
                Properties properties = child.getChildrenAsProperties();
                Interceptor interceptorInstance = (Interceptor) resolveClass(interceptor).newInstance();
                interceptorInstance.setProperties(properties);
                //调用InterceptorChain.addInterceptor
                configuration.addInterceptor(interceptorInstance);
            }
        }
    }
}

上面是两种不同的形式构建 configuration 并添加拦截器 interceptor,上面第二种一般是以前 XML 配置的情况,这里主要是解析配置文件的 plugin 节点,根据配置的 interceptor 属性实例化 Interceptor 对象,然后添加到 Configuration 对象中的 InterceptorChain 属性中。

如果定义多个拦截器就会它们链起来形成一个拦截器链,初始化配置文件的时候就把所有的拦截器添加到拦截器链中。

public class InterceptorChain {
 
 
  private final List<Interceptor> interceptors = new ArrayList<Interceptor>();
 
  public Object pluginAll(Object target) {
    //循环调用每个Interceptor.plugin方法
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }
   // 添加拦截器
  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }
  
  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }
 
}

3、执行拦截器

从以下代码可以看出 MyBatis 在实例化 Executor、ParameterHandler、ResultSetHandler、StatementHandler 四大接口对象的时候调用 interceptorChain.pluginAll() 方法插入进去的。其实就是循环执行拦截器链所有的拦截器的 plugin() 方法, MyBatis 官方推荐的 plugin 方法是 Plugin.wrap() 方法,这个就会生成代理类。


public class Configuration {
 
    protected final InterceptorChain interceptorChain = new InterceptorChain();
    //创建参数处理器
    public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
        //创建ParameterHandler
        ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
        //插件在这里插入
        parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
        return parameterHandler;
    }
 
    //创建结果集处理器
    public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
                                                ResultHandler resultHandler, BoundSql boundSql) {
        //创建DefaultResultSetHandler
        ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
        //插件在这里插入
        resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
        return resultSetHandler;
    }
 
    //创建语句处理器
    public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        //创建路由选择语句处理器
        StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
        //插件在这里插入
        statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
        return statementHandler;
    }
 
    public Executor newExecutor(Transaction transaction) {
        return newExecutor(transaction, defaultExecutorType);
    }
 
    //产生执行器
    public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
        executorType = executorType == null ? defaultExecutorType : executorType;
        //这句再做一下保护,囧,防止粗心大意的人将defaultExecutorType设成null?
        executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
        Executor executor;
        //然后就是简单的3个分支,产生3种执行器BatchExecutor/ReuseExecutor/SimpleExecutor
        if (ExecutorType.BATCH == executorType) {
            executor = new BatchExecutor(this, transaction);
        } else if (ExecutorType.REUSE == executorType) {
            executor = new ReuseExecutor(this, transaction);
        } else {
            executor = new SimpleExecutor(this, transaction);
        }
        //如果要求缓存,生成另一种CachingExecutor(默认就是有缓存),装饰者模式,所以默认都是返回CachingExecutor
        if (cacheEnabled) {
            executor = new CachingExecutor(executor);
        }
        //此处调用插件,通过插件可以改变Executor行为
        executor = (Executor) interceptorChain.pluginAll(executor);
        return executor;
    }
}

4、Plugin 的动态代理

我们首先看一下Plugin.wrap() 方法,这个方法的作用是为实现Interceptor注解的接口实现类生成代理对象的。

    // 如果是Interceptor注解的接口的实现类会产生代理类 
    public static Object wrap(Object target, Interceptor interceptor) {
    //从拦截器的注解中获取拦截的类名和方法信息
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    //取得要改变行为的类(ParameterHandler|ResultSetHandler|StatementHandler|Executor)
    Class<?> type = target.getClass();
    //取得接口
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    //产生代理,是Interceptor注解的接口的实现类才会产生代理
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(type.getClassLoader(),interfaces,new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

Plugin 中的 getSignatureMap、 getAllInterfaces 两个辅助方法,来帮助判断是否为是否Interceptor注解的接口实现类。

  //取得签名Map,就是获取Interceptor实现类上面的注解,要拦截的是那个类(Executor 
  //,ParameterHandler, ResultSetHandler,StatementHandler)的那个方法 
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
    //取Intercepts注解
    Intercepts interceptsAnnotation =interceptor.getClass().getAnnotation(Intercepts.class);
   
    //必须得有Intercepts注解,没有报错
    if (interceptsAnnotation == null) {
      throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());      
    }
    //value是数组型,Signature的数组
      Signature[] sigs = interceptsAnnotation.value();
    //每个class里有多个Method需要被拦截,所以这么定义
      Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
     for (Signature sig : sigs) {
      Set<Method> methods = signatureMap.get(sig.type());
        if (methods == null) {
          methods = new HashSet<Method>();
          signatureMap.put(sig.type(), methods);
      }
      try {
         Method method = sig.type().getMethod(sig.method(), sig.args());
         methods.add(method);
      } catch (NoSuchMethodException e) {
         throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
      }
    }
    return signatureMap;
  }
    
 //取得接口
 private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
    Set<Class<?>> interfaces = new HashSet<Class<?>>();
      while (type != null) {
        for (Class<?> c : type.getInterfaces()) {
        //拦截其他的无效
        if (signatureMap.containsKey(c)) {
          interfaces.add(c);
        }
      }
      type = type.getSuperclass();
    }
    return interfaces.toArray(new Class<?>[interfaces.size()]);
  }
}

我们来看一下代理类的 query 方法,其实就是调用了 Plugin.invoke() 方法。代理类屏蔽了 intercept 方法的调用。

    public final List query(MappedStatement mappedStatement, Object object, RowBounds rowBounds, ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws SQLException {
        try {
            // 这里的 h 就是一个 Plugin
            return (List)this.h.invoke(this, m5, new Object[]{mappedStatement, object, rowBounds, resultHandler, cacheKey, boundSql});
        }
        catch (Error | RuntimeException | SQLException throwable) {
            throw throwable;
        }
        catch (Throwable throwable) {
            throw new UndeclaredThrowableException(throwable);
        }
    }

最后 Plugin.invoke() 就是判断当前方法是否拦截,如果需要拦截则会调用 Interceptor.intercept() 对当前方法执行拦截逻辑。

public class Plugin implements InvocationHandler {

    ...
  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      //获取需要拦截的方法
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      //是Interceptor实现类注解的方法才会拦截处理
      if (methods != null && methods.contains(method)) {
        //调用Interceptor.intercept,即调用自己写的逻辑
        return interceptor.intercept(new Invocation(target, method, args));
      }
      //最后执行原来逻辑
        return method.invoke(target, args);
    } catch (Exception e) {
        throw ExceptionUtil.unwrapThrowable(e);
    }
  }
    
    ...

总结

我们以 PageHelper 为切入点讲解了 MyBatis Plugin 的实现原理,其中 MyBatis 拦截器用到责任链模式+动态代理+反射机制。 通过上面的分析可以知道,所有可能被拦截的处理类都会生成一个代理类,如果有 N 个拦截器,就会有 N 个代理,层层生成动态代理是比较耗性能的。而且虽然能指定插件拦截的位置,但这个是在执行方法时利用反射动态判断的,初始化的时候就是简单的把拦截器插入到了所有可以拦截的地方。所以尽量不要编写不必要的拦截器,并且拦截器尽量不要写复杂的逻辑。