mybatis+spring 多数据源单一方法事务一致性

858 阅读4分钟

这是我参与8月更文挑战的第22天,活动详情查看:8月更文挑战

前言

之前我发布了一篇 spring多数据源事务一致性的文章,虽然实现了需求但是确实henlow,侵入性很强,一直想改造优化,在我学习了spring 事务的相关源码和不断尝试终于算是优雅实现

这里可能有朋友会疑惑,多数据源事务一致性这不是用seata就行了么,为啥要自己实现

所以在开始前我先阐述理由:

  • seata是分布式事务,本文的应用场景适合的是单体并且在同一个方法中的事务
  • 另外小编项目中的数据源是动态的,提前是不知道的
  • 最主要原因还是想自己写写好玩,借用Linux之父林纳斯的一句话来说就是Just for Fun

思路

spring 事务源码(一)事务切面注入、解析

spring 事务源码(二)事务切面详细

spring 事务源码(三)如何保证被@Transactional标记方法中的所有sql都在一个事务内

spring自定义切面实现拦截类全方法

基础的代码思路都是来自上面的文章

首先 spring多数据源事务一致性的文章当时这个文章其实已经说得差不多了,目前就是减少代码的侵入性,变得优雅些,毕竟让写代码的人手动回滚、提交实在是low

当时难点在于无法分辨执行的sql或者service层的方法是属于哪个数据源事务,但是spring自定义切面实现拦截类全方法这篇文章写完之后我就有了新的思路

完全可以在所有的service层的方法形成切面,切面伪代码如下

try{
  if(当前线程上下文有多数据源事务标记){
  获取当前service方法需要执行的数据源事务
    获取之后将事务管理器注册到异步线程A
  用异步线程A执行执行service层方法
}
}catch{
  
}

同时还得用一个新的注解暂时定义为MoreDataTransactional,切面的伪代码如下

try{
  在当前线程中设置多数据源事务标记
    
    获取当前线程中所有注册的事务管理器
    iter
    commit
}
}catch{
  获取当前线程中所有注册的事务管理器
    iter
  rollback
}

这样就能保证每个service层的方法执行的sql是在唯一对应的事务管理器,这里我们用的事务管理器实际上还是spring 自带的

因为mybatis执行sql创建sqlsession还是会在本线程中获取事务管理器,而获取出来的就是spring 容器中的(如果有事务的话)

具体可以参考org.mybatis.spring.SqlSessionUtils#getSqlSession(...)

代码实现

自定义事务注解及其切面实现

/**
 * 多数据源事务处理
 * 当方法有此注解会在 {@link MoDataTransactionSourcePointcut} 解析
 * 具体处理是在 {@link MoreDataTransactionAnnotationInterceptor}
 * 当此方法中所有相关mybatis service 方法调用均会被 {@link MoreDbTransactionInterceptor} 解析处理
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface MoreDataTransactional {
}
/**
 * 多事务注解解析判断
 * {@link MoreDataTransactional}
 */
@Slf4j
public class MoDataTransactionSourcePointcut extends StaticMethodMatcherPointcut implements Serializable {
	@Override
	public boolean matches(Method method, @Nullable Class<?> targetClass) {
		if (targetClass == null) {
			return false;
		}
		if (method.isAnnotationPresent(MoreDataTransactional.class)) {
			log.info("{} MoDataTransactionSourcePointcut ===> {}", targetClass.getName(), method.getName());
			return true;
		}
		return false;
	}


	@Override
	public int hashCode() {
		return MoDataTransactionSourcePointcut.class.hashCode();
	}
}
/**
 * 多事务注解切面
 * {@link MoreDataTransactional}
 */
@Slf4j
public class MoreDataTransactionAnnotationInterceptor extends TransactionInterceptor {
   // 当前线程中自定义的事务管理器
    private static final ThreadLocal<NewMoreDataSourceTransactionManager> threadLocal = new ThreadLocal<>();

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
        NewMoreDataSourceTransactionManager transactionManager = new NewMoreDataSourceTransactionManager();
        threadLocal.set(transactionManager);

        try {
            return transactionManager.dos(invocation);
        } finally {
            threadLocal.remove();
        }
    }

    public static NewMoreDataSourceTransactionManager getDataTransactionManager(){
        return threadLocal.get();
    }
}
/**
 * 多数据元事务注解处理器
 */
@Component
public class MoreDataTransactionBeanFactoryPointcutAdvisor extends AbstractBeanFactoryPointcutAdvisor {

    MoDataTransactionSourcePointcut pointcut = new MoDataTransactionSourcePointcut();

    public MoreDataTransactionBeanFactoryPointcutAdvisor() {
        setAdvice(new MoreDataTransactionAnnotationInterceptor());
    }

    public void setClassFilter(ClassFilter classFilter) {
        this.pointcut.setClassFilter(classFilter);
    }

