Mybatis源码-AOP

107 阅读3分钟

1. 快速入门

Service

public interface Service {
    void sayHello(String name);

    String getServiceName();

    void setServiceName(String serviceName);
}

HelloService

public class HelloService implements Service{

    private String serviceName;

    @Override
    public void sayHello(String name) {
        System.out.println("hello " + name);
    }

    public void setServiceName(String serviceName) {
        this.serviceName = serviceName;
    }

    public String getServiceName() {
        return serviceName;
    }
}

现在我要对 HelloService 的 sayHello 进行增强。

编写拦截器

@Intercepts(
    @Signature(
        type = Service.class,
        method = "sayHello",
        args = {String.class}
    )
)
class MyInterceptor1 implements Interceptor {
    /**
     * 编写增强逻辑
     *
     * @param invocation 原始逻辑
     * @return
     */
    @Override
    public Object intercept(Invocation invocation) {

        // 1. 可以调用被代理对象的方法
        Service service = (Service) invocation.getTarget();
        service.setServiceName("hello service");

        // 2. 甚至可以修改原始方法的参数哦
        invocation.getArgs()[0] = "new value";

        System.out.println("前置增强1");
        Object proceed = invocation.proceed();
        System.out.println("后置增强1");
        return proceed;
    }

}


@Intercepts(
    @Signature(
        type = Service.class,
        method = "sayHello",
        args = {String.class}
    )
)
class MyInterceptor2 implements Interceptor {
    /**
     * 编写增强逻辑
     *
     * @param invocation 原始逻辑
     * @return
     */
    @Override
    public Object intercept(Invocation invocation) {

        Service service = (Service) invocation.getTarget();
        service.setServiceName("hello service");
        System.out.println(service.getServiceName());

        System.out.println("前置增强2");
        Object proceed = invocation.proceed();
        System.out.println("后置增强2");
        return proceed;
    }

}

通过注解制定需要拦截的类是 Service.class, 拦截的方法是 sayHello。

通过 intercept 编写增强逻辑

使用

package com.hdu.myMabatisAop;


public class TestMain {
    public static void main(String[] args) {
        MyInterceptor1 myInterceptor1 = new MyInterceptor1();
        MyInterceptor2 myInterceptor2 = new MyInterceptor2();
        InterceptorChain.addInterceptors(myInterceptor2, myInterceptor1);
        Service helloService = (Service) InterceptorChain.pluginAll(new HelloService());
        helloService.sayHello("old value");

        // 前置增强1
        // hello service
        // 前置增强2
        // hello new value
        // 后置增强2
        // 后置增强1
    }
}

2. 源码解析

Signature

Signature 主要用于标记你要增强的方法

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Signature {
    Class<?> type();
    String method();
    Class<?>[] args();
}

Intercepts

内部 Signature数组,意思是一个拦截器可以增强多个方法

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Intercepts {
    Signature[] value();
}

Invocation

封装原始方法执行逻辑

public class Invocation {

    private final Object target;
    private final Method method;
    private final Object[] args;

    public Invocation(Object target, Method method, Object[] args) {
        this.target = target;
        this.method = method;
        this.args = args;
    }


    public Object proceed() {
        try {
            return method.invoke(target, args);
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }
}

Interceptor

拦截器,主要是制定增强逻辑

public interface Interceptor {
    
    // 制定增强逻辑
    Object intercept(Invocation invocation);

    // 制定扩展逻辑 一般默认都是  return Plugin.wrap(target, this);
    Object plugin(Object target) {
       return Plugin.wrap(target, this);
    }
}

InterceptorChain

内部聚合 Interceptor,可以实现 n 个 Interceptor 对同一个对象进行增强

public class InterceptorChain {


    private static final List<Interceptor> INTERCEPTORS = new ArrayList<>();

    public static Object pluginAll(Object target) {
        for (Interceptor interceptor : INTERCEPTORS) {
            target = interceptor.plugin(target);
        }
        return target;
    }

    public static void addInterceptor(Interceptor interceptor) {
        INTERCEPTORS.add(interceptor);
    }

}

Plugin

Plugin代理原始对象。并且实现invoke逻辑,invoke逻辑交给Interceptor实现

/**
 * 代理原始对象。并且实现invoke逻辑,invoke逻辑交给Interceptor实现
 */
public class Plugin implements InvocationHandler {

