Spring 切面无法通过注解拦截 mapper 方法的问题解决

1,283 阅读7分钟

Spring 切面无法通过注解拦截 mapper 方法的问题解决

背景

项目中需要做一个 sql 拦截并分析条件的功能(用做租户隔离的校验)。需要用到注解,对不同方法使用不同表和列,需要在 mybatis 拦截器中获取注解值并进行 sql 分析。一开始的想法是对 mapper 层注解的拦截,并把注解值写入 ThreadLocal,在 mybatis 拦截器中获取 ThreadLocal 值并进行 sql 分析。但是做到最后,突然发现 spring4 无法对 mapper 文件中的注解切面进行织入(参考: blog.csdn.net/z69183787/a…)。

于是只能通过其它方式,经过摸索,发现 Transaction 注解用的就是原生的 AspectJ。于是仿造 TransactionInterceptor , 对 mapper 包路径进行拦截,并判断是否存在注解来进行 mapper 的 Aop 。

实现

1 首先定义一个注解类和一个注解属性类

注解类

package com.yyigou.ddc.services.sql.analysis.util.annotation;
​
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
​
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
​
/**
 * DAO上开启租户隔离注解
 *
 * 1. 开启后, 将在指定的列进行进行风控排查
 * 2. 检查
 *
 * @author miaoyc
 */
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
@Inherited
public @interface EnableEnterpriseIsolate {
    /**
     * 要检查的租户属性列, 比如: bdc_goods.enterprise_no
     * 如果存在多个, 用逗号分割
     *
     * @return
     */
    String[] columns() default {};
​
    /**
     * 默认拦截返回值为 List 的 select 语句
     * @return
     */
    EffectiveScope effectiveScope() default EffectiveScope.SELECT_FOR_LIST;
​
    /**
     * 操作符所有列是同时满足还是满足其一即可, 默认OR, 即满足其一即可
     */
    boolean operatorOr() default true;
}
​

注解解析类

package com.yyigou.ddc.services.sql.analysis.util.annotation;
​
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
import lombok.Data;
​
import java.io.Serializable;
​
/**
 * @author WJP
 * @Date: 2023/2/22 19:03
 * @Description:
 */
@Data
public class EnterpriseIsolateAttr implements Serializable {
​
​
    String[] columns;
​
    /**
     * 操作符所有列是同时满足还是满足其一即可, 默认OR, 即满足其一即可
     */
    boolean operatorOr;
​
    /**
     * 默认拦截返回值为 List 的 select 语句
     * @return
     */
    EffectiveScope effectiveScope;
​
    /**
     * 是否是类上的注解(如果是类上的注解, 需要判断生效范围 effectiveScope, 决定是对所有 select sql 拦截还是只对返回值为 list 的 sql 拦截)
     */
    boolean annoOnClass;
​
    /**
     * 是否是空对象(用来对无注解的方法进行缓存)
     */
    boolean nullObject;
}

2 定义切面类

package com.yyigou.ddc.services.sql.analysis.util.aspect;
​
import com.yyigou.ddc.common.util.CallContextFilterUtil;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnableEnterpriseIsolate;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.conf.SqlAnalysisConfig;
import com.yyigou.ddc.services.sql.analysis.util.enums.EffectiveScope;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlEnterpriseAnalysisUtil;
import lombok.extern.slf4j.Slf4j;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.util.ClassUtils;
​
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
​
/**
 * @author WJP
 * @Date: 2021/8/23 14:57
 */
