import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.parser.SqlParser.Config;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.dialect.OracleSqlDialect;
import java.util.*;
public class OracleSqlFilterInjector {
private static final Config SQL_PARSER_CONFIG = SqlParser.config()
.withLex(Lex.ORACLE)
.withCaseSensitive(false);
public static String injectFilter(String originalSql, List<String> targetTables) throws SqlParseException {
SqlParser parser = SqlParser.create(originalSql, SQL_PARSER_CONFIG);
SqlNode sqlNode = parser.parseQuery();
FilterInjectionShuttle shuttle = new FilterInjectionShuttle(targetTables);
SqlNode modifiedSqlNode = sqlNode.accept(shuttle);
return modifiedSqlNode.toSqlString(OracleSqlDialect.DEFAULT, true)
.getSql()
.replace("\"", ""); // Remove Oracle quotes
}
private static class FilterInjectionShuttle extends SqlShuttle {
private final Set<String> targetTables;
private final Map<String, String> tableAliases = new HashMap<>();
public FilterInjectionShuttle(List<String> targetTables) {
this.targetTables = new HashSet<>(targetTables);
}
@Override
public SqlNode visit(SqlCall call) {
if (call instanceof SqlSelect) {
return handleSelect((SqlSelect) call);
}
return super.visit(call);
}
private SqlNode handleSelect(SqlSelect select) {
// Process FROM clause
if (select.getFrom() != null) {
SqlNode from = select.getFrom().accept(this);
select.setFrom(from);
collectTableAliases(from);
}
// Process WHERE clause
SqlNode where = select.getWhere();
if (where != null) {
where = where.accept(this);
}
SqlNode newWhere = injectConditions(where, select.getFrom());
select.setWhere(newWhere);
return super.visit(select);
}
private void collectTableAliases(SqlNode from) {
if (from instanceof SqlJoin) {
collectTableAliases(((SqlJoin) from).getLeft());
collectTableAliases(((SqlJoin) from).getRight());
} else if (from instanceof SqlBasicCall) {
SqlBasicCall call = (SqlBasicCall) from;
if (call.getOperator() == SqlStdOperatorTable.AS) {
processAsOperator(call);
} else {
// Handle other types of calls (e.g. subqueries)
for (SqlNode operand : call.getOperandList()) {
collectTableAliases(operand);
}
}
} else if (from instanceof SqlIdentifier) {
processIdentifier((SqlIdentifier) from);
}
}
private void processAsOperator(SqlBasicCall call) {
List<SqlNode> operands = call.getOperandList();
if (operands.size() >= 2) {
SqlNode tableNode = operands.get(0);
SqlNode aliasNode = operands.get(1);
if (tableNode instanceof SqlIdentifier && aliasNode instanceof SqlIdentifier) {
String tableName = ((SqlIdentifier) tableNode).names.get(0);
String alias = ((SqlIdentifier) aliasNode).getSimple();
tableAliases.put(tableName.toUpperCase(), alias);
}
}
}
private void processIdentifier(SqlIdentifier identifier) {
String tableName = identifier.names.get(0);
tableAliases.put(tableName.toUpperCase(), tableName);
}
private SqlNode injectConditions(SqlNode originalWhere, SqlNode from) {
List<SqlNode> conditions = new ArrayList<>();
collectTables(from).stream()
.map(String::toUpperCase)
.filter(targetTables::contains)
.forEach(table -> {
String alias = tableAliases.getOrDefault(table, table);
conditions.add(createCondition(alias));
});
if (conditions.isEmpty()) {
return originalWhere;
}
SqlNode injectedCondition = conditions.stream()
.reduce((a, b) -> new SqlBasicCall(
SqlStdOperatorTable.AND,
Arrays.asList(a, b),
SqlParserPos.ZERO))
.get();
if (originalWhere == null) {
return injectedCondition;
}
return new SqlBasicCall(
SqlStdOperatorTable.AND,
Arrays.asList(originalWhere, injectedCondition),
SqlParserPos.ZERO);
}
private List<String> collectTables(SqlNode from) {
List<String> tables = new ArrayList<>();
if (from == null) return tables;
if (from instanceof SqlJoin) {
tables.addAll(collectTables(((SqlJoin) from).getLeft()));
tables.addAll(collectTables(((SqlJoin) from).getRight()));
} else if (from instanceof SqlBasicCall) {
SqlBasicCall call = (SqlBasicCall) from;
if (call.getOperator() == SqlStdOperatorTable.AS) {
tables.addAll(collectTables(call.getOperandList().get(0)));
} else {
for (SqlNode operand : call.getOperandList()) {
tables.addAll(collectTables(operand));
}
}
} else if (from instanceof SqlIdentifier) {
tables.add(((SqlIdentifier) from).names.get(0));
} else if (from instanceof SqlSelect) {
tables.addAll(collectTables(((SqlSelect) from).getFrom()));
}
return tables;
}
private SqlNode createCondition(String alias) {
try {
return SqlParser.create(
String.format("%s.HRP_NAME IN (1,2,3)", alias),
SQL_PARSER_CONFIG
).parseExpression();
} catch (SqlParseException e) {
throw new RuntimeException("Condition creation failed", e);
}
}
}
public static void main(String[] args) throws SqlParseException {
String sql = "SELECT * FROM (SELECT * FROM orders) o JOIN customers c ON o.cust_id = c.id";
String modifiedSql = injectFilter(sql, Arrays.asList("ORDERS", "CUSTOMERS"));
System.out.println(modifiedSql);
}
}