一个比mybatis的动态SQL更加动态的ORM小工具

1,222 阅读3分钟

前言:在大数据部门遇到这样一个场景:数据分散存储在MySQL,Hive,Redis,es等平台,部门想要把各个存储介质中的数据汇总,对外提供一个统一的访问接口;即:其他部门以一种访问方式,就可以获取到分布在异构存储介质的数据;

部门大佬给出的方案是:给不同的存储介质写不同的查询引擎,例如mySQL的查询引擎可以接收一些参数转化成完整的SQL,Redis接收同样的参数可以转化成Redis能够识别的查询语句...

这有点像JVM跨平台方案,也像Graal VM跨语言的方案,将问题转化成子问题,分治的解决子问题即可

我对此有更进一步的想法,对于MySQL而言,通过几个参数(要查询的字段,where条件,表名,limit以及order情况)就可以构造出完整的SQL,相比于之前每个select要写动态SQL而言,自己编码能够更加灵活

以下是代码

  1. 项目结构

image.png

  1. 核心方法
public class GenSql {  
  
public static final Set<String> NUMBER_TYPES = new HashSet<>();  
  
static {  
NUMBER_TYPES.add("tinyint");  
NUMBER_TYPES.add("smallint");  
NUMBER_TYPES.add("int");  
NUMBER_TYPES.add("bigint");  
NUMBER_TYPES.add("double");  
NUMBER_TYPES.add("float");  
NUMBER_TYPES.add("decimal");  
  
}  
  
public static String produce(SqlParamConvert params) {  
/*  
项目经理(海华):  
- 传入参数:  
1. 目的属性  
2. dataSource 数据库名  
3. 表名  
4. wheres  
5. limit  
6. order  
*/  
String tableName = params.getTableName();  
Where[] wheres = params.getWheres();  
boolean limitHasOffset = params.isLimitHasOffset();  
String limit = params.getLimit();  
String[] fields = params.getFields();  
  
StringBuilder fieldsInSql = new StringBuilder();  
// for (String str:fields) {  
// fieldsInSql.append(str).append(",");  
// }  
for (int i = 0; i < fields.length; i++) {  
fieldsInSql.append(fields[i]);  
if (i != fields.length - 1) {  
fieldsInSql.append(",");  
}  
}  
  
StringBuilder wheresInSql = new StringBuilder();  
for (int i = 0; i < wheres.length; i++) {  
Where where = wheres[i];  
String name = where.getName();  
WhereType whereType = where.getWhereType();  
if (WhereType.SIMPLE_EQUALS.equals(whereType)) {  
//是简单的 a等于a‘  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
  
wheresInSql.append(  
intOrChar ? String.format(" %s=%s", name, where.getValue()) :  
String.format(" %s='%s'", name, where.getValue())  
);  
} else if (WhereType.RANGE_EQUALS.equals(whereType)) {  
//是范围类型 a in range / a < b / a > c  
switch (where.getRangeType()) {  
case IN -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String[] ins = where.getValue().split(",");  
String tempStr = intOrChar ?  
String.format("%s in (%s)", where.getName(), String.join(",", ins)) :  
String.format("%s in ('%s')", where.getName(), String.join(",", ins));  
wheresInSql.append(tempStr);  
}  
case BETWEEN -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String[] betweens = where.getValue().split(",");  
String tempStr = intOrChar ?  
String.format("%s between %s and %s", where.getName(), betweens[0], betweens[1]) :  
String.format("%s between '%s' and '%s'", where.getName(), betweens[0], betweens[1]);  
wheresInSql.append(tempStr);  
}  
case LESS_THAN -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String value = where.getValue();  
String tempStr = intOrChar ?  
String.format("%s < %s", where.getName(), value) :  
String.format("%s < '%s'", where.getName(), value);  
wheresInSql.append(tempStr);  
}  
case MORE_THAN -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String value = where.getValue();  
String tempStr = intOrChar ?  
String.format("%s > %s", where.getName(), value) :  
String.format("%s > '%s'", where.getName(), value);  
wheresInSql.append(tempStr);  
}  
case LESS_THAN_EQUALS -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String value = where.getValue();  
String tempStr = intOrChar ?  
String.format("%s <= %s", where.getName(), value) :  
String.format("%s <= '%s'", where.getName(), value);  
wheresInSql.append(tempStr);  
}  
case MORE_THAN_EQUALS -> {  
boolean intOrChar = NUMBER_TYPES.stream().anyMatch(t -> where.getType().name().startsWith(t));  
String value = where.getValue();  
String tempStr = intOrChar ?  
String.format("%s >= %s", where.getName(), value) :  
String.format("%s >= '%s'", where.getName(), value);  
wheresInSql.append(tempStr);  
}  
}  
}  
if (i != wheres.length - 1) {  
wheresInSql.append(" and ");  
}  
}  
String limitStr = "";  
if (limitHasOffset) {  
String[] limitCondition = limit.split(",");  
//原本无论有无offset,直接拼在limit后即可,但是为了检测无注入行为,选择细分场景  
limitStr = String.format("%s,%s", limitCondition[0], limitCondition[1]);  
} else {  
limitStr = String.format("%s", limit);  
}  
  