@Slf4j
//@Aspect
public class EnterpriseIsolateAspect implements MethodInterceptor {
​
​
    /**
     * 缓存 cache
     */
    private final Map<Object, EnterpriseIsolateAttr> attributeCache =
            new ConcurrentHashMap<Object, EnterpriseIsolateAttr>(1024);
​
    /**
     * 配置类
     */
    @Autowired
    SqlAnalysisConfig sqlAnalysisConfig;
​
    /**
     * 用作空对象缓存
     */
    private final EnterpriseIsolateAttr NULL_CACHE_OBJECT = new EnterpriseIsolateAttr();
​
    public EnterpriseIsolateAspect() {
        NULL_CACHE_OBJECT.setNullObject(true);
    }
​
    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
​
        // 如果项目未开启租户隔离分析
        if (!sqlAnalysisConfig.getEnableEnterpriseIsolate()){
            return invocation.proceed();
        }
        try{
            Class<?> targetClass = (invocation.getThis() != null ? AopUtils.getTargetClass(invocation.getThis()) : null);
            EnterpriseIsolateAttr transactionAttribute = getEnterpriseIsolateAttribute(invocation.getMethod(), targetClass);
            if (transactionAttribute != null){
                if (!transactionAttribute.isAnnoOnClass()){
                    // 如果注解是在方法上, 则不用判断 scope
                    SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
                }else {
                    // 如果方法上没有注解, 则从类上找注解
                    EffectiveScope effectiveScope = transactionAttribute.getEffectiveScope();
                    // 如果对所有 select sql 生效
                    if (EffectiveScope.ALL.equals(effectiveScope)){
                        SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
                    }else {
                        if (invocation.getMethod().getReturnType() == List.class){
                            SqlEnterpriseAnalysisUtil.setEnterpriseIsolate(transactionAttribute);
                        }
                    }
                }
            }
​
            return invocation.proceed();
        }finally {
            SqlEnterpriseAnalysisUtil.removeEnterpriseIsolate();
        }
    }
​
    public EnterpriseIsolateAttr getEnterpriseIsolateAttribute(Method method, Class<?> targetClass) {
        if (method.getDeclaringClass() == Object.class) {
            return null;
        }
​
        // First, see if we have a cached value.
        Object cacheKey = getCacheKey(method, targetClass);
        EnterpriseIsolateAttr cached = this.attributeCache.get(cacheKey);
        if (cached != null) {
            if (cached == NULL_CACHE_OBJECT) {
                return null;
            }
            else {
                return cached;
            }
        }
        else {
            // We need to work it out.
            EnterpriseIsolateAttr txAttr = computeEnterpriseIsolateAttribute(method, targetClass);
            // Put it in the cache.
            if (txAttr == null) {
                this.attributeCache.put(cacheKey, NULL_CACHE_OBJECT);
            }
            else {
                this.attributeCache.put(cacheKey, txAttr);
            }
            return txAttr;
        }
    }
​
​
    protected EnterpriseIsolateAttr computeEnterpriseIsolateAttribute(Method method, Class<?> targetClass) {
​
        // Ignore CGLIB subclasses - introspect the actual user class.
        Class<?> userClass = ClassUtils.getUserClass(targetClass);
        // The method may be on an interface, but we need attributes from the target class.
        // If the target class is null, the method will be unchanged.
        Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);
        // If we are dealing with method with generic parameters, find the original method.
        specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);
​
        // First try is the method in the target class.
        EnterpriseIsolateAttr txAttr = findEnterpriseIsolateAttribute(specificMethod);
        if (txAttr != null) {
            txAttr.setAnnoOnClass(false);
            return txAttr;
        }
​
        // Second try is the transaction attribute on the target class.
        txAttr = findEnterpriseIsolateAttribute(specificMethod.getDeclaringClass());
        if (txAttr != null && ClassUtils.isUserLevelMethod(method)) {
            txAttr.setAnnoOnClass(true);
            return txAttr;
        }
​
        if (specificMethod != method) {
            // Fallback is to look at the original method.
            txAttr = findEnterpriseIsolateAttribute(method);
            if (txAttr != null) {
                txAttr.setAnnoOnClass(false);
                return txAttr;
            }
            // Last fallback is the class of the original method.
            txAttr = findEnterpriseIsolateAttribute(method.getDeclaringClass());
            if (txAttr != null && ClassUtils.isUserLevelMethod(method)) {
                txAttr.setAnnoOnClass(true);
                return txAttr;
            }
        }
​
        return null;
    }
​
​
    protected EnterpriseIsolateAttr findEnterpriseIsolateAttribute(AnnotatedElement element) {
        AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(
                element, EnableEnterpriseIsolate.class);
        if (attributes != null) {
            return parseEnterpriseIsolateAnnotation(attributes);
        }
        else {
            return null;
        }
    }
