这是我参与8月更文挑战的第22天,活动详情查看:8月更文挑战
前言
之前我发布了一篇 spring多数据源事务一致性的文章,虽然实现了需求但是确实henlow,侵入性很强,一直想改造优化,在我学习了spring 事务的相关源码和不断尝试终于算是优雅实现
这里可能有朋友会疑惑,多数据源事务一致性这不是用seata就行了么,为啥要自己实现
所以在开始前我先阐述理由:
- seata是分布式事务,本文的应用场景适合的是单体并且在同一个方法中的事务
- 另外小编项目中的数据源是动态的,提前是不知道的
- 最主要原因还是想自己写写好玩,借用Linux之父林纳斯的一句话来说就是
Just for Fun
思路
spring 事务源码(三)如何保证被@Transactional标记方法中的所有sql都在一个事务内
基础的代码思路都是来自上面的文章
首先 spring多数据源事务一致性的文章当时这个文章其实已经说得差不多了,目前就是减少代码的侵入性,变得优雅些,毕竟让写代码的人手动回滚、提交实在是low
当时难点在于无法分辨执行的sql或者service层的方法是属于哪个数据源事务,但是spring自定义切面实现拦截类全方法这篇文章写完之后我就有了新的思路
完全可以在所有的service层的方法形成切面,切面伪代码如下
try{
if(当前线程上下文有多数据源事务标记){
获取当前service方法需要执行的数据源事务
获取之后将事务管理器注册到异步线程A
用异步线程A执行执行service层方法
}
}catch{
}
同时还得用一个新的注解暂时定义为MoreDataTransactional,切面的伪代码如下
try{
在当前线程中设置多数据源事务标记
获取当前线程中所有注册的事务管理器
iter
commit
}
}catch{
获取当前线程中所有注册的事务管理器
iter
rollback
}
这样就能保证每个service层的方法执行的sql是在唯一对应的事务管理器,这里我们用的事务管理器实际上还是spring 自带的
因为mybatis执行sql创建sqlsession还是会在本线程中获取事务管理器,而获取出来的就是spring 容器中的(如果有事务的话)
具体可以参考org.mybatis.spring.SqlSessionUtils#getSqlSession(...)
代码实现
自定义事务注解及其切面实现
/**
* 多数据源事务处理
* 当方法有此注解会在 {@link MoDataTransactionSourcePointcut} 解析
* 具体处理是在 {@link MoreDataTransactionAnnotationInterceptor}
* 当此方法中所有相关mybatis service 方法调用均会被 {@link MoreDbTransactionInterceptor} 解析处理
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface MoreDataTransactional {
}
/**
* 多事务注解解析判断
* {@link MoreDataTransactional}
*/
@Slf4j
public class MoDataTransactionSourcePointcut extends StaticMethodMatcherPointcut implements Serializable {
@Override
public boolean matches(Method method, @Nullable Class<?> targetClass) {
if (targetClass == null) {
return false;
}
if (method.isAnnotationPresent(MoreDataTransactional.class)) {
log.info("{} MoDataTransactionSourcePointcut ===> {}", targetClass.getName(), method.getName());
return true;
}
return false;
}
@Override
public int hashCode() {
return MoDataTransactionSourcePointcut.class.hashCode();
}
}
/**
* 多事务注解切面
* {@link MoreDataTransactional}
*/
@Slf4j
public class MoreDataTransactionAnnotationInterceptor extends TransactionInterceptor {
// 当前线程中自定义的事务管理器
private static final ThreadLocal<NewMoreDataSourceTransactionManager> threadLocal = new ThreadLocal<>();
@Override
public Object invoke(MethodInvocation invocation) throws Throwable {
NewMoreDataSourceTransactionManager transactionManager = new NewMoreDataSourceTransactionManager();
threadLocal.set(transactionManager);
try {
return transactionManager.dos(invocation);
} finally {
threadLocal.remove();
}
}
public static NewMoreDataSourceTransactionManager getDataTransactionManager(){
return threadLocal.get();
}
}
/**
* 多数据元事务注解处理器
*/
@Component
public class MoreDataTransactionBeanFactoryPointcutAdvisor extends AbstractBeanFactoryPointcutAdvisor {
MoDataTransactionSourcePointcut pointcut = new MoDataTransactionSourcePointcut();
public MoreDataTransactionBeanFactoryPointcutAdvisor() {
setAdvice(new MoreDataTransactionAnnotationInterceptor());
}
public void setClassFilter(ClassFilter classFilter) {
this.pointcut.setClassFilter(classFilter);
}
@Override
public Pointcut getPointcut() {
return pointcut;
}
}
@Slf4j
public class NewMoreDataSourceTransactionManager {
// 不同租户不同的事务管理器
Map<String, ViceTransactionExecutor> viceTransactionExecutorMap;
public ViceTransactionExecutor getTran(String dbKey) {
if (this.viceTransactionExecutorMap == null) {
this.viceTransactionExecutorMap = new HashMap<>();
}
ViceTransactionExecutor viceTransactionExecutor = this.viceTransactionExecutorMap.get(dbKey);
// 没有则创建
if (viceTransactionExecutor == null) {
viceTransactionExecutor = new ViceTransactionExecutor(dbKey);
this.viceTransactionExecutorMap.put(dbKey, viceTransactionExecutor);
}
return viceTransactionExecutor;
}
// 事务执行
public Object dos(MethodInvocation invocation) throws Throwable {
try {
log.info("dos 开始运行");
Object run = invocation.proceed();
finish();
return run;
} catch (Throwable e) {
e.printStackTrace();
log.warn("运行异常 {},", e.getMessage());
rollback();
throw e;
}finally {
log.info("dos 运行结束");
}
}
public <T> T submit(String dbKey, Callable<T> callable) throws Exception {
return getTran(dbKey).submit(callable);
}
private void finish(){
log.info("dos 开始提交事务");
this.viceTransactionExecutorMap.values().forEach(ViceTransactionExecutor::finish);
log.info("dos 事务提交结束 ");
}
private void rollback(){
log.info("dos 开始回滚事务");
this.viceTransactionExecutorMap.values().forEach(ViceTransactionExecutor::rollback);
log.info("dos 事务回滚结束 ");
}
}
@Slf4j
public class ViceTransactionExecutor {
// 真正调用的事务管理器
PlatformTransactionManager transactionManager;
// 是否初始化
boolean init;
// 事务状态
TransactionStatus transactionStatus;
// 单线程池(保证在一个线程中运行)
final ThreadPoolExecutor singleThreadExecutor;
// 是否结束
boolean finish;
// 数据源租户关键字
String dbKey;
public ViceTransactionExecutor(PlatformTransactionManager transactionManager, String dbKey) {
this.transactionManager = transactionManager;
this.dbKey = dbKey;
this.singleThreadExecutor = new ThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(1), new NamedThreadFactory("db-" + dbKey));
}
public ViceTransactionExecutor(String dbKey) {
this(SpringContext.getBean(PlatformTransactionManager.class), dbKey);
}
// 事务运行
public <T> T submit(Callable<T> command) throws Exception {
if (finish) {
throw new ServiceException("already finish");
}
if (!init) {
init();
}
Future<T> submit = singleThreadExecutor.submit(command);
return submit.get();
}
public void submit(Runnable command) throws Exception {
if (finish) {
throw new ServiceException("already finish");
}
if (!init) {
init();
}
Future<?> submit = singleThreadExecutor.submit(command);
submit.get();
}
// 初始化
private void init() throws Exception{
Future<TransactionStatus> submit = singleThreadExecutor.submit(() -> {
DyDataSource.setDbKey(dbKey);
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
def.setPropagationBehavior(DefaultTransactionDefinition.PROPAGATION_REQUIRED);
def.setReadOnly(false);
log.info("{} 初始化事务", dbKey);
return transactionManager.getTransaction(def);
});
this.transactionStatus = submit.get();
init = true;
}
// 完成事务运行
public void finish(){
if (transactionStatus == null){
log.info("{} 暂未初始事务",dbKey);
return;
}else {
try {
singleThreadExecutor.submit(() -> {
log.info("{} 开始事务提交", dbKey);
this.transactionManager.commit(transactionStatus);
log.info("{} 事务提交结束", dbKey);
}).get();
} catch (Exception e) {
e.printStackTrace();
}
}
singleThreadExecutor.shutdown();
log.info("{} 线程池回收", dbKey);
finish = true;
}
public void rollback(){
if (transactionStatus == null){
log.info("{} 暂未初始事务",dbKey);
return;
}else {
try {
singleThreadExecutor.submit(() -> {
log.info("{} 开始事务回滚", dbKey);
this.transactionManager.rollback(transactionStatus);
log.info("{} 事务回滚结束", dbKey);
}).get();
} catch (Exception e) {
e.printStackTrace();
}
}
singleThreadExecutor.shutdown();
log.info("{} 线程池回收", dbKey);
}
}