package self.assets.platform.mysql;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLLimit;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLAllColumnExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.agrona.collections.ObjectHashSet;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.Connection;
import java.util.List;
import java.util.Optional;
@Slf4j
public class IllegalSqlInnerInterceptor implements InnerInterceptor {
private final static MessageDigest MESSAGE_DIGEST;
static {
try {
MESSAGE_DIGEST = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
private final ObjectHashSet<SqlMd5LongPair> cacheValidResult = new ObjectHashSet<>();
private record SqlMd5LongPair(long high, long low) {
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpStatementHandler.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())) {
return;
}
BoundSql boundSql = mpStatementHandler.boundSql();
String originalSql = boundSql.getSql();
log.debug("检查SQL是否合规,SQL:{}", originalSql);
byte[] digest = MESSAGE_DIGEST.digest(originalSql.getBytes(StandardCharsets.UTF_8));
ByteBuffer digestBuffer = ByteBuffer.wrap(digest);
SqlMd5LongPair sqlMd5LongPair = new SqlMd5LongPair(digestBuffer.getLong(), digestBuffer.getLong());
if (cacheValidResult.contains(sqlMd5LongPair)) {
log.debug("该SQL已验证,无需再次验证,SQL:{}", originalSql);
return;
}
List<SQLStatement> sqlStatements = SQLUtils.parseStatements(originalSql, DbType.mysql);
for (SQLStatement sqlStatement : sqlStatements) {
sqlCheck(sqlStatement);
}
cacheValidResult.add(sqlMd5LongPair);
}
public static void sqlCheck(SQLStatement sqlStatement) {
if (sqlStatement == null) {
return;
}
if (sqlStatement instanceof SQLSelectStatement sqlSelectStatement) {
selectCheck(sqlSelectStatement);
} else if (sqlStatement instanceof SQLUpdateStatement sqlUpdateStatement) {
updateCheck(sqlUpdateStatement);
} else if (sqlStatement instanceof SQLInsertStatement sqlInsertStatement) {
insertCheck(sqlInsertStatement);
} else if (sqlStatement instanceof SQLDeleteStatement sqlDeleteStatement) {
deleteCheck(sqlDeleteStatement);
} else {
log.error("不支持的SQL类型:{},{}", sqlStatement.getClass(), sqlStatement);
}
}
public static void updateCheck(SQLUpdateStatement statement) {
mustHaveLimit(statement.getLimit());
mustHaveWhere(statement.getWhere());
}
public static void selectCheck(SQLSelectStatement statement) {
SQLSelect select = statement.getSelect();
Optional.ofNullable(select.getQuery())
.filter(query -> query instanceof SQLSelectQueryBlock)
.map(query -> (SQLSelectQueryBlock) query)
.ifPresent(queryBlock -> mustHaveLimit(queryBlock.getLimit()));
Optional.ofNullable(select.getQueryBlock())
.map(SQLSelectQueryBlock::getSelectList)
.map(selectItems -> selectItems.stream()
.anyMatch(selectItem -> selectItem.getExpr() instanceof SQLAllColumnExpr))
.ifPresent(hasAllColumn -> {
if (hasAllColumn) {
log.error("查询语句包含*");
}
});
}
public static void deleteCheck(SQLDeleteStatement statement) {
mustHaveWhere(statement.getWhere());
}
public static void insertCheck(SQLInsertStatement statement) {
if (statement.getColumns()
.isEmpty()) {
log.error("缺失列定义");
}
}
private static void mustHaveLimit(SQLLimit limit) {
if (limit == null) {
log.error("缺失limit关键字");
}
}
private static void mustHaveWhere(SQLExpr where) {
if (where == null) {
log.error("缺失where条件");
}
}
}