​
​
    protected EnterpriseIsolateAttr parseEnterpriseIsolateAnnotation(AnnotationAttributes attributes) {
​
        EnterpriseIsolateAttr enterpriseIsolateAttr = new EnterpriseIsolateAttr();
        String[] columns = attributes.getStringArray("columns");
        EffectiveScope effectiveScope = attributes.getEnum("effectiveScope");
        boolean operatorOr = attributes.getBoolean("operatorOr");
        enterpriseIsolateAttr.setColumns(columns);
        enterpriseIsolateAttr.setOperatorOr(operatorOr);
        enterpriseIsolateAttr.setEffectiveScope(effectiveScope);
​
​
        return enterpriseIsolateAttr;
    }
​
​
    /**
     * Determine a cache key for the given method and target class.
     * <p>Must not produce same key for overloaded methods.
     * Must produce same key for different instances of the same method.
     * @param method the method (never {@code null})
     * @param targetClass the target class (may be {@code null})
     * @return the cache key (never {@code null})
     */
    protected Object getCacheKey(Method method, Class<?> targetClass) {
        return new MethodClassKey(method, targetClass);
    }
​
}
​

3 在 xml 中定义切面

    <bean id="enterpriseIsolateAspect" class="com.yyigou.ddc.services.sql.analysis.util.aspect.EnterpriseIsolateAspect"/>
    <aop:config>
        <aop:pointcut id="point" expression="execution(public * com.yyigou.ddc.services.ddc.task.mybatis.mapper..*.*(..))"/>
        <aop:advisor  pointcut-ref="point" advice-ref="enterpriseIsolateAspect"/>
    </aop:config>

使用方式

mapper 层

/** 
* Mapper 类上增加租户隔离注解
* 默认只拦截返回值为 list 的方法 
* 如果需要指定某个 sql 拦截不同表和列, 则在方法上使用 EnableEnterpriseIsolate 注解
*/package com.yyigou.ddc.services.ddc.task.mybatis.mapper;
​
import java.util.List;
​
/**
 * @author WJP
 * @Date: 2022/7/7 13:40
 */
@EnableEnterpriseIsolate(columns = {"t_import_task.tenant_no"})
public interface CustomImportTaskMapper {
​
    TaskVO getImportTaskWithTemplateInfo(@Param("taskId") String taskId);
​
    List<TaskVO> listImportTaskWithTemplateInfo(@Param("taskIds") List<String> taskIds,
                                                @Param("tenantNo") String tenantNo);
​
    List<TaskVO> listImportTaskWithTemplateInfoNoSession(@Param("taskIds") List<String> taskIds);
​
    @EnableEnterpriseIsolate(columns = {"t_import_task.task_id"})
    List<TaskVO> listImportTaskWithTemplateInfoNoSession1(@Param("taskIds") List<String> taskIds);
}
​

Sql 拦截器部分

这部分和主题无关,这部分是记录我使用 sql 拦截器分析 sql 并判断表对应的条件是否不存在的逻辑代码,不想看的可以直接无视

1 sql 拦截器

package com.yyigou.ddc.services.sql.analysis.util.interceptor;
​
import com.alibaba.fastjson.JSONObject;
import com.yyigou.ddc.common.exception.BusinessException;
import com.yyigou.ddc.services.message.kafka.KafkaProducerClient;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.conf.SqlAnalysisConfig;
import com.yyigou.ddc.services.sql.analysis.util.enums.SqlEnterpriseNoAnalysisEnum;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlEnterpriseAnalysisUtil;
import com.yyigou.ddc.services.sql.analysis.util.tools.SqlTableSourceAnalysisRes;
import com.yyigou.ddc.services.sql.analysis.util.tools.TraceMethodUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.skywalking.apm.toolkit.trace.TraceContext;
import org.springframework.stereotype.Component;
​
import java.util.List;
import java.util.Properties;
​
/**
 * @author WJP
 * @Date: 2023/2/21 13:12
 * @Description:
 */