StringBuilder sql = new StringBuilder();  
return sql.append("select ")  
.append(fieldsInSql)  
.append(" from ")  
.append(tableName)  
.append("".equals(wheresInSql.toString()) ? "" : " where")  
.append(wheresInSql)  
.append("".equals(limitStr) ? "" : " limit ")  
.append(limitStr)  
.toString();  
}  
}
  1. 传入参数
@Data  
public class SqlParamConvert {  
private Where[] wheres;  
private String tableName;  
private String[] fields;  
private boolean limitHasOffset;  
private String limit;  
  
@Override  
public String toString() {  
return "SqlParamConvert{" +  
"wheres=" + Arrays.toString(wheres) +  
", tableName='" + tableName + '\'' +  
", fields=" + Arrays.toString(fields) +  
", limitHasOffset=" + limitHasOffset +  
", limit='" + limit + '\'' +  
'}';  
}  
}
@Data  
public class Where {  
  
WhereType whereType;  
RangeType rangeType;  
String name;  
String value;  
MySQLDataType type;
}
public enum WhereType {  
SIMPLE_EQUALS,  
RANGE_EQUALS  
}
package com.example.demo.param;  
  
public enum RangeType {  
  
IN,  
BETWEEN,  
LESS_THAN,  
MORE_THAN,  
LESS_THAN_EQUALS,  
MORE_THAN_EQUALS  
}
package com.example.demo.param;  
  
public enum MySQLDataType {  
// 数值类型  
TINYINT,  
SMALLINT,  
MEDIUMINT,  
INT,  
BIGINT,  
FLOAT,  
DOUBLE,  
DECIMAL,  
  
// 日期和时间类型  
DATE,  
TIME,  
DATETIME,  
TIMESTAMP,  
YEAR,  
  
// 字符串类型  
CHAR,  
VARCHAR,  
BINARY,  
VARBINARY,  
BLOB,  
TINYBLOB,  
MEDIUMBLOB,  
LONGBLOB,  
TEXT,  
TINYTEXT,  
MEDIUMTEXT,  
LONGTEXT,  
  
// 枚举和集合类型  
ENUM,  
SET;  
}
  1. 查询引擎
@Component  
public class SelectEngine {  
public <T> List<T> query(SqlParamConvert sqlParam, String tClass) throws SQLException, ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {  
System.out.println("sqlParam = " + sqlParam);  
String sql = GenSql.produce(sqlParam);  
ResultSet rs = prepareResultSet(sql);  
  
  
List<T> ans = new ArrayList<>();  
Class<?> c = Class.forName("com.example.demo.test.User");  
Field[] fields = c.getDeclaredFields();  
String[] targetFields = sqlParam.getFields();  
Set<String> set = new HashSet<>();  
Collections.addAll(set, targetFields);  
  
  
System.out.println("fields.length = " + fields.length);  
while (rs.next()) {  
T t = (T) c.getDeclaredConstructor().newInstance();  
for (Field f : fields) {  
String[] split = f.toString().split("\\.");  
String target = split[split.length - 1];  
if (!set.contains(target)) {  
continue;  
}  
String value = rs.getString(target);  
System.out.println("f.toString() = " + f);  
f.setAccessible(true);  
f.set(t, convertFieldFromString(f, value));  
}  
ans.add(t);  
}  
  
return ans;  
}  
  
public ResultSet prepareResultSet(String sql) throws ClassNotFoundException, SQLException {  
System.out.println("sql = " + sql);  
Class.forName("com.mysql.cj.jdbc.Driver");  
Connection com = DriverManager.getConnection("jdbc:mysql://192.168.56.10/demo", "root", "root");  
Statement stat = com.createStatement();  
ResultSet rs = stat.executeQuery(sql);  
  
// stat.close();  
// com.close();  
return rs;  
}  
  
public Object convertFieldFromString(Field f, String value) {  
Class<?> type = f.getType();  
if (type.equals(Integer.class) || type.equals(int.class)) {  
return Integer.parseInt(value);  
} else if (type.isEnum()) {  
return value;  
} else {  
return value;  
}  
}  
}

测试:

数据库中数据如下

image.png

查询的Request Body如下

{
  "wheres": [
    {
      "whereType": "SIMPLE_EQUALS",
      "name": "age",
      "value": "30",
      "type": "INT"
    },
    {
      "whereType": "RANGE_EQUALS",
      "rangeType": "BETWEEN",
      "name": "registration_date",
      "value": "2022-01-01, 2022-01-15",
      "type": "DATE"
    }
  ],
  "tableName": "users",
  "fields": ["username", "age", "email", "registration_date"],
  "limitHasOffset": true,
  "limit": "0,10"
}

结果:

image.png