xiaoyan 测试解析豆子2

26 阅读1分钟
private static class FilterInjectionShuttle extends SqlShuttle {
    private final Set<String> targetTables;
    private final Deque<ScopeContext> scopes = new ArrayDeque<>();

    private static class ScopeContext {
        Map<String, String> aliases = new HashMap<>();
        Set<String> processedTables = new HashSet<>();
    }

    public FilterInjectionShuttle(List<String> targetTables) {
        this.targetTables = new HashSet<>(targetTables);
        scopes.push(new ScopeContext());
    }

    @Override
    public SqlNode visit(SqlCall call) {
        if (call instanceof SqlSelect) {
            scopes.push(new ScopeContext()); // 进入新作用域
            SqlNode result = handleSelect((SqlSelect) call);
            scopes.pop(); // 退出作用域
            return result;
        }
        return super.visit(call);
    }

    private SqlNode handleSelect(SqlSelect select) {
        // 处理FROM子句
        if (select.getFrom() != null) {
            SqlNode from = select.getFrom().accept(this);
            select.setFrom(from);
            collectLocalTables(from); // 仅收集当前作用域的表
        }

        // 处理WHERE条件
        SqlNode where = select.getWhere();
        if (where != null) {
            where = where.accept(this);
        }
        SqlNode newWhere = injectLocalConditions(where);
        select.setWhere(newWhere);

        return super.visit(select);
    }

    private void collectLocalTables(SqlNode from) {
        ScopeContext currentScope = scopes.peek();
        if (from instanceof SqlJoin) {
            collectLocalTables(((SqlJoin) from).getLeft());
            collectLocalTables(((SqlJoin) from).getRight());
        } else if (from instanceof SqlBasicCall) {
            SqlBasicCall call = (SqlBasicCall) from;
            if (call.getOperator() == SqlStdOperatorTable.AS) {
                List<SqlNode> operands = call.getOperandList();
                if (operands.size() >= 2) {
                    SqlNode tableNode = operands.get(0);
                    SqlNode aliasNode = operands.get(1);
                    if (tableNode instanceof SqlIdentifier && aliasNode instanceof SqlIdentifier) {
                        String tableName = extractPhysicalTableName(tableNode);
                        String alias = ((SqlIdentifier) aliasNode).getSimple();
                        currentScope.aliases.put(tableName.toUpperCase(), alias);
                        currentScope.processedTables.add(tableName.toUpperCase());
                    }
                }
            }
        } else if (from instanceof SqlIdentifier) {
            String tableName = extractPhysicalTableName(from);
            currentScope.aliases.put(tableName.toUpperCase(), tableName);
            currentScope.processedTables.add(tableName.toUpperCase());
        }
    }

    private SqlNode injectLocalConditions(SqlNode originalWhere) {
        ScopeContext currentScope = scopes.peek();
        List<SqlNode> conditions = new ArrayList<>();

        currentScope.processedTables.stream()
                .filter(targetTables::contains)
                .forEach(table -> {
                    String alias = currentScope.aliases.get(table);
                    conditions.add(createCondition(alias));
                });

        if (conditions.isEmpty()) {
            return originalWhere;
        }

        SqlNode injectedCondition = conditions.stream()
                .reduce((a, b) -> new SqlBasicCall(
                        SqlStdOperatorTable.AND,
                        Arrays.asList(a, b),
                        SqlParserPos.ZERO))
                .get();

        if (originalWhere == null) {
            return injectedCondition;
        }

        return new SqlBasicCall(
                SqlStdOperatorTable.AND,
                Arrays.asList(originalWhere, injectedCondition),
                SqlParserPos.ZERO);
    }

    private String extractPhysicalTableName(SqlNode node) {
        if (node instanceof SqlIdentifier) {
            List<String> names = ((SqlIdentifier) node).names;
            return names.get(names.size() - 1); // 获取真实表名(忽略schema)
        }
        return "";
    }
}