    private final Object target;
    private final Interceptor interceptor;
    private final Map<Class<?>, Set<Method>> signatureMap;


    private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
        this.target = target;
        this.interceptor = interceptor;
        this.signatureMap = signatureMap;
    }

    public static Object wrap(Object target, Interceptor interceptor) {
        Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
        Class<?> type = target.getClass();
        Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
        if (interfaces.length > 0) {
            return newProxyInstance(
                type.getClassLoader(),
                interfaces,
                new Plugin(target, interceptor, signatureMap));
        }
        return target;
    }

    private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
        Set<Class<?>> interfaces = new HashSet<>();
        while (type != null) {
            for (Class<?> c : type.getInterfaces()) {
                if (signatureMap.containsKey(c)) {
                    interfaces.add(c);
                }
            }
            type = type.getSuperclass();
        }
        return interfaces.toArray(new Class<?>[0]);
    }

    @SuppressWarnings("all")
    private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
        Intercepts intercepts = interceptor.getClass().getAnnotation(Intercepts.class);
        unValidParams(intercepts == null, "@Intercepts must be set on the Interceptor class");
        Signature[] sigs = intercepts.value();
        unValidParams(sigs.length == 0, "@Intercepts method must have Signature");
        Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
        for (Signature sig : sigs) {
            Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
            try {
                Method method = sig.type().getMethod(sig.method(), sig.args());
                methods.add(method);
            } catch (NoSuchMethodException e) {
                unValidParams(
                    true,
                    String.format("Could not find method on %s named %s", sig.type().getName(), sig.method())
                );
            }
        }
        return signatureMap;
    }


    @Override
    public Object invoke(Object proxy, Method method, Object[] args) {
        try {
            Set<Method> methods = signatureMap.get(method.getDeclaringClass());
            if (methods != null && methods.contains(method)) {
                return interceptor.intercept(new Invocation(target, method, args));
            }
            return method.invoke(target, args);
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }
}

3. mybatis的aop用在哪里?

mybatis的aop体系主要用在插件增强。

mybatis里面可以可以增强的插件有

  1. Executor(执行器)
  2. StatemenHandler(JDBC处理器)
  3. ParameterHandler(参数处理器)
  4. ResultSetHandler(结果处理器)

举个例子,利用插件实现自动分页。

interceptor.png

我们使用 mybatis的插件体系,增强预处理阶段。

/**
 * 分页拦截器
 */
@Intercepts(
    @Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class}
    )
)
public class PageInterceptor implements Interceptor {

    private final String COUNT_SQL_TEMPLATE = "select count(*) from (%s) as _page";
    private final String NEW_SQL_TEMPLATE = "%s offset %d, limit %d";

    @SuppressWarnings("all")
    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        Object parameterObject = boundSql.getParameterObject();

        Page page = null;

        if (parameterObject instanceof Page) {
            page = (Page) parameterObject;
        } else if (parameterObject instanceof Map) {
            page = (Page) ((Map) (parameterObject))
                .values()
                .stream()
                .filter(p -> p instanceof Page)
                .findFirst()
                .orElse(null);
        }


        /**
         * 1. 检测当前是否需要分页条件
         */
        if (page == null) {
            return invocation.proceed();
        }
        /**
         * 2. 设置总行数 select count(*) from (sql)
         *
         */
        else  {
            page.setTotal(selectCount(invocation));
        }

        /**
         * 3. 修改原有sql (select * from user) offset 0, limit 50
         */
        String newSql = format(
            NEW_SQL_TEMPLATE,
            boundSql.getSql(),
            page.getOffset(),
            page.getSize()
        );
        SystemMetaObject.forObject(boundSql).setValue("sql", newSql);

        return invocation.proceed();
    }

    private int selectCount(Invocation invocation) throws SQLException {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        String sql = boundSql.getSql();
        String countSql = format(
            COUNT_SQL_TEMPLATE,
            sql
        );
        Connection connection = (Connection) invocation.getArgs()[0];
        PreparedStatement preparedStatement = connection.prepareStatement(countSql);
        // 设置参数
        statementHandler.getParameterHandler().setParameters(preparedStatement);
        ResultSet resultSet = preparedStatement.executeQuery();
        int totalCount = 0;
        if (resultSet.next()) {
            totalCount = resultSet.getInt(1);
        }

        resultSet.close();
        preparedStatement.close();

        return totalCount;
    }
}

4.源码

mybatis-aop: mybatis-aop (gitee.com)