Mybatis源码解析(四) -- 拦截器链

345 阅读5分钟

拦截器链

还是回到这个Demo

  public static void main(String[] args) throws Exception {
    String resource = "mybatis.xml";
    InputStream inputStream = Resources.getResourceAsStream(resource);
    SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream);
    SqlSession sqlSession = sqlSessionFactory.openSession();
    UserMapper mapper = sqlSession.getMapper(UserMapper.class);
    System.out.println(mapper.findAll());
    sqlSession.close();
  }

当我们执行sqlSessionFactory.openSession();的时候,首先需要创建一个Executor。

DefaultSqlSessionFactory.openSession,这里configuration.getDefaultExecutorType()指的是ExecutorType.SIMPLE(默认),还有REUSE, BATCH两种类型。

  public SqlSession openSession() {
    return openSessionFromDataSource(configuration.getDefaultExecutorType(), null, false);
  }

DefaultSqlSessionFactory.openSessionFromDataSource

  private SqlSession openSessionFromDataSource(ExecutorType execType, TransactionIsolationLevel level, boolean autoCommit) {
    Transaction tx = null;
    try {
      final Environment environment = configuration.getEnvironment();
      final TransactionFactory transactionFactory = getTransactionFactoryFromEnvironment(environment);
      tx = transactionFactory.newTransaction(environment.getDataSource(), level, autoCommit);
      //创建执行器
      final Executor executor = configuration.newExecutor(tx, execType);
      return new DefaultSqlSession(configuration, executor, autoCommit);
    } catch (Exception e) {
      closeTransaction(tx); // may have fetched a connection so lets call close()
      throw ExceptionFactory.wrapException("Error opening session.  Cause: " + e, e);
    } finally {
      ErrorContext.instance().reset();
    }
  }

Configuration.newExecutor

  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    //之前我们说到,Executor是嵌套的模式,这里就可以看出来
    //外层是CachingExecutor,内层SimpleExecutor
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    //最后加入拦截器链
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }

与之类似的,我们还能再Configuration中找到其他会加入到拦截器中的类ParameterHandler,ResultSetHandler,StatementHandler.它们在创建的时候都会加入到拦截器链中。

再次复习一下,ParameterHandler和ResultSetHandler是StatementHandler中的两个成员变量,分别用来处理入参和结果集。

public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
  ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
  parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
  return parameterHandler;
}

public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
    ResultHandler resultHandler, BoundSql boundSql) {
  ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
  resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
  return resultSetHandler;
}

public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
  StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
  statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
  return statementHandler;
}

我们现在来分析一下拦截器。与我们第一篇文章的Demo类似,我们注册的拦截器会被存放到interceptors这个list集合中,然后使用pluginAll这个方法来处理我们的类。

public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }
}
public interface Interceptor {

  Object intercept(Invocation invocation) throws Throwable;

  default Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }
  default void setProperties(Properties properties) {
    // NOP
  }
}

用下面这个注解举例,这里拦截了StatementHandler中的query。

//@Intercepts({
//  @Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class}),
//})

Plugin.wrap。Plugin还实现了InvocationHandler接口,是个代理。

  public static Object wrap(Object target, Interceptor interceptor) {
    //这个方法解析拦截器上的注解@Intercepts,@Signature
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    //根据传入的target对象,查找target的接口,并根据signatureMap中的值,找到需要拦截的接口。
    //例如我拦截的是StatementHandler中的方法,query
    //根据传入的target RountingStatementHandler找到它的接口类StatementHandler.class
    //如果这个接口类在signatureMap存在,说明是需要被拦截的,加入到interfaces中
    //如果进来的target是Executor,或者其它接口,不属于signatureMap中,那就不用拦截。
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    //如果大于0说明要拦截
    if (interfaces.length > 0) {
    //这里也是动态代理
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

Plugin.getSignatureMap 解析拦截器上的注解@Intercepts,@Signature

  private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
    //获取注解@Intercepts
    Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
    // issue #251
    if (interceptsAnnotation == null) {
      throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
    }
    //获取注解@Intercepts中的@Signature
    Signature[] sigs = interceptsAnnotation.value();
	//该容器用来记录需要拦截的接口,key是class类型,value是需要拦截的接口
    Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
    for (Signature sig : sigs) {
      //根据例子sig.type()是StatementHandler
      Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
      try {
      //根据方法名称,和方法参数确定一个方法
      //根据例子sig.method()获取方法名称query
      //例StatementHandler.query中有两个参数Statement, ResultHandler,那么根据这个方法签名就能找到该方法了。
        Method method = sig.type().getMethod(sig.method(), sig.args());
        methods.add(method);
      } catch (NoSuchMethodException e) {
        throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
      }
    }
    return signatureMap;
  }

Plugin.getAllInterfaces

  private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
    Set<Class<?>> interfaces = new HashSet<>();
    while (type != null) {
      //获取该类的所有接口
      for (Class<?> c : type.getInterfaces()) {
        //判断该接口是否需要被拦截
        if (signatureMap.containsKey(c)) {
          interfaces.add(c);
        }
      }
      type = type.getSuperclass();
    }
    //返回需要类拦截的接口类
    return interfaces.toArray(new Class<?>[interfaces.size()]);
  }

接下来是动态代理的部分Plugin,这是重写的invoke方法

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      //用根据method获取到class,使用class当作key获取到这个class中需要代理的方法集
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      //这里接口中的所有方法都会被代理,但是我们只会拦截特定的接口。
      if (methods != null && methods.contains(method)) {
      //如果方法集中有包含当前传入的method,那么就执行拦截器中的intercept方法
        return interceptor.intercept(new Invocation(target, method, args));
      }
      //否则就直接执行方法体本身
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

最后我们需要条用Invocation.proceed,才能执行方法体本身.

  public Object proceed() throws InvocationTargetException, IllegalAccessException {
    //这里类似递归调用,如果还有拦截器,继续拦截
    return method.invoke(target, args);
  }

通常我们可能会有多个拦截器,那么在一个拦截器执行完它的方法之后 执行method.invoke,但是由于该方法又被下一层的拦截器拦截了,所以又会调用下一个拦截器的intercept方法,直到最后一层的拦截器。从代码中也能看出,是先拦截器中的intercept,然后再执行要拦截的方法

小结

梳理一下流程。Mybatis会解析配置文件中的plugins节点,将拦截器通过addInterceptor方法添加到InterceptorChain的interceptors容器中。之后在Mybatis创建Executor,StatementHandler,ParameterHandler,ResultSetHandler这四大类的时候,会遍历所有的拦截器。如果该拦截器拦截的是这些接口,那么就会将其包装成动态代理。

当这个接口执行方法时,在invoke方法中再具体判断是否拦截的是注解中的方法,如果不是就执行方法本身,否则执行拦截器的intercept。在所有的拦截器都遍历之后,就形成了一层一层的代理。

具体举例进来说(plugins配置忽略)

@Intercepts({
  @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})
})
public class MyInterceptor implements Interceptor {
  @Override
  public Object intercept(Invocation invocation) throws Throwable {
	System.out.println("this is a interceptor");
    return invocation.proceed();
  }
}

这里我们写了一个拦截器。那么就会拦截StatementHandler接口的所有方法都会被动态代理,但是只有prepare被加入signatureMap中,所以只有该方法会执行拦截器中的intercept。而其他方法就直接执行。其他的接口如Executor,我们没有写拦截器,那么代理对象都不会生成。