@Intercepts(
        {
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                        RowBounds.class, ResultHandler.class}),
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                        RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
        }
)
@Slf4j
@Component
public class EnterpriseIsolateInterceptor implements Interceptor {
​
    private SqlAnalysisConfig sqlAnalysisConfig;
    private KafkaProducerClient kafkaProducerClient;
​
​
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
​
        Object target = invocation.getTarget();
        Object[] args = invocation.getArgs();
        Object parameter = args[1];
        MappedStatement mappedStatement = (MappedStatement) (args[0]);
​
        // 如果不需要拦截或者需要跳过当前 sqlId
        EnterpriseIsolateAttr enterpriseIsolate = SqlEnterpriseAnalysisUtil.getEnterpriseIsolate();
        if (enterpriseIsolate == null || sqlAnalysisConfig.getEnterpriseIsolateSkipSqlIdMap().containsKey(mappedStatement.getId())) {
            return invocation.proceed();
        }
​
        //获取到原始sql语句
        Configuration configuration = mappedStatement.getConfiguration();
        StatementHandler handler = configuration.newStatementHandler((Executor) target, mappedStatement,
                parameter, RowBounds.DEFAULT, null, null
        );
        BoundSql boundSql = handler.getBoundSql();
        String sql = boundSql.getSql();
​
        String[] columns = enterpriseIsolate.getColumns();
​
        boolean checkFlag = false;
        // 如果是 and
        if (!enterpriseIsolate.isOperatorOr()){
            checkFlag = true;
        }
        for (String column : columns) {
            String[] split = column.trim().split("\.");
            if (split.length < 2){
                throw new BusinessException("租户隔离注解值存在异常, 异常值为: " + column);
            }
            SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = SqlEnterpriseAnalysisUtil.analysisSqlContainsEnterpriseNo(sql, split[0], split[1]);
            // 如果当前表和列校验通过
            if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(sqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum())){
                if (enterpriseIsolate.isOperatorOr()){
                    checkFlag = true;
                    break;
                }
            } else {
                if (!enterpriseIsolate.isOperatorOr()){
                    checkFlag = false;
                    break;
                }
            }
        }
        if (!checkFlag){
            String elkMsg = String.format("sqlAnalysis警告, 当前sql存在租户隔离风险: 入参为: %s", JSONObject.toJSONString(parameter));
            log.warn(elkMsg);
        }
        return invocation.proceed();
​
    }
​
    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }
​
    @Override
    public void setProperties(Properties properties) {
​
    }
​
​
    public void setConfig(SqlAnalysisConfig sqlAnalysisConfig) {
        this.sqlAnalysisConfig = sqlAnalysisConfig;
    }
​
    public void setKafkaProducerClient(KafkaProducerClient kafkaProducerClient) {
        this.kafkaProducerClient = kafkaProducerClient;
    }
}
​

2 sql 分析 util

package com.yyigou.ddc.services.sql.analysis.util.tools;
​
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.statement.SQLUnionQueryTableSource;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUnionQuery;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.google.common.collect.Maps;
import com.yyigou.ddc.services.sql.analysis.util.annotation.EnterpriseIsolateAttr;
import com.yyigou.ddc.services.sql.analysis.util.enums.SqlEnterpriseNoAnalysisEnum;
​
import java.util.Map;
import java.util.regex.Pattern;
​
/**
 * @author WJP
 * @Date: 2023/2/21 16:56
 * @Description:
 */
