Spring 切面无法通过注解拦截 mapper 方法的问题解决
背景
项目中需要做一个 sql 拦截并分析条件的功能(用做租户隔离的校验)。需要用到注解,对不同方法使用不同表和列,需要在 mybatis 拦截器中获取注解值并进行 sql 分析。一开始的想法是对 mapper 层注解的拦截,并把注解值写入 ThreadLocal,在 mybatis 拦截器中获取 ThreadLocal 值并进行 sql 分析。但是做到最后,突然发现 spring4 无法对 mapper 文件中的注解切面进行织入(参考: blog.csdn.net/z69183787/a…)。
于是只能通过其它方式,经过摸索,发现 Transaction 注解用的就是原生的 AspectJ。于是仿造 TransactionInterceptor , 对 mapper 包路径进行拦截,并判断是否存在注解来进行 mapper 的 Aop 。
实现
1 首先定义一个注解类和一个注解属性类
注解类
package com.yyigou.ddc.services.sql.analysis.util.annotation;
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* DAO上开启租户隔离注解
*
* 1. 开启后, 将在指定的列进行进行风控排查
* 2. 检查
*
* @author miaoyc
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
@Inherited
public @interface EnableEnterpriseIsolate {
/**
* 要检查的租户属性列, 比如: bdc_goods.enterprise_no
* 如果存在多个, 用逗号分割
*
* @return
*/
String[] columns() default {};
/**
* 默认拦截返回值为 List 的 select 语句
* @return
*/
EffectiveScope effectiveScope() default EffectiveScope.SELECT_FOR_LIST;
/**
* 操作符所有列是同时满足还是满足其一即可, 默认OR, 即满足其一即可
*/
boolean operatorOr() default true;
}
注解解析类
package com.yyigou.ddc.services.sql.analysis.util.annotation;
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
import lombok.Data;
import java.io.Serializable;
/**
* @author WJP
* @Date: 2023/2/22 19:03
* @Description:
*/
@Data
public class EnterpriseIsolateAttr implements Serializable {
String[] columns;
/**
* 操作符所有列是同时满足还是满足其一即可, 默认OR, 即满足其一即可
*/
boolean operatorOr;
/**
* 默认拦截返回值为 List 的 select 语句
* @return
*/
EffectiveScope effectiveScope;
/**
* 是否是类上的注解(如果是类上的注解, 需要判断生效范围 effectiveScope, 决定是对所有 select sql 拦截还是只对返回值为 list 的 sql 拦截)
*/
boolean annoOnClass;
/**
* 是否是空对象(用来对无注解的方法进行缓存)
*/
boolean nullObject;
}
2 定义切面类
package com.yyigou.ddc.services.sql.analysis.util.aspect;
import com.yyigou.ddc.common.util.CallContextFilterUtil;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnableEnterpriseIsolate;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.conf.SqlAnalysisConfig;
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlEnterpriseAnalysisUtil;
import lombok.extern.slf4j.Slf4j;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.util.ClassUtils;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author WJP
* @Date: 2021/8/23 14:57
*/
@Slf4j
//@Aspect
public class EnterpriseIsolateAspect implements MethodInterceptor {
/**
* 缓存 cache
*/
private final Map<Object, EnterpriseIsolateAttr> attributeCache =
new ConcurrentHashMap<Object, EnterpriseIsolateAttr>(1024);
/**
* 配置类
*/
@Autowired
SqlAnalysisConfig sqlAnalysisConfig;
/**
* 用作空对象缓存
*/
private final EnterpriseIsolateAttr NULL_CACHE_OBJECT = new EnterpriseIsolateAttr();
public EnterpriseIsolateAspect() {
NULL_CACHE_OBJECT.setNullObject(true);
}
@Override
public Object invoke(MethodInvocation invocation) throws Throwable {
// 如果项目未开启租户隔离分析
if (!sqlAnalysisConfig.getEnableEnterpriseIsolate()){
return invocation.proceed();
}
try{
Class<?> targetClass = (invocation.getThis() != null ? AopUtils.getTargetClass(invocation.getThis()) : null);
EnterpriseIsolateAttr transactionAttribute = getEnterpriseIsolateAttribute(invocation.getMethod(), targetClass);
if (transactionAttribute != null){
if (!transactionAttribute.isAnnoOnClass()){
// 如果注解是在方法上, 则不用判断 scope
SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
}else {
// 如果方法上没有注解, 则从类上找注解
EffectiveScope effectiveScope = transactionAttribute.getEffectiveScope();
// 如果对所有 select sql 生效
if (EffectiveScope.ALL.equals(effectiveScope)){
SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
}else {
if (invocation.getMethod().getReturnType() == List.class){
SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
}
}
}
}
return invocation.proceed();
}finally {
SqlEnterpriseAnalysisUtil.removeEnterpriseIsolate();
}
}
public EnterpriseIsolateAttr getEnterpriseIsolateAttribute(Method method, Class<?> targetClass) {
if (method.getDeclaringClass() == Object.class) {
return null;
}
// First, see if we have a cached value.
Object cacheKey = getCacheKey(method, targetClass);
EnterpriseIsolateAttr cached = this.attributeCache.get(cacheKey);
if (cached != null) {
if (cached == NULL_CACHE_OBJECT) {
return null;
}
else {
return cached;
}
}
else {
// We need to work it out.
EnterpriseIsolateAttr txAttr = computeEnterpriseIsolateAttribute(method, targetClass);
// Put it in the cache.
if (txAttr == null) {
this.attributeCache.put(cacheKey, NULL_CACHE_OBJECT);
}
else {
this.attributeCache.put(cacheKey, txAttr);
}
return txAttr;
}
}
protected EnterpriseIsolateAttr computeEnterpriseIsolateAttribute(Method method, Class<?> targetClass) {
// Ignore CGLIB subclasses - introspect the actual user class.
Class<?> userClass = ClassUtils.getUserClass(targetClass);
// The method may be on an interface, but we need attributes from the target class.
// If the target class is null, the method will be unchanged.
Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);
// If we are dealing with method with generic parameters, find the original method.
specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);
// First try is the method in the target class.
EnterpriseIsolateAttr txAttr = findEnterpriseIsolateAttribute(specificMethod);
if (txAttr != null) {
txAttr.setAnnoOnClass(false);
return txAttr;
}
// Second try is the transaction attribute on the target class.
txAttr = findEnterpriseIsolateAttribute(specificMethod.getDeclaringClass());
if (txAttr != null && ClassUtils.isUserLevelMethod(method)) {
txAttr.setAnnoOnClass(true);
return txAttr;
}
if (specificMethod != method) {
// Fallback is to look at the original method.
txAttr = findEnterpriseIsolateAttribute(method);
if (txAttr != null) {
txAttr.setAnnoOnClass(false);
return txAttr;
}
// Last fallback is the class of the original method.
txAttr = findEnterpriseIsolateAttribute(method.getDeclaringClass());
if (txAttr != null && ClassUtils.isUserLevelMethod(method)) {
txAttr.setAnnoOnClass(true);
return txAttr;
}
}
return null;
}
protected EnterpriseIsolateAttr findEnterpriseIsolateAttribute(AnnotatedElement element) {
AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(
element, EnableEnterpriseIsolate.class);
if (attributes != null) {
return parseEnterpriseIsolateAnnotation(attributes);
}
else {
return null;
}
}
protected EnterpriseIsolateAttr parseEnterpriseIsolateAnnotation(AnnotationAttributes attributes) {
EnterpriseIsolateAttr enterpriseIsolateAttr = new EnterpriseIsolateAttr();
String[] columns = attributes.getStringArray("columns");
EffectiveScope effectiveScope = attributes.getEnum("effectiveScope");
boolean operatorOr = attributes.getBoolean("operatorOr");
enterpriseIsolateAttr.setColumns(columns);
enterpriseIsolateAttr.setOperatorOr(operatorOr);
enterpriseIsolateAttr.setEffectiveScope(effectiveScope);
return enterpriseIsolateAttr;
}
/**
* Determine a cache key for the given method and target class.
* <p>Must not produce same key for overloaded methods.
* Must produce same key for different instances of the same method.
* @param method the method (never {@code null})
* @param targetClass the target class (may be {@code null})
* @return the cache key (never {@code null})
*/
protected Object getCacheKey(Method method, Class<?> targetClass) {
return new MethodClassKey(method, targetClass);
}
}
3 在 xml 中定义切面
<bean id="enterpriseIsolateAspect" class="com.yyigou.ddc.services.sql.analysis.util.aspect.EnterpriseIsolateAspect"/>
<aop:config>
<aop:pointcut id="point" expression="execution(public * com.yyigou.ddc.services.ddc.task.mybatis.mapper..*.*(..))"/>
<aop:advisor pointcut-ref="point" advice-ref="enterpriseIsolateAspect"/>
</aop:config>
使用方式
mapper 层
/**
* Mapper 类上增加租户隔离注解
* 默认只拦截返回值为 list 的方法
* 如果需要指定某个 sql 拦截不同表和列, 则在方法上使用 EnableEnterpriseIsolate 注解
*/
package com.yyigou.ddc.services.ddc.task.mybatis.mapper;
import java.util.List;
/**
* @author WJP
* @Date: 2022/7/7 13:40
*/
@EnableEnterpriseIsolate(columns = {"t_import_task.tenant_no"})
public interface CustomImportTaskMapper {
TaskVO getImportTaskWithTemplateInfo(@Param("taskId") String taskId);
List<TaskVO> listImportTaskWithTemplateInfo(@Param("taskIds") List<String> taskIds,
@Param("tenantNo") String tenantNo);
List<TaskVO> listImportTaskWithTemplateInfoNoSession(@Param("taskIds") List<String> taskIds);
@EnableEnterpriseIsolate(columns = {"t_import_task.task_id"})
List<TaskVO> listImportTaskWithTemplateInfoNoSession1(@Param("taskIds") List<String> taskIds);
}
Sql 拦截器部分
这部分和主题无关,这部分是记录我使用 sql 拦截器分析 sql 并判断表对应的条件是否不存在的逻辑代码,不想看的可以直接无视
1 sql 拦截器
package com.yyigou.ddc.services.sql.analysis.util.interceptor;
import com.alibaba.fastjson.JSONObject;
import com.yyigou.ddc.common.exception.BusinessException;
import com.yyigou.ddc.services.message.kafka.KafkaProducerClient;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.conf.SqlAnalysisConfig;
import com.yyigou.ddc.services.sql.analysis.util.enums.SqlEnterpriseNoAnalysisEnum;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlEnterpriseAnalysisUtil;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlTableSourceAnalysisRes;
import com.yyigou.ddc.services.sql.analysis.util.tools.TraceMethodUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.skywalking.apm.toolkit.trace.TraceContext;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Properties;
/**
* @author WJP
* @Date: 2023/2/21 13:12
* @Description:
*/
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
}
)
@Slf4j
@Component
public class EnterpriseIsolateInterceptor implements Interceptor {
private SqlAnalysisConfig sqlAnalysisConfig;
private KafkaProducerClient kafkaProducerClient;
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
Object[] args = invocation.getArgs();
Object parameter = args[1];
MappedStatement mappedStatement = (MappedStatement) (args[0]);
// 如果不需要拦截或者需要跳过当前 sqlId
EnterpriseIsolateAttr enterpriseIsolate = SqlEnterpriseAnalysisUtil.getEnterpriseIsolate();
if (enterpriseIsolate == null || sqlAnalysisConfig.getEnterpriseIsolateSkipSqlIdMap().containsKey(mappedStatement.getId())) {
return invocation.proceed();
}
//获取到原始sql语句
Configuration configuration = mappedStatement.getConfiguration();
StatementHandler handler = configuration.newStatementHandler((Executor) target, mappedStatement,
parameter, RowBounds.DEFAULT, null, null
);
BoundSql boundSql = handler.getBoundSql();
String sql = boundSql.getSql();
String[] columns = enterpriseIsolate.getColumns();
boolean checkFlag = false;
// 如果是 and
if (!enterpriseIsolate.isOperatorOr()){
checkFlag = true;
}
for (String column : columns) {
String[] split = column.trim().split("\.");
if (split.length < 2){
throw new BusinessException("租户隔离注解值存在异常, 异常值为: " + column);
}
SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = SqlEnterpriseAnalysisUtil.analysisSqlContainsEnterpriseNo(sql, split[0], split[1]);
// 如果当前表和列校验通过
if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(sqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum())){
if (enterpriseIsolate.isOperatorOr()){
checkFlag = true;
break;
}
} else {
if (!enterpriseIsolate.isOperatorOr()){
checkFlag = false;
break;
}
}
}
if (!checkFlag){
String elkMsg = String.format("sqlAnalysis警告, 当前sql存在租户隔离风险: 入参为: %s", JSONObject.toJSONString(parameter));
log.warn(elkMsg);
}
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
if (target instanceof Executor) {
return Plugin.wrap(target, this);
} else {
return target;
}
}
@Override
public void setProperties(Properties properties) {
}
public void setConfig(SqlAnalysisConfig sqlAnalysisConfig) {
this.sqlAnalysisConfig = sqlAnalysisConfig;
}
public void setKafkaProducerClient(KafkaProducerClient kafkaProducerClient) {
this.kafkaProducerClient = kafkaProducerClient;
}
}
2 sql 分析 util
package com.yyigou.ddc.services.sql.analysis.util.tools;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.statement.SQLUnionQueryTableSource;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUnionQuery;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.google.common.collect.Maps;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.enums.SqlEnterpriseNoAnalysisEnum;
import java.util.Map;
import java.util.regex.Pattern;
/**
* @author WJP
* @Date: 2023/2/21 16:56
* @Description:
*/
public class SqlEnterpriseAnalysisUtil {
/**
* 列名的 pattern
*/
private static Map<String, Pattern> columnEnterprisePatternMap = Maps.newConcurrentMap();
/**
* 租户隔离的 ThreadLocal
*/
private static final ThreadLocal<EnterpriseIsolateAttr> enterpriseIsolateThreadLocal = new ThreadLocal<EnterpriseIsolateAttr>();
public static SqlTableSourceAnalysisRes analysisSqlContainsEnterpriseNo(String sql, String tableName, String columnName) {
// 新建 MySQL Parser
SQLStatementParser parser = new MySqlStatementParser(sql);
// 使用Parser解析生成AST,这里SQLStatement就是AST
SQLSelectStatement statement = (SQLSelectStatement) parser.parseStatement();
// 从 statement 中拿出 select 信息
SQLSelect select = statement.getSelect();
SQLSelectQuery sqlSelectQuery = select.getQuery();
// 分析 select 语句
SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = analysisFromSqlSelectQuery(sqlSelectQuery, tableName, columnName);
return sqlTableSourceAnalysisRes;
// System.out.println(select.toString());
}
/**
* 从 sqlQuery 中提取信息
*
* @param sqlSelectQuery
*/
private static SqlTableSourceAnalysisRes analysisFromSqlSelectQuery(SQLSelectQuery sqlSelectQuery, String tableName,
String columnName) {
SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = new SqlTableSourceAnalysisRes();
sqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO);
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
// 如果是查询语句块
SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
// 获取 from 和 where 里包含的东西
SQLTableSource from = sqlSelectQueryBlock.getFrom();
SQLExpr where = sqlSelectQueryBlock.getWhere();
// 如果是配置的表名则返回 true,进行修改条件操作
SqlTableSourceAnalysisRes sqlTableSourceAnalysisResTemp = analysisFromSqlTableSource(from, tableName, columnName);
// 如果不存在表或者表不存在 where 条件
if (sqlTableSourceAnalysisResTemp == null || where == null) {
return sqlTableSourceAnalysisRes;
}
// 如果当前查询已经包含了 on enterprise_no = ? 条件
if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(sqlTableSourceAnalysisResTemp.getSqlEnterpriseNoAnalysisEnum())) {
return sqlTableSourceAnalysisResTemp;
}
sqlTableSourceAnalysisResTemp.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.getByType(analysisFromCondition(where, tableName, columnName)));
return sqlTableSourceAnalysisResTemp;
} else if (sqlSelectQuery instanceof MySqlUnionQuery) {
// 如果是 union 语句块
MySqlUnionQuery mySqlUnionQuery = (MySqlUnionQuery) sqlSelectQuery;
// 递归调用
SqlTableSourceAnalysisRes sqlTableSourceAnalysisResTemp = analysisFromSqlSelectQuery(mySqlUnionQuery.getLeft(), tableName, columnName);
if (!sqlTableSourceAnalysisResTemp.getSqlEnterpriseNoAnalysisEnum().equals(SqlEnterpriseNoAnalysisEnum.NORMAL.getCode())) {
return sqlTableSourceAnalysisResTemp;
}
return analysisFromSqlSelectQuery(mySqlUnionQuery.getRight(), tableName, columnName);
}
return sqlTableSourceAnalysisRes;
}
private static SqlTableSourceAnalysisRes analysisFromSqlTableSource(SQLTableSource sqlTableSource, String tableName,
String columnName){
SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = new SqlTableSourceAnalysisRes();
sqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO);
if (sqlTableSource instanceof SQLJoinTableSource) {
// 如果是 join 语句块
SQLJoinTableSource sqlJoinTableSource = (SQLJoinTableSource) sqlTableSource;
SQLTableSource left = sqlJoinTableSource.getLeft();
SQLTableSource right = sqlJoinTableSource.getRight();
// 分析表名是否存在, join on 中是否包含租户条件
SqlTableSourceAnalysisRes leftSqlTableSourceAnalysisRes = analysisFromSqlTableSource(left, tableName, columnName);
// 如果不存在表名或者 join on 中已包含租户条件, 则不需要进行 where 的判断
if (analysisJoinTableSource(tableName, columnName, sqlJoinTableSource, leftSqlTableSourceAnalysisRes)){
return leftSqlTableSourceAnalysisRes;
}
SqlTableSourceAnalysisRes rightSqlTableSourceAnalysisRes = analysisFromSqlTableSource(right, tableName, columnName);
if (analysisJoinTableSource(tableName, columnName, sqlJoinTableSource, rightSqlTableSourceAnalysisRes)) {
return rightSqlTableSourceAnalysisRes;
}
return sqlTableSourceAnalysisRes;
}if (sqlTableSource instanceof SQLSubqueryTableSource) {
// 如果是子查询
SQLSelectQueryBlock query = (SQLSelectQueryBlock) ((SQLSubqueryTableSource) sqlTableSource).getSelect()
.getQuery();
return analysisFromSqlSelectQuery(query, tableName, columnName);
} else if (sqlTableSource instanceof SQLUnionQueryTableSource) {
// 如果是 union
MySqlUnionQuery union = (MySqlUnionQuery) ((SQLUnionQueryTableSource) sqlTableSource).getUnion();
return analysisFromSqlSelectQuery(union, tableName, columnName);
} else if (sqlTableSource instanceof SQLExprTableSource) {
// 如果是表名(已到终点)
SQLExpr expr = ((SQLExprTableSource) sqlTableSource).getExpr();
if (expr instanceof SQLIdentifierExpr) {
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) expr;
String lowerName = sqlIdentifierExpr.getLowerName();
// 如果是指定的表名, 忽略大小写
if (lowerName.equalsIgnoreCase(tableName) || lowerName.equalsIgnoreCase(String.format("`%s`", tableName))) {
sqlTableSourceAnalysisRes.setAlias(sqlTableSource.getAlias() == null ? "": sqlTableSource.getAlias());
}
}
}
return sqlTableSourceAnalysisRes;
}
/**
* 判断 joinTableSource 是否是正常状态
* @param tableName
* @param columnName
* @param sqlJoinTableSource
* @param leftSqlTableSourceAnalysisRes
* @return
*/
private static boolean analysisJoinTableSource(String tableName, String columnName, SQLJoinTableSource sqlJoinTableSource, SqlTableSourceAnalysisRes leftSqlTableSourceAnalysisRes) {
if (leftSqlTableSourceAnalysisRes.getAlias() == null){
return false;
} else if (leftSqlTableSourceAnalysisRes.getAlias() != null) {
if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(leftSqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum())) {
return true;
}else {
// 构造条件
int i = analysisFromCondition(sqlJoinTableSource.getCondition(), tableName, columnName);
SqlEnterpriseNoAnalysisEnum byType = SqlEnterpriseNoAnalysisEnum.getByType(i);
if (byType.equals(SqlEnterpriseNoAnalysisEnum.NORMAL)){
leftSqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(byType);
return true;
}
}
}
return false;
}
private static int analysisFromCondition(SQLExpr expr, String tableName, String columnName) {
int max = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
if (expr instanceof SQLBinaryOpExpr) {
SQLBinaryOperator operator = ((SQLBinaryOpExpr) expr).getOperator();
String operatorName = operator.getName();
int left = 0, right = 0;
// 如果两个都是 PropertyExpr , 则说明是 join 的 = 条件, 不能作为判断条件存在的依据
if (((SQLBinaryOpExpr) expr).getLeft() instanceof SQLPropertyExpr && ((SQLBinaryOpExpr) expr).getRight() instanceof SQLPropertyExpr) {
left = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
right = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
} else {
left = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
if (((SQLBinaryOpExpr) expr).getLeft() != null) {
left = analysisFromCondition(((SQLBinaryOpExpr) expr).getLeft(), tableName, columnName);
}
right = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
if (((SQLBinaryOpExpr) expr).getRight() != null) {
right = analysisFromCondition(((SQLBinaryOpExpr) expr).getRight(), tableName, columnName);
}
}
if ("or".equalsIgnoreCase(operatorName)) {
if (left != right) {
max = SqlEnterpriseNoAnalysisEnum.ENTERPRISE_NO_COVERED.getCode();
} else {
max = left;
}
} else if ("and".equalsIgnoreCase(operatorName)) {
if (left == SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() || right == SqlEnterpriseNoAnalysisEnum.NORMAL.getCode()) {
max = SqlEnterpriseNoAnalysisEnum.NORMAL.getCode();
} else {
max = Math.max(left, right);
}
} else {
max = Math.max(left, right);
}
} else if (expr instanceof SQLInSubQueryExpr) {
// 如果是 query 则进行 query 遍历
SQLSelectQuery query = ((SQLInSubQueryExpr) expr).getSubQuery().getQuery();
SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = analysisFromSqlSelectQuery(query, tableName, columnName);
max = Math.max(max, sqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum().getCode());
} else if (expr instanceof SQLPropertyExpr) {
max = columnNameMatch(columnName, ((SQLPropertyExpr) expr).getName()) ? SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() : SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
} else if (expr instanceof SQLIdentifierExpr) {
max = columnNameMatch(columnName, ((SQLIdentifierExpr) expr).getName()) ? SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() : SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
} else if (expr instanceof SQLMethodInvokeExpr) {
for (SQLExpr parameter : ((SQLMethodInvokeExpr) expr).getParameters()) {
max = Math.max(max, analysisFromCondition(parameter, tableName, columnName));
}
} else if (expr instanceof SQLInListExpr){
max = analysisFromCondition(((SQLInListExpr) expr).getExpr(), tableName, columnName);
}
return max;
}
/**
* 判断 expr 的 name 是否和 columnName 匹配
* @param columnName
* @param exprName
* @return
*/
private static boolean columnNameMatch(String columnName, String exprName){
// 把 columnName 的 pattern 放入 map 中, 避免频繁创建 pattern
if(!columnEnterprisePatternMap.containsKey(columnName)){
Pattern enterprisePattern = Pattern.compile(columnName, Pattern.CASE_INSENSITIVE);
columnEnterprisePatternMap.put(columnName, enterprisePattern);
}
Pattern pattern = columnEnterprisePatternMap.get(columnName);
if (pattern.matcher(exprName).find()){
return true;
}
return false;
}
public static void setEnterpriseIsolate(EnterpriseIsolateAttr enterpriseIsolate){
enterpriseIsolateThreadLocal.set(enterpriseIsolate);
}
public static void removeEnterpriseIsolate(){
enterpriseIsolateThreadLocal.remove();
}
public static EnterpriseIsolateAttr getEnterpriseIsolate(){
return enterpriseIsolateThreadLocal.get();
}
}
package com.yyigou.ddc.services.sql.analysis.util.enums;
import java.util.HashMap;
import java.util.Map;
/**
* @author WJP
* @Date: 2023/2/21 14:33
* @Description:
*/
public enum SqlEnterpriseNoAnalysisEnum {
/**
* sql分析后不存在租户条件
**/
NO_ENTERPRISE_NO(0, "不存在租户条件"),
/**
* sql分析后无异常
**/
NORMAL(1, "无异常"),
/**
* sql分析后租户条件被 or 覆盖
**/
ENTERPRISE_NO_COVERED(2, "租户条件被 or 覆盖");
private final Integer code;
private final String message;
SqlEnterpriseNoAnalysisEnum(Integer code, String message) {
this.code = code;
this.message = message;
}
public Integer getCode() {
return code;
}
public String getMessage() {
return message;
}
private static Map<Integer, SqlEnterpriseNoAnalysisEnum> maps = new HashMap<Integer, SqlEnterpriseNoAnalysisEnum>();
static {
for (SqlEnterpriseNoAnalysisEnum item : SqlEnterpriseNoAnalysisEnum.values()) {
maps.put(item.getCode(), item);
}
}
public static SqlEnterpriseNoAnalysisEnum getByType(final Integer code) {
if (code == null) {
return null;
}
return maps.get(code);
}
}