private static class FilterInjectionShuttle extends SqlShuttle {
private final Set<String> targetTables;
private final Deque<ScopeContext> scopes = new ArrayDeque<>();
private static class ScopeContext {
Map<String, String> aliases = new HashMap<>();
Set<String> processedTables = new HashSet<>();
}
public FilterInjectionShuttle(List<String> targetTables) {
this.targetTables = new HashSet<>(targetTables);
scopes.push(new ScopeContext());
}
@Override
public SqlNode visit(SqlCall call) {
if (call instanceof SqlSelect) {
scopes.push(new ScopeContext());
SqlNode result = handleSelect((SqlSelect) call);
scopes.pop();
return result;
}
return super.visit(call);
}
private SqlNode handleSelect(SqlSelect select) {
if (select.getFrom() != null) {
SqlNode from = select.getFrom().accept(this);
select.setFrom(from);
collectLocalTables(from);
}
SqlNode where = select.getWhere();
if (where != null) {
where = where.accept(this);
}
SqlNode newWhere = injectLocalConditions(where);
select.setWhere(newWhere);
return super.visit(select);
}
private void collectLocalTables(SqlNode from) {
ScopeContext currentScope = scopes.peek();
if (from instanceof SqlJoin) {
collectLocalTables(((SqlJoin) from).getLeft());
collectLocalTables(((SqlJoin) from).getRight());
} else if (from instanceof SqlBasicCall) {
SqlBasicCall call = (SqlBasicCall) from;
if (call.getOperator() == SqlStdOperatorTable.AS) {
List<SqlNode> operands = call.getOperandList();
if (operands.size() >= 2) {
SqlNode tableNode = operands.get(0);
SqlNode aliasNode = operands.get(1);
if (tableNode instanceof SqlIdentifier && aliasNode instanceof SqlIdentifier) {
String tableName = extractPhysicalTableName(tableNode);
String alias = ((SqlIdentifier) aliasNode).getSimple();
currentScope.aliases.put(tableName.toUpperCase(), alias);
currentScope.processedTables.add(tableName.toUpperCase());
}
}
}
} else if (from instanceof SqlIdentifier) {
String tableName = extractPhysicalTableName(from);
currentScope.aliases.put(tableName.toUpperCase(), tableName);
currentScope.processedTables.add(tableName.toUpperCase());
}
}
private SqlNode injectLocalConditions(SqlNode originalWhere) {
ScopeContext currentScope = scopes.peek();
List<SqlNode> conditions = new ArrayList<>();
currentScope.processedTables.stream()
.filter(targetTables::contains)
.forEach(table -> {
String alias = currentScope.aliases.get(table);
conditions.add(createCondition(alias));
});
if (conditions.isEmpty()) {
return originalWhere;
}
SqlNode injectedCondition = conditions.stream()
.reduce((a, b) -> new SqlBasicCall(
SqlStdOperatorTable.AND,
Arrays.asList(a, b),
SqlParserPos.ZERO))
.get();
if (originalWhere == null) {
return injectedCondition;
}
return new SqlBasicCall(
SqlStdOperatorTable.AND,
Arrays.asList(originalWhere, injectedCondition),
SqlParserPos.ZERO);
}
private String extractPhysicalTableName(SqlNode node) {
if (node instanceof SqlIdentifier) {
List<String> names = ((SqlIdentifier) node).names;
return names.get(names.size() - 1);
}
return "";
}
}