public class SqlEnterpriseAnalysisUtil {
​
    /**
     * 列名的 pattern
     */
    private static Map<String, Pattern> columnEnterprisePatternMap = Maps.newConcurrentMap();
    /**
     * 租户隔离的 ThreadLocal
     */
    private static final ThreadLocal<EnterpriseIsolateAttr> enterpriseIsolateThreadLocal = new ThreadLocal<EnterpriseIsolateAttr>();
​
    public static SqlTableSourceAnalysisRes analysisSqlContainsEnterpriseNo(String sql, String tableName, String columnName) {
        // 新建 MySQL Parser
        SQLStatementParser parser = new MySqlStatementParser(sql);
​
        // 使用Parser解析生成AST,这里SQLStatement就是AST
        SQLSelectStatement statement = (SQLSelectStatement) parser.parseStatement();
        // 从 statement 中拿出 select 信息
        SQLSelect select = statement.getSelect();
        SQLSelectQuery sqlSelectQuery = select.getQuery();
        // 分析 select 语句
        SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = analysisFromSqlSelectQuery(sqlSelectQuery, tableName, columnName);
        return sqlTableSourceAnalysisRes;
//        System.out.println(select.toString());
    }
​
    /**
     * 从 sqlQuery 中提取信息
     *
     * @param sqlSelectQuery
     */
    private static SqlTableSourceAnalysisRes analysisFromSqlSelectQuery(SQLSelectQuery sqlSelectQuery, String tableName,
                                                  String columnName) {
        SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = new SqlTableSourceAnalysisRes();
        sqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO);
        if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
            // 如果是查询语句块
            SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
            // 获取 from 和 where 里包含的东西
            SQLTableSource from = sqlSelectQueryBlock.getFrom();
            SQLExpr where = sqlSelectQueryBlock.getWhere();
            // 如果是配置的表名则返回 true,进行修改条件操作
            SqlTableSourceAnalysisRes sqlTableSourceAnalysisResTemp = analysisFromSqlTableSource(from, tableName, columnName);
            // 如果不存在表或者表不存在 where 条件
            if (sqlTableSourceAnalysisResTemp == null || where == null) {
                return sqlTableSourceAnalysisRes;
            }
            // 如果当前查询已经包含了 on enterprise_no = ? 条件
            if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(sqlTableSourceAnalysisResTemp.getSqlEnterpriseNoAnalysisEnum())) {
                return sqlTableSourceAnalysisResTemp;
            }
            sqlTableSourceAnalysisResTemp.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.getByType(analysisFromCondition(where, tableName, columnName)));
            return sqlTableSourceAnalysisResTemp;
        } else if (sqlSelectQuery instanceof MySqlUnionQuery) {
            // 如果是 union 语句块
            MySqlUnionQuery mySqlUnionQuery = (MySqlUnionQuery) sqlSelectQuery;
            // 递归调用
            SqlTableSourceAnalysisRes sqlTableSourceAnalysisResTemp = analysisFromSqlSelectQuery(mySqlUnionQuery.getLeft(), tableName, columnName);
            if (!sqlTableSourceAnalysisResTemp.getSqlEnterpriseNoAnalysisEnum().equals(SqlEnterpriseNoAnalysisEnum.NORMAL.getCode())) {
                return sqlTableSourceAnalysisResTemp;
            }
            return analysisFromSqlSelectQuery(mySqlUnionQuery.getRight(), tableName, columnName);
        }
        return sqlTableSourceAnalysisRes;
    }
​
    private static SqlTableSourceAnalysisRes analysisFromSqlTableSource(SQLTableSource sqlTableSource, String tableName,
                                                                        String columnName){
        SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = new SqlTableSourceAnalysisRes();
        sqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO);
        if (sqlTableSource instanceof SQLJoinTableSource) {
            // 如果是 join 语句块
            SQLJoinTableSource sqlJoinTableSource = (SQLJoinTableSource) sqlTableSource;
            SQLTableSource left = sqlJoinTableSource.getLeft();
            SQLTableSource right = sqlJoinTableSource.getRight();
​
            // 分析表名是否存在, join on 中是否包含租户条件
            SqlTableSourceAnalysisRes leftSqlTableSourceAnalysisRes = analysisFromSqlTableSource(left, tableName, columnName);
            // 如果不存在表名或者 join on 中已包含租户条件, 则不需要进行 where 的判断
            if (analysisJoinTableSource(tableName, columnName, sqlJoinTableSource, leftSqlTableSourceAnalysisRes)){
                return leftSqlTableSourceAnalysisRes;
            }
            SqlTableSourceAnalysisRes rightSqlTableSourceAnalysisRes = analysisFromSqlTableSource(right, tableName, columnName);
            if (analysisJoinTableSource(tableName, columnName, sqlJoinTableSource, rightSqlTableSourceAnalysisRes)) {
                return rightSqlTableSourceAnalysisRes;
            }
            return sqlTableSourceAnalysisRes;
        }if (sqlTableSource instanceof SQLSubqueryTableSource) {
            // 如果是子查询
            SQLSelectQueryBlock query = (SQLSelectQueryBlock) ((SQLSubqueryTableSource) sqlTableSource).getSelect()
                    .getQuery();
            return analysisFromSqlSelectQuery(query, tableName, columnName);
        } else if (sqlTableSource instanceof SQLUnionQueryTableSource) {
            // 如果是 union
            MySqlUnionQuery union = (MySqlUnionQuery) ((SQLUnionQueryTableSource) sqlTableSource).getUnion();
            return analysisFromSqlSelectQuery(union, tableName, columnName);
        } else if (sqlTableSource instanceof SQLExprTableSource) {
            // 如果是表名(已到终点)
            SQLExpr expr = ((SQLExprTableSource) sqlTableSource).getExpr();
            if (expr instanceof SQLIdentifierExpr) {
                SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) expr;
                String lowerName = sqlIdentifierExpr.getLowerName();
                // 如果是指定的表名, 忽略大小写
                if (lowerName.equalsIgnoreCase(tableName) || lowerName.equalsIgnoreCase(String.format("`%s`", tableName))) {
                    sqlTableSourceAnalysisRes.setAlias(sqlTableSource.getAlias() == null ? "": sqlTableSource.getAlias());
                }
            }
        }
        return sqlTableSourceAnalysisRes;
    }
