xiaoyan 测试解析豆子

56 阅读1分钟
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.parser.SqlParser.Config;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.dialect.OracleSqlDialect;
import java.util.*;

public class OracleSqlFilterInjector {
    private static final Config SQL_PARSER_CONFIG = SqlParser.config()
            .withLex(Lex.ORACLE)
            .withCaseSensitive(false);

    public static String injectFilter(String originalSql, List<String> targetTables) throws SqlParseException {
        SqlParser parser = SqlParser.create(originalSql, SQL_PARSER_CONFIG);
        SqlNode sqlNode = parser.parseQuery();
        
        FilterInjectionShuttle shuttle = new FilterInjectionShuttle(targetTables);
        SqlNode modifiedSqlNode = sqlNode.accept(shuttle);
        
        return modifiedSqlNode.toSqlString(OracleSqlDialect.DEFAULT, true)
                .getSql()
                .replace("\"", ""); // Remove Oracle quotes
    }

    private static class FilterInjectionShuttle extends SqlShuttle {
        private final Set<String> targetTables;
        private final Map<String, String> tableAliases = new HashMap<>();

        public FilterInjectionShuttle(List<String> targetTables) {
            this.targetTables = new HashSet<>(targetTables);
        }

        @Override
        public SqlNode visit(SqlCall call) {
            if (call instanceof SqlSelect) {
                return handleSelect((SqlSelect) call);
            }
            return super.visit(call);
        }

        private SqlNode handleSelect(SqlSelect select) {
            // Process FROM clause
            if (select.getFrom() != null) {
                SqlNode from = select.getFrom().accept(this);
                select.setFrom(from);
                collectTableAliases(from);
            }

            // Process WHERE clause
            SqlNode where = select.getWhere();
            if (where != null) {
                where = where.accept(this);
            }
            SqlNode newWhere = injectConditions(where, select.getFrom());
            select.setWhere(newWhere);

            return super.visit(select);
        }

        private void collectTableAliases(SqlNode from) {
            if (from instanceof SqlJoin) {
                collectTableAliases(((SqlJoin) from).getLeft());
                collectTableAliases(((SqlJoin) from).getRight());
            } else if (from instanceof SqlBasicCall) {
                SqlBasicCall call = (SqlBasicCall) from;
                if (call.getOperator() == SqlStdOperatorTable.AS) {
                    processAsOperator(call);
                } else {
                    // Handle other types of calls (e.g. subqueries)
                    for (SqlNode operand : call.getOperandList()) {
                        collectTableAliases(operand);
                    }
                }
            } else if (from instanceof SqlIdentifier) {
                processIdentifier((SqlIdentifier) from);
            }
        }

        private void processAsOperator(SqlBasicCall call) {
            List<SqlNode> operands = call.getOperandList();
            if (operands.size() >= 2) {
                SqlNode tableNode = operands.get(0);
                SqlNode aliasNode = operands.get(1);
                
                if (tableNode instanceof SqlIdentifier && aliasNode instanceof SqlIdentifier) {
                    String tableName = ((SqlIdentifier) tableNode).names.get(0);
                    String alias = ((SqlIdentifier) aliasNode).getSimple();
                    tableAliases.put(tableName.toUpperCase(), alias);
                }
            }
        }

        private void processIdentifier(SqlIdentifier identifier) {
            String tableName = identifier.names.get(0);
            tableAliases.put(tableName.toUpperCase(), tableName);
        }

        private SqlNode injectConditions(SqlNode originalWhere, SqlNode from) {
            List<SqlNode> conditions = new ArrayList<>();
            
            collectTables(from).stream()
                    .map(String::toUpperCase)
                    .filter(targetTables::contains)
                    .forEach(table -> {
                        String alias = tableAliases.getOrDefault(table, table);
                        conditions.add(createCondition(alias));
                    });

            if (conditions.isEmpty()) {
                return originalWhere;
            }

            SqlNode injectedCondition = conditions.stream()
                    .reduce((a, b) -> new SqlBasicCall(
                            SqlStdOperatorTable.AND,
                            Arrays.asList(a, b),
                            SqlParserPos.ZERO))
                    .get();

            if (originalWhere == null) {
                return injectedCondition;
            }

            return new SqlBasicCall(
                    SqlStdOperatorTable.AND,
                    Arrays.asList(originalWhere, injectedCondition),
                    SqlParserPos.ZERO);
        }

        private List<String> collectTables(SqlNode from) {
            List<String> tables = new ArrayList<>();
            if (from == null) return tables;

            if (from instanceof SqlJoin) {
                tables.addAll(collectTables(((SqlJoin) from).getLeft()));
                tables.addAll(collectTables(((SqlJoin) from).getRight()));
            } else if (from instanceof SqlBasicCall) {
                SqlBasicCall call = (SqlBasicCall) from;
                if (call.getOperator() == SqlStdOperatorTable.AS) {
                    tables.addAll(collectTables(call.getOperandList().get(0)));
                } else {
                    for (SqlNode operand : call.getOperandList()) {
                        tables.addAll(collectTables(operand));
                    }
                }
            } else if (from instanceof SqlIdentifier) {
                tables.add(((SqlIdentifier) from).names.get(0));
            } else if (from instanceof SqlSelect) {
                tables.addAll(collectTables(((SqlSelect) from).getFrom()));
            }
            return tables;
        }

        private SqlNode createCondition(String alias) {
            try {
                return SqlParser.create(
                        String.format("%s.HRP_NAME IN (1,2,3)", alias),
                        SQL_PARSER_CONFIG
                ).parseExpression();
            } catch (SqlParseException e) {
                throw new RuntimeException("Condition creation failed", e);
            }
        }
    }

    public static void main(String[] args) throws SqlParseException {
        String sql = "SELECT * FROM (SELECT * FROM orders) o JOIN customers c ON o.cust_id = c.id";
        String modifiedSql = injectFilter(sql, Arrays.asList("ORDERS", "CUSTOMERS"));
        System.out.println(modifiedSql);
    }
}