    @Override
    public Pointcut getPointcut() {
        return pointcut;
    }
}
@Slf4j
public class NewMoreDataSourceTransactionManager {

  
		// 不同租户不同的事务管理器
    Map<String, ViceTransactionExecutor> viceTransactionExecutorMap;


    public ViceTransactionExecutor getTran(String dbKey) {
        if (this.viceTransactionExecutorMap == null) {
            this.viceTransactionExecutorMap = new HashMap<>();
        }
        ViceTransactionExecutor viceTransactionExecutor = this.viceTransactionExecutorMap.get(dbKey);
      // 没有则创建
        if (viceTransactionExecutor == null) {
            viceTransactionExecutor = new ViceTransactionExecutor(dbKey);
            this.viceTransactionExecutorMap.put(dbKey, viceTransactionExecutor);
        }
        return viceTransactionExecutor;
    }

    // 事务执行
    public Object dos(MethodInvocation invocation) throws Throwable {

        try {
            log.info("dos 开始运行");
            Object run = invocation.proceed();
            finish();
            return run;
        } catch (Throwable e) {
            e.printStackTrace();
            log.warn("运行异常 {},", e.getMessage());
            rollback();
            throw e;
        }finally {
            log.info("dos 运行结束");
        }

    }

    public <T> T submit(String dbKey, Callable<T> callable) throws Exception {
        return getTran(dbKey).submit(callable);
    }

    private void finish(){
        log.info("dos 开始提交事务");
        this.viceTransactionExecutorMap.values().forEach(ViceTransactionExecutor::finish);
        log.info("dos 事务提交结束 ");
    }

    private void rollback(){
        log.info("dos 开始回滚事务");
        this.viceTransactionExecutorMap.values().forEach(ViceTransactionExecutor::rollback);
        log.info("dos 事务回滚结束 ");
    }


}


@Slf4j
public class ViceTransactionExecutor {

    // 真正调用的事务管理器
    PlatformTransactionManager transactionManager;

    // 是否初始化
    boolean init;

    // 事务状态
    TransactionStatus transactionStatus;

    // 单线程池(保证在一个线程中运行)
    final ThreadPoolExecutor singleThreadExecutor;

  // 是否结束
    boolean finish;

  // 数据源租户关键字
    String dbKey;

    public ViceTransactionExecutor(PlatformTransactionManager transactionManager, String dbKey) {
        this.transactionManager = transactionManager;
        this.dbKey = dbKey;
        this.singleThreadExecutor = new ThreadPoolExecutor(1, 1,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<>(1), new NamedThreadFactory("db-" + dbKey));

    }

    public ViceTransactionExecutor(String dbKey) {
        this(SpringContext.getBean(PlatformTransactionManager.class), dbKey);
    }


    // 事务运行
    public <T> T submit(Callable<T> command) throws Exception {
        if (finish) {
            throw new ServiceException("already finish");
        }
        if (!init) {
            init();
        }
        Future<T> submit = singleThreadExecutor.submit(command);
        return submit.get();
    }

    public void submit(Runnable command) throws Exception {
        if (finish) {
            throw new ServiceException("already finish");
        }
        if (!init) {
            init();
        }
        Future<?> submit = singleThreadExecutor.submit(command);
        submit.get();
    }

    // 初始化
    private void init() throws Exception{
        Future<TransactionStatus> submit = singleThreadExecutor.submit(() -> {
            DyDataSource.setDbKey(dbKey);
            DefaultTransactionDefinition def = new DefaultTransactionDefinition();
            def.setPropagationBehavior(DefaultTransactionDefinition.PROPAGATION_REQUIRED);
            def.setReadOnly(false);
            log.info("{} 初始化事务", dbKey);
            return transactionManager.getTransaction(def);
        });


        this.transactionStatus = submit.get();
        init = true;
    }

    // 完成事务运行
    public void finish(){
        if (transactionStatus == null){
            log.info("{} 暂未初始事务",dbKey);
            return;
        }else {
            try {
                singleThreadExecutor.submit(() -> {
                    log.info("{} 开始事务提交", dbKey);
                    this.transactionManager.commit(transactionStatus);
                    log.info("{} 事务提交结束", dbKey);
                }).get();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        singleThreadExecutor.shutdown();
        log.info("{} 线程池回收", dbKey);
        finish = true;
    }


    public void rollback(){
        if (transactionStatus == null){
            log.info("{} 暂未初始事务",dbKey);
            return;
        }else {
            try {
                singleThreadExecutor.submit(() -> {
                    log.info("{} 开始事务回滚", dbKey);
                    this.transactionManager.rollback(transactionStatus);
                    log.info("{} 事务回滚结束", dbKey);
                }).get();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        singleThreadExecutor.shutdown();
        log.info("{} 线程池回收", dbKey);
    }

}