​
    /**
     * 判断 joinTableSource 是否是正常状态
     * @param tableName
     * @param columnName
     * @param sqlJoinTableSource
     * @param leftSqlTableSourceAnalysisRes
     * @return
     */
    private static boolean analysisJoinTableSource(String tableName, String columnName, SQLJoinTableSource sqlJoinTableSource, SqlTableSourceAnalysisRes leftSqlTableSourceAnalysisRes) {
        if (leftSqlTableSourceAnalysisRes.getAlias() == null){
            return false;
        } else if (leftSqlTableSourceAnalysisRes.getAlias() != null) {
            if (SqlEnterpriseNoAnalysisEnum.NORMAL.equals(leftSqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum())) {
                return true;
            }else {
                // 构造条件
                int i = analysisFromCondition(sqlJoinTableSource.getCondition(), tableName, columnName);
                SqlEnterpriseNoAnalysisEnum byType = SqlEnterpriseNoAnalysisEnum.getByType(i);
                if (byType.equals(SqlEnterpriseNoAnalysisEnum.NORMAL)){
                    leftSqlTableSourceAnalysisRes.setSqlEnterpriseNoAnalysisEnum(byType);
                    return true;
                }
            }
        }
        return false;
    }
​
​
    private static int analysisFromCondition(SQLExpr expr, String tableName, String columnName) {
        int max = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
        if (expr instanceof SQLBinaryOpExpr) {
            SQLBinaryOperator operator = ((SQLBinaryOpExpr) expr).getOperator();
            String operatorName = operator.getName();
            int left = 0, right = 0;
            // 如果两个都是 PropertyExpr , 则说明是 join 的 = 条件, 不能作为判断条件存在的依据
            if (((SQLBinaryOpExpr) expr).getLeft() instanceof SQLPropertyExpr && ((SQLBinaryOpExpr) expr).getRight() instanceof SQLPropertyExpr) {
                left = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
                right = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
            } else {
                left = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
                if (((SQLBinaryOpExpr) expr).getLeft() != null) {
                    left = analysisFromCondition(((SQLBinaryOpExpr) expr).getLeft(), tableName, columnName);
                }
                right = SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
                if (((SQLBinaryOpExpr) expr).getRight() != null) {
                    right = analysisFromCondition(((SQLBinaryOpExpr) expr).getRight(), tableName, columnName);
                }
            }
​
            if ("or".equalsIgnoreCase(operatorName)) {
                if (left != right) {
                    max = SqlEnterpriseNoAnalysisEnum.ENTERPRISE_NO_COVERED.getCode();
                } else {
                    max = left;
                }
            } else if ("and".equalsIgnoreCase(operatorName)) {
                if (left == SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() || right == SqlEnterpriseNoAnalysisEnum.NORMAL.getCode()) {
                    max = SqlEnterpriseNoAnalysisEnum.NORMAL.getCode();
                } else {
                    max = Math.max(left, right);
                }
            } else {
                max = Math.max(left, right);
            }
        } else if (expr instanceof SQLInSubQueryExpr) {
            // 如果是 query 则进行 query 遍历
            SQLSelectQuery query = ((SQLInSubQueryExpr) expr).getSubQuery().getQuery();
            SqlTableSourceAnalysisRes sqlTableSourceAnalysisRes = analysisFromSqlSelectQuery(query, tableName, columnName);
            max = Math.max(max, sqlTableSourceAnalysisRes.getSqlEnterpriseNoAnalysisEnum().getCode());
        } else if (expr instanceof SQLPropertyExpr) {
            max = columnNameMatch(columnName, ((SQLPropertyExpr) expr).getName()) ? SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() : SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
​
        } else if (expr instanceof SQLIdentifierExpr) {
            max = columnNameMatch(columnName, ((SQLIdentifierExpr) expr).getName()) ? SqlEnterpriseNoAnalysisEnum.NORMAL.getCode() : SqlEnterpriseNoAnalysisEnum.NO_ENTERPRISE_NO.getCode();
        } else if (expr instanceof SQLMethodInvokeExpr) {
            for (SQLExpr parameter : ((SQLMethodInvokeExpr) expr).getParameters()) {
                max = Math.max(max, analysisFromCondition(parameter, tableName, columnName));
            }
        } else if (expr instanceof SQLInListExpr){
            max = analysisFromCondition(((SQLInListExpr) expr).getExpr(), tableName, columnName);
        }
        return max;
    }
​
​
    /**
     * 判断 expr 的 name 是否和 columnName 匹配
     * @param columnName
     * @param exprName
     * @return
     */
    private static boolean columnNameMatch(String columnName, String exprName){
        // 把 columnName 的 pattern 放入 map 中, 避免频繁创建 pattern
        if(!columnEnterprisePatternMap.containsKey(columnName)){
            Pattern enterprisePattern = Pattern.compile(columnName, Pattern.CASE_INSENSITIVE);
            columnEnterprisePatternMap.put(columnName, enterprisePattern);
        }
        Pattern pattern = columnEnterprisePatternMap.get(columnName);
        if (pattern.matcher(exprName).find()){
            return true;
        }
        return false;
​
    }
​
​
    public static void setEnterpriseIsolate(EnterpriseIsolateAttr enterpriseIsolate){
        enterpriseIsolateThreadLocal.set(enterpriseIsolate);
    }
​
    public static void removeEnterpriseIsolate(){
        enterpriseIsolateThreadLocal.remove();
    }
​
    public static EnterpriseIsolateAttr getEnterpriseIsolate(){
        return enterpriseIsolateThreadLocal.get();
    }
​
​
}
​
package com.yyigou.ddc.services.sql.analysis.util.enums;
​
import java.util.HashMap;
import java.util.Map;
​
/**
 * @author WJP
 * @Date: 2023/2/21 14:33
 * @Description:
 */
public enum SqlEnterpriseNoAnalysisEnum {
​
    /**
     * sql分析后不存在租户条件
     **/
    NO_ENTERPRISE_NO(0, "不存在租户条件"),
​
    /**
     * sql分析后无异常
     **/
    NORMAL(1, "无异常"),
​
    /**
     * sql分析后租户条件被 or 覆盖
     **/
    ENTERPRISE_NO_COVERED(2, "租户条件被 or 覆盖");
​
    private final Integer code;
    private final String message;
​
    SqlEnterpriseNoAnalysisEnum(Integer code, String message) {
        this.code = code;
        this.message = message;
    }
​
    public Integer getCode() {
        return code;
    }
​
    public String getMessage() {
        return message;
    }
​
    private static Map<Integer, SqlEnterpriseNoAnalysisEnum> maps = new HashMap<Integer, SqlEnterpriseNoAnalysisEnum>();
​
    static {
        for (SqlEnterpriseNoAnalysisEnum item : SqlEnterpriseNoAnalysisEnum.values()) {
            maps.put(item.getCode(), item);
        }
    }
​
    public static SqlEnterpriseNoAnalysisEnum getByType(final Integer code) {
        if (code == null) {
            return null;
        }
        return maps.get(code);
    }
​
​
}
​