【实战】学习了泛型、注解和反射,我想开发一个ORM框架

1,280 阅读15分钟

我正在参与掘金创作者训练营第4期,点击了解活动详情,一起学习吧!

前言

之前学习了泛型、注解和反射(不清楚的可以看我之前的文章),是时候把这 3 个东西结合起来使用了。

本文讲述的是,我如何运用这几个知识点来实现一个小型的 ORM 框架。

小型 ORM 框架的需求如下:

  1. 有数据库连接池,能够连接上数据库(这是基本需求)。
  2. 能够做到数据表字段和Java实体类的属性自动映射,不考虑子查询以及复合实体类。
  3. 有基本的 CURD 功能(也就是删改查都通过 ID 来删改查)。
  4. 支持$#来非预编译和预编译的方式拼装SQL

数据库连接池

  1. 有数据库连接池,能够连接上数据库(这是基本需求)。 首先我们需要一个数据库连接池。

数据库连接池跟线程池差不多,都是为了减少反复地创建和销毁,数据库连接池对应的就是数据库链接,线程池对应的就是线程。具体的原理可以去看看我的设计模式专栏中的享元模式,我想差不多是这样设计的。

简而言之,数据库连接池可以为我们提供更快的响应,减少性能和内存的消耗。

这里选用了号称性能杀手的HiKariCP作为测试的数据库连接池,当然也可以使用C3P0或者是DBCP

HiKariCP的使用

  1. 在 pom.xml 增加依赖。
<dependency>
  <groupId>com.zaxxer</groupId>
  <artifactId>HikariCP</artifactId>
  <version>2.5.1</version>
</dependency>
  1. 在 resource 文件夹下创建 hikari.properties 文件。 这个文件名可以是其他的,我们后面使用的时候是可以自定义的。
#数据库地址
jdbcUrl=jdbc:mysql://localhost:3306/test?useSSL=false&useUnicode=true&characterEncoding=UTF-8
#数据库驱动
driverClassName=com.mysql.jdbc.Driver
#数据库用户名
dataSource.user=root
#数据库密码
dataSource.password=123456
dataSource.databaseName=test
dataSource.serverName=localhost
#最大数据库链接数,当数据库链接不够用时,就会增长到这个数
dataSource.maximumPoolSize=10

这里配置上自己的数据库地址和账号密码。

  1. 测试
public class HikariTest {
    public static void main(String[] args) {
        try (InputStream is = HikariTest.class.getClassLoader().getResourceAsStream("hikari.properties")) {
            // 加载属性文件并解析:
            Properties props = new Properties();
            props.load(is);
            HikariConfig config = new HikariConfig(props);
            HikariDataSource hikariDataSource = new HikariDataSource(config);
            Connection connection = hikariDataSource.getConnection();
            Statement statement = connection.createStatement();
            System.out.println();
        } catch (IOException | SQLException e) {
            e.printStackTrace();
        }
    }
}

System.out.println();可以在这行代码打上断点 hikariDataSource、connection、statement 不为 null,就说明我们的数据库连接池配置成功了。

数据库连接池工厂

虽然我们的HiKariCP数据库连接池已经配置成功了,但是为了更加方便的使用,我们还需要一个工厂来让我们更加方便地获得一个链接。

/**
 * 数据库连接工厂
 */
public class ConnectFactory {
    private static DataSource dataSource = getHikariDataSource();

    /**
     * 获取数据源
     *
     * @return
     */
    private static DataSource getHikariDataSource() {
        HikariDataSource hikariDataSource = null;
        try (InputStream is = HikariTest.class.getClassLoader().getResourceAsStream("hikari.properties")) {
            // 加载属性文件并解析:
            Properties props = new Properties();
            props.load(is);
            HikariConfig config = new HikariConfig(props);
            hikariDataSource = new HikariDataSource(config);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return hikariDataSource;
    }

    /**
     * 获取连接
     *
     * @return
     */
    public static Connection getConnection() {
        try {
            return dataSource.getConnection();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 关闭资源
     *
     * @param connection
     * @param statement
     * @param resultSet
     */
    public static void close(Connection connection, Statement statement, ResultSet resultSet) {
        try {
            if (resultSet != null) {
                resultSet.close();
            }
            if (statement != null) {
                statement.close();
            }
            if (connection != null) {
                connection.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

}

注解

  1. 能够做到数据表字段和Java实体类的属性自动映射,不考虑子查询以及复合实体类。

我们可以使用注解来存储表信息和字段信息,我们只需要在entity中配置好就可以直接拿来根据反射获取注解中的值,就可以拿来拼装一段 SQL 了。

@DataTableName

/**
 * 表注解
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface DataTableName {
    String value();
}

@DataTableField

/**
 * 字段注解
 */
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DataTableField {
    String value();
}

@DataTablePkey

/**
 * 主键注解
 */
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DataTablePkey {
}

在公司使用的是ORACLE,所以习惯主键使用的是 PKEY,各位朋友也可以使用 ID,看个人喜好。

entity

User

/**
 * 用户类
 */
@DataTableName("T_USER")
public class User {
    public User() {
    }

    public User(String name, Integer age, String address) {
        this.name = name;
        this.age = age;
        this.address = address;
    }

    public User(Long pkey, String name, Integer age, String address) {
        Pkey = pkey;
        this.name = name;
        this.age = age;
        this.address = address;
    }

    @DataTablePkey
    @DataTableField("PKEY")
    private Long Pkey;

    @DataTableField("NAME")
    private String name;

    @DataTableField("AGE")
    private Integer age;

    @DataTableField("ADDRESS")
    private String address;

    public String getName() {
        return name;
    }

    public User setName(String name) {
        this.name = name;
        return this;
    }

    public Integer getAge() {
        return age;
    }

    public User setAge(Integer age) {
        this.age = age;
        return this;
    }

    public String getAddress() {
        return address;
    }

    public User setAddress(String address) {
        this.address = address;
        return this;
    }

    public Long getPkey() {
        return Pkey;
    }

    public User setPkey(Long pkey) {
        Pkey = pkey;
        return this;
    }

    @Override
    public String toString() {
        return "User{" +
                "Pkey=" + Pkey +
                ", name='" + name + ''' +
                ", age=" + age +
                ", address='" + address + ''' +
                '}';
    }
}

实现 CURD

  1. 有基本的 CURD 功能。

这一部分就是去获取注解中的值,然后去拼装SQL就完事了。

至于使用#$来拼装SQL,可能需要使用到注解+动态代理了,这个需要后面再做了,现在先走出 CURD 这一步。

创建一个 BaseDao

先来创建一个接口IBaseDao和抽象类类BaseDao,这里为了实现能够获取到entity的 class 对象,编写了的BaseDao抽象类,作为所有 DAO 可以继承的,这个BaseDao抽象类包含公共的基础的 CURD 方法的抽象类。我们的 CURD 基本上都是会在BaseDao抽象类中实现。

BaseDao

private Class<T> clazz;
{
    clazz= (Class<T>) ((ParameterizedType)this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
}

因为我们创建的 UserDao 是继承 BaseDao,我们的泛型是写在父类上的,所以我们要获取到父类的 class 对象。

getGenericSuperclass()getSuperclass()都是获取当前对象父类的 class 对象啊,区别是前者不会擦除泛型,后者会擦除泛型。

这里为了能够获取到泛型的class,一开始想着能不能new T()?很显然不可以,泛型在编译之后都是会被擦除的,所以BaseDao在编译之后<T>就会变成一个Object,所以必须要把BaseDao<T>创建成一个抽象类,再让一个 Dao 来继承它。

还有 UserDao

public class UserDao extends BaseDao<User> {

}

这是UserDao,后面很长时间都再看到这个 dao 。

删除功能

我们来实现一个根据 PKEY 删除的功能,其实想实现通过一个对象来删除数据也是可以的,不过现在的首要目标是实现一个有基本 CURD 的框架,这里先挖个坑,以后有时间在回头来实现。

创建接口方法

IBaseDao

public interface IBaseDao<T>{
     int deleteByPkey(long PKEY);
}

实现接口方法

这里在BaseDao中实现删除方法就可以了。

这里一写,就写了差不多一百行代码。

INTEGER_STR
LONG_STR
STRING_STR
DATE_STR

这几个类型呢,是比较常用的类型,因为Javaswitch居然不支持通过Integer.class.getSimpleName()获取的字符串?!搞得我只能自己写一个字符串了,这就显得非常的呆,也就是写成这样来给preparedStatement设置参数。

 switch (name){
     case INTEGER_STR:
         preparedStatement.setInt(i,Integer.valueOf(String.valueOf(value)));
         break;
     case LONG_STR:
         preparedStatement.setLong(i,Long.valueOf(String.valueOf(value)));
         break;
     case STRING_STR:
         preparedStatement.setString(i,String.valueOf(value));
         break;
     case DATE_STR:
         //可能会报错
         preparedStatement.setDate(i,(Date) value);
         break;

后来发现,可以直接preparedStatement.setObject(i,value);,阿哲,居然可以直接设置Object类型,那还搞这么多类型干嘛?!我看了一眼,setObject的方法,这个方法就帮我们去判断了实际该调用setInt还是setString,也就是说当我们不确定参数的类型时,就直接调用setObject,这样判断参数类型的操作就直接交给了setObject

executeUpdate方法是打算用来执行设置SQL的参数和执行preparedStatement.executeUpdate()

我把表名给提取出来了,然后我觉得应该也要把实体类的属性和数据库的字段的映射提取出来,这样就能够在后面的代码中基本上不需要再接触到注解和class了,我们本来需要的数据不也是注解上面的值和属性名称嘛。

于是我还加上了entityFieldToDataTableFieldMapdataTableFieldToEntityFieldMap,前者是实体类属性对应数据表字段,后者是数据表字段对应实体类属性,为了更加方便,我还把实体类属性对应Field也抽成了一个entityFieldToFieldMap

setParameters方法的初衷是这样的,因为考虑到后面需要使用自定义注解,所以打算使用paramSequence来存储所有?占位符的位子对应着什么参数,paramMap来存储所有参数对应的值,这样参数的值和占位符就能够对应上了。

各位朋友也不需要担心,我在最后会把我编辑好的源码放到gitee上的,然后把连接地址放在最后,感兴趣的朋友可以下载来玩一玩。

public abstract class BaseDao<T> implements IBaseDao<T> {

    private Class<T> clazz;
    private final Map<String,String> entityFieldToDataTableFieldMap=new HashMap<>();
    private final Map<String,String> dataTableFieldToEntityFieldMap=new HashMap<>();
    private final Map<String,Field> entityFieldToFieldMap=new HashMap<>();

    {
        clazz= (Class<T>) ((ParameterizedType)this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
        Field[] fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            DataTableField fieldAnnotation = field.getAnnotation(DataTableField.class);
            String entityField=field.getName();
            String dataTableField=field.getName();
            if (fieldAnnotation!=null){
                dataTableField=fieldAnnotation.value();
            }
            entityFieldToDataTableFieldMap.put(entityField,dataTableField);
            dataTableFieldToEntityFieldMap.put(dataTableField,entityField);
            entityFieldToFieldMap.put(entityField,field);
        }
    }


    /**
     * 暂只支持这几种类型
     * 本来也想直接获取 Integer.getClass().getName() 的,但是 switch 不支持啊,没办法。
     */
    private final static String INTEGER_STR="Integer";
    private final static String LONG_STR="Long";
    private final static String STRING_STR="String";
    private final static String DATE_STR="Date";

    private final String tableName = getTableName();
    //表中的 PKEY 字段名
    private final String tablePkeyFieldName=getTablePkeyFieldName();


    /**
     * 根据 PKEY 删除数据
     * @param PKEY
     * @return
     */
    @Override
    public int  deleteByPkey(long PKEY) {
        StringBuilder builder=new StringBuilder("DELETE FROM ");
        builder.append(tableName).append(" WHERE ");
        builder.append(tablePkeyFieldName);
        builder.append("=?");
        List<String> paramSequence=new ArrayList<>(1);
        paramSequence.add(tablePkeyFieldName);
        Map<String,Object> paramMap=new HashMap<>(2);
        paramMap.put(tablePkeyFieldName,PKEY);
        return executeUpdate(builder.toString(),paramSequence,paramMap);
    }

    private String getTablePkeyFieldName() {
        String pkey=null;
        Field[] fields = clazz.getDeclaredFields();
        for (Field item:fields) {
            DataTablePkey pkeyA = item.getAnnotation(DataTablePkey.class);
            if (pkeyA!=null){
                pkey=item.getAnnotation(DataTableField.class).value();
                break;
            }
        }
        if (pkey==null){
            //默认主键为 PKEY
            pkey="PKEY";
        }
        return pkey;
    }

    /**
     * 获取表名
     * @return
     */
    private String getTableName() {
        String tableName;
        DataTableName tableAnnotation = clazz.getAnnotation(DataTableName.class);
        if(tableAnnotation==null){
            //当没有 Table 注解
            tableName=clazz.getSimpleName();
        }else {
            tableName=tableAnnotation.value();
        }
        return tableName;
    }


    /**
     * 执行 preparedStatement.executeUpdate()
     * @param sql
     * @param paramSequence 存放数据的名称,数据名称顺序和要插入的参数顺序要一致
     * @param paramMap 存放 paramSequence 中对应的参数值
     * @return
     */
    private int executeUpdate(String sql, List<String> paramSequence, Map<String, Object> paramMap) {
        Connection connection= ConnectFactory.getConnection();
        try {
            PreparedStatement preparedStatement =connection.prepareStatement(sql);
            setParameters(paramSequence, paramMap, preparedStatement);
            return preparedStatement.executeUpdate();
        } catch (SQLException e) {
            e.printStackTrace();
        }finally {
            ConnectFactory.close(connection,null,null);
        }
        return 0;
    }



    /**
     * 设置参数
     * @param paramSequence 存放数据的名称,数据名称顺序和要插入的参数顺序要一致
     * @param paramMap 存放 paramSequence 中对应的参数值
     * @param preparedStatement
     * @throws SQLException
     */
    private void setParameters(List<String> paramSequence, Map<String, Object> paramMap, PreparedStatement preparedStatement) throws SQLException {
        for (int i = 0; i < paramSequence.size();) {
            String key = paramSequence.get(i);
            i++;
            Object value = paramMap.get(key);
            preparedStatement.setObject(i,value);
        }
    }
}

后面只会贴新增的代码了。

查询功能

写完了删除方法,我发现我在写的过程中一直会修修改改,这必然不是一个好的编程,所以我打算先把思路捋清楚,再开始写。

我们需要用到的SQLSELECT * FROM TABLE_NAME WHERE PKEY=?

我们这里是只用主键来查询,可以看到我们需要使用的数据如下:表明,主键。

我想了一下,根据主键查询其实跟查询全部,有条件查询没什么差别,都是执行SQL,返回一个ResultSet

也就是说,我们需要做三个动作:

  1. 拼装 SQL
  2. 设置参数
  3. 处理结果集

设置参数,我们已经有了setParameters方法,我们只需要按照规则去组装paramSequenceparamMap就可以了。

拼装 SQL 的话,我们只需要拼装WHERE后面的子语句,其实来说还比较简单,所以可以使用一个实体类来做参数,就直接把实体类中有值的属性都加到WHERE子语句后。那SQL就是这样的:SELECT * FROM TABLE_NAME WHERE PKEY=? AND NAME=?,我们肯定要解决“AND”的问题,为了提高所以的效率,我们肯定不能使用1=1的方法去解决多余“AND”的问题,所以这里可以加一个标志,判断当前是不是第一个参数,是的话就不加“AND”,不是的话就加。

处理结果集的话,这里统一返回一个List<T>集合吧。

在拼装 SQL 和处理结果集的时候都需要使用到getset方法,这里就一个方法需要去拼装方法名,然后通过反射去调用传入进来的实参的方法,我们就能够获取参数,设置参数也是同理。

创建接口方法

IBaseDao

List<T> select(T t);

实现接口方法

BaseDao

@Override
public List<T> select(T t) {
    //拼接 SQL
    StringBuilder builder = new StringBuilder("SELECT * FROM " + tableName);
    List<String> paramSequence = new ArrayList<>();
    Map<String, Object> paramMap = new HashMap<>();
    if (t != null) {
        builder.append(" WHERE ");
        //这里我们使用 entityFieldToFieldMap 来遍历
        Set<String> keySet = entityFieldToFieldMap.keySet();
        Iterator<String> iterator = keySet.iterator();
        //判断是否是第一个参数
        boolean isNoFirst=false;
        try {
            while (iterator.hasNext()) {
                String entityFieldName = iterator.next();
                String getMethodName = getGetOrSetMethodName(entityFieldName, true);
                Method method = clazz.getDeclaredMethod(getMethodName);
                Object value = method.invoke(t);
                if (value == null) {
                    continue;
                }
                //当 value 不为 null,就开始处理
                String dataFieldName = entityFieldToDataTableFieldMap.get(entityFieldName);
                if (isNoFirst){
                    builder.append(" AND ");
                }else {
                    //第一次是不增加 AND
                    isNoFirst=true;
                }
                builder.append(dataFieldName).append("=?");
                paramSequence.add(entityFieldName);
                paramMap.put(entityFieldName, value);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    return executeQuery(builder.toString(), paramSequence, paramMap);
}
/**
 * 执行 preparedStatement.executeQuery()
 *
 * @param sql
 * @param paramSequence
 * @param paramMap
 * @return
 */
private List<T> executeQuery(String sql, List<String> paramSequence, Map<String, Object> paramMap) {
    Connection connection = ConnectFactory.getConnection();
    PreparedStatement preparedStatement = null;
    ResultSet resultSet = null;
    List<T> result = new ArrayList<>();
    try {
        preparedStatement = connection.prepareStatement(sql);
        setParameters(paramSequence, paramMap, preparedStatement);
        //执行 SQL
        resultSet = preparedStatement.executeQuery();
        //处理结果集
        while (resultSet.next()) {
            Set<String> keySet = dataTableFieldToEntityFieldMap.keySet();
            Iterator<String> iterator = keySet.iterator();
            T t = clazz.newInstance();
            for (int i = 0; iterator.hasNext(); i++) {
                String dataFieldName = iterator.next();
                //数据表中该列名的值
                Object value = resultSet.getObject(dataFieldName);
                String entityFieldName = dataTableFieldToEntityFieldMap.get(dataFieldName);
                String setMethodName = getGetOrSetMethodName(entityFieldName, false);
                Method method = clazz.getDeclaredMethod(setMethodName, entityFieldToFieldMap.get(entityFieldName).getType());
                //通过 method 反射调用 t
                method.invoke(t, value);
            }
            result.add(t);
        }
    } catch (Exception e) {
        e.printStackTrace();
    } finally {
        ConnectFactory.close(connection, preparedStatement, resultSet);
    }
    return result;
}

select方法的实现超出了我的预期,本来只想实现根据id查询的,后面就想着简单可以去实习根据实体类来获取参数查询,再后面因为传入的参数可能会为NULL,就实现了查询全部数据。

哈哈哈,看来想太多就会觉得很难,还是要动手做起来,才能遇到真实的问题,然后去解决它,而不应该被想象中的问题所难倒。

新增功能

常见的新增SQL:INSERT INTO TABLE_NAME(FIELD1,FIELD2) VALUES(VALUE1,VALUE2)

我们需要的数据有:表明、数据库字段名、实体对象。

具体的做法是:我们要获取传入进来的实体对象中的属性值,如果不为空,就把对应的数据库字段名加入到SQL中,这里需要两个字符串来拼装,一个是fieldString,一个是valueString,后者用于拼接占位符,至于最后一个“,”的问题,我们也只需要截取掉字符串最后一位就可以了。

还要把参数组装到我们的paramSequenceparamMap中。

哦对了,还有一个细节需要实现,就是在插入之后要把新增的主键返回来,这个的前提是我们要在数据表中把主键自增打开,我们需要在调用preparedStatement.executeUpdate()后,再去调用preparedStatement.getGeneratedKeys()来获取新增的主键。

考虑到新增功能是不需要使用实体类中的主键,我在新增一个noPkeyEntityFieldToDataTableFieldMap来存放主键之外的属性,这样使用的时候就不需要总是去判断现在的是不是主键了。增加代码如下图所示。

稍微把之前的代码调整了一下。

image.png image.png

创建接口方法

IBaseDao

Long insert(T t);

实现接口方法

BaseDao

@Override
public Long insert(T t) {
    //拼接 SQL
    StringBuilder fieldBuilder = new StringBuilder("INSERT INTO " + tableName + "(");
    StringBuilder valueBuilder = new StringBuilder("VALUES(");
    List<String> paramSequence = new ArrayList<>();
    Map<String, Object> paramMap = new HashMap<>();
    if (t == null) {
        return null;
    }
    //这里我们使用 dataTableFieldToEntityFieldMap 来遍历
    Set<String> keySet = noPkeyEntityFieldToDataTableFieldMap.keySet();
    Iterator<String> iterator = keySet.iterator();
    try {
        while (iterator.hasNext()) {
            //这里跟 select 的差不多  dataFieldName
            String entityFieldName = iterator.next();
            String dataFieldName = noPkeyEntityFieldToDataTableFieldMap.get(entityFieldName);
            String getMethodName = getGetOrSetMethodName(entityFieldName, true);
            Method method = clazz.getDeclaredMethod(getMethodName);
            Object value = method.invoke(t);
            if (value == null) {
                continue;
            }
            //当 value 不为 null,就开始处理
            fieldBuilder.append(dataFieldName).append(",");
            valueBuilder.append("?,");
            paramSequence.add(entityFieldName);
            paramMap.put(entityFieldName, value);
        }
        fieldBuilder.deleteCharAt(fieldBuilder.length()-1).append(") ");
        valueBuilder.deleteCharAt(valueBuilder.length()-1).append(")");
    } catch (Exception e) {
        e.printStackTrace();
    }
    fieldBuilder.append(valueBuilder.toString());
    return executeInsert(fieldBuilder.toString(),paramSequence,paramMap);
}

/**
 * 执行 preparedStatement.executeUpdate(),返回自增的主键
 * @param sql
 * @param paramSequence
 * @param paramMap
 * @return
 */
private Long executeInsert(String sql, List<String> paramSequence, Map<String, Object> paramMap) {
    Connection connection = ConnectFactory.getConnection();
    PreparedStatement preparedStatement =null;
    ResultSet resultSet=null;

    try {
        //new String[] {tablePkeyFieldName} 设置返回的主键
        preparedStatement=connection.prepareStatement(sql,new String[] {tablePkeyFieldName});
        setParameters(paramSequence, paramMap, preparedStatement);
        int update = preparedStatement.executeUpdate();
        if (update>0){
            resultSet= preparedStatement.getGeneratedKeys();
            resultSet.next();
            return (Long) resultSet.getObject(1);
        }
    } catch (SQLException e) {
        e.printStackTrace();
    } finally {
        ConnectFactory.close(connection, preparedStatement, resultSet);
    }
    return null;
}

其实这里的executeInsertexecuteUpdate方法很像,但是这里额外做多了一步,把新增的主键返回来了,所以就多了一个方法,以后有机会看看怎么把它俩怎么合并起来。

修改功能

还有最后一个修改方法了,能看到这里都是真爱了,希望各位真爱能点点这个👍,收到👍的我会更加有动力给大家带来更好的文章。

按照我们之前的分析方法,我们常见的修改SQLUPDATE TABLE_NAME SET FIELD1=VALUE1,FIELD2=VALUE2 WHERE PKEY=VALUE3

我们需要的数据有:表名、数据库字段名、实体对象。

这里拼接SQL的方法可以说和查询的如出一辙,但又有那么细微的地方不同.

查询:SELECT * FROM TABLE_NAME WHERE FIELD1=VALUE1 AND FIELD2=VALUE2

修改:UPDATE TABLE_NAME SET FIELD1=VALUE1,FIELD2=VALUE2 WHERE PKEY=VALUE3

查询的 WHERE 子语句和修改的 SET 子语句类似

阿哲,不就是分隔符的不同吗?有点摸到了动态 SQL 的感觉。

不过如果想把这一步分抽成一个方法,好像不太划算,又要解决“AND”/","多一个的问题,又要公用,那肯定就需要使用一个队列把字段存储起来了,这里先按照原来的逻辑写吧。

创建接口方法

IBaseDao

int updateByPkey(T t);

实现接口方法

BaseDao

@Override
public int updateByPkey(T t) {
    String pkeyFieldName = dataTableFieldToEntityFieldMap.get(tablePkeyFieldName);
    String getPkeyMethodName = getGetOrSetMethodName(pkeyFieldName, true);
    Object pkey = null;
    try {
        Method getPkeyMethod = clazz.getDeclaredMethod(getPkeyMethodName);
        pkey = getPkeyMethod.invoke(t);
    } catch (Exception e) {
        e.printStackTrace();
    }
    if (t == null || pkey == null) {
        //参数为空,pkey 为空都不处理
        return 0;
    }
    //拼接 SQL
    StringBuilder builder = new StringBuilder("UPDATE " + tableName+" SET ");
    List<String> paramSequence = new ArrayList<>();
    Map<String, Object> paramMap = new HashMap<>();
    //这里我们使用 noPkeyEntityFieldToDataTableFieldMap 来遍历
    Set<String> keySet = noPkeyEntityFieldToDataTableFieldMap.keySet();
    Iterator<String> iterator = keySet.iterator();
    //判断是否是第一个参数
    boolean isNoFirst = false;
    try {
        while (iterator.hasNext()) {
            String entityFieldName = iterator.next();
            String getMethodName = getGetOrSetMethodName(entityFieldName, true);
            Method method = clazz.getDeclaredMethod(getMethodName);
            Object value = method.invoke(t);
            if (value == null) {
                continue;
            }
            //当 value 不为 null,就开始处理
            String dataFieldName = noPkeyEntityFieldToDataTableFieldMap.get(entityFieldName);
            if (isNoFirst) {
                builder.append(",");
            } else {
                //第一次是不增加 ,
                isNoFirst = true;
            }
            builder.append(dataFieldName).append("=?");
            paramSequence.add(entityFieldName);
            paramMap.put(entityFieldName, value);
        }
        builder.append(" WHERE ");
        builder.append(tablePkeyFieldName).append("=?");
        paramSequence.add(tablePkeyFieldName);
        paramMap.put(tablePkeyFieldName, pkey);
    } catch (Exception e) {
        e.printStackTrace();
    }
    return executeUpdate(builder.toString(), paramSequence, paramMap);
}

到这里 CURD 功能就写完了,简单的增删改查场景都能使用。

CURD 功能测试

虽然我是一边开发一边测试,基本上也是没有问题滴,不过还是要给大家一个交代,这里还是放一下测试的结果。

我在 T_USER 表写了一些数据。

image.png 测试代码:

public static void main(String[] args) {
    UserDao userDao = new UserDao();
    List<User> select1 = userDao.select(null);
    System.out.println("-----select1-----");
    for (User user : select1) {
        System.out.println(user.toString());
    }
    int update1 = userDao.deleteByPkey(2L);
    System.out.println("update1:"+update1);
    List<User> select2 = userDao.select(null);
    System.out.println("-----select2-----");
    for (User user : select2) {
        System.out.println(user.toString());
    }
    Long pkey = userDao.insert(new User("小刚", 18, "广东"));
    System.out.println("pkey:"+pkey);
    List<User> select3 = userDao.select(new User().setPkey(pkey));
    System.out.println("-----select3-----");
    for (User user : select3) {
        System.out.println(user.toString());
    }
    int update2 = userDao.updateByPkey(new User(3L, "小蓝", 99, "广东"));
    System.out.println("update2:"+update2);
    List<User> select4 = userDao.select(new User().setPkey(3L));
    System.out.println("-----select4-----");
    for (User user : select4) {
        System.out.println(user.toString());
    }
}

执行!

-----select1-----

User{Pkey=1, name='小明', age=18, address='广东'}

User{Pkey=2, name='小蓝', age=99, address='广东'}

User{Pkey=3, name='小红', age=18, address='广东'}

User{Pkey=4, name='小刚', age=18, address='广东'}

update1:1

-----select2-----

User{Pkey=1, name='小明', age=18, address='广东'}

User{Pkey=3, name='小红', age=18, address='广东'}

User{Pkey=4, name='小刚', age=18, address='广东'}

pkey:9

-----select3-----

User{Pkey=9, name='小刚', age=18, address='广东'}

update2:1

-----select4-----

User{Pkey=3, name='小蓝', age=99, address='广东'}

注意!实体类的整型类型必须改成包装类型,不然会有默认值,会影响我们的查询。

到这里基本功能就完结了,自定义 SQL 还需要构思一下,肯定会用到动态代理的,动态 SQL 就可能要等更晚一点才能去实现了。

路要一口一口地吃,饭要一步一步地走。让我们开始我们下一个目标:自定义 SQL。

gitee 地址:orm

自定义 SQL

  1. 支持$#来非预编译和预编译的方式拼装SQL

这一部分就是讲述如何去实现自定义 SQL 的功能了。

一个小问题

问题描述

这里遇到一个小问题(我差点就要放弃了):

通过new创建的对象是可以获取泛型的 Class 的,但是通过cglib动态代理以及 Java 自带的动态代理是获取不到的。

我们以 UserDao 作示例。

这是通过new创建的对象,可以看到这里的 name 是 org.example.simulate.orm.achieve.entity.User,也就是这个可以获取到当前的泛型的 class 对象。

image.png

接下来是cglib创建的代理对象。

首先,我们需要一个代理对象的生成器,我这里简单写了一个,代码就不贴上来了,代码如下图。

cglib代码生成器

调用代码:

UserDao instance = (UserDao)CglibDynamicProxy.getProxyInstance(UserDao.class);
instance.insert(null);

接下来,让我们看一下这里能不能获取到泛型的 class 对象。

很显然,获取不到,这里显示类型转换错误。

image.png

问题解决方案

为什么会出现这种情况?

那是因为,我们使用了动态代理,它会生成一个继承 UserDao 的子类,我们不能从 UserDao 中获取到泛型的 class 对象,所以要再向上一级获取到 BaseDao,我们才能获取到泛型的 class 对象。

我也只能根据现象做出结论,希望有更专业的大佬来解答一下。

所以我们要修改 BaseDao 中的代码:

{
    clazz = (Class<T>) ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
}

改成

{
    // 获取继承的父类(不擦除泛型的)
    Type type = this.getClass().getGenericSuperclass();
    if(type instanceof ParameterizedType){
        // 当前的类继承了 ParameterizedType,可以直接获取泛型
        clazz = (Class<T>) ((ParameterizedType) type).getActualTypeArguments()[0];
    }else {
        // 不能直接获取泛型,需要再获取上一个父类才能获取泛型
        clazz = (Class<T>) ((ParameterizedType) this.getClass().getSuperclass().getGenericSuperclass()).getActualTypeArguments()[0];
    }

}

如果当前的type实现了ParameterizedType接口的话,就可以通过getActualTypeArguments()[0]获取到实际的泛型 class 对象了,否则就再获取当前type的父类,再去获取泛型的 class 对象。

代码改了之后,不会再类型转换的错误了,然后我们再修改一下cglib的代理对象生成器,就可以顺利运行了。

image.png

成功运行!

image.png

可以看到,代理对象生成器需要通过newInstance()创建一个 UserDao 对象才能去调用 BaseDao 中的方法,可见这里需要我们做出一些调整。

我们创建各种 Dao 对象的单例缓存在系统中,我们这里不能创建一个公共的 Dao,因为我们的实体类型是直接写在泛型里的,所以只能把各种 Dao 缓存在系统中。

在遇到这些公共的方法时,就直接去调用 Dao 单例,其他的方法都通过拼装注解中的 SQL 去执行。

代理 Dao 生成器

基本上只需要把我们之前代理对象的生成器复制粘贴一下就好了。

DaoCglibDynamicProxyimage.png

剩下的就是要去做 Dao 单例缓存和公共方法识别,以及最终要实现的自定义 SQL。

前两个功能都是为了调用BaseDao上写的公共方法,后一个就又要开始拼装 SQL 了。

Dao 单例缓存

我们现在先只调用 BaseDao 已经实现的方法作为测试就可以了。

开始实现 Dao 单例缓存:

我们是可以通过DaoCglibDynamicProxy.intercept方法的第一个参数获得我们要的 Dao 的 class 对象的。

image.png

那我们就可以通过这个 class 对象创建一个 UserDao 对象,然后通过类的全限定名作为 key,UserDao 对象作为 value 存入一个 map 中。

当然啦,我们具体的做法肯定是先从 map 中取,没有再新建一个,为了线程安全,双重检查什么的肯定也少不了。

DaoCglibDynamicProxy中新增的代码:

//对象缓存 map
private static Map<String, Object> objectCacheMap = new HashMap<>();

private Object getDaoFromCache(Class clazz) throws ClassNotFoundException {
    return getDaoFromCache(clazz.getName());
}

private Object getDaoFromCache(String clazzName) {
    Object o = objectCacheMap.get(clazzName);
    if (o==null){
        synchronized (objectCacheMap){
            // 要锁住这个 objectCacheMap
            o = objectCacheMap.get(clazzName);
            if(o==null){
                // 再次检验,为 null,则创建一个对象
                Class<?> clazz = null;
                try {
                    clazz = Class.forName(clazzName);
                    o = clazz.newInstance();
                    objectCacheMap.put(clazzName,o);
                } catch (Exception e) {
                    e.printStackTrace();
                }

            }
        }
    }
    return o;
}

对应的intercept()修改一下:

image.png

测试。

image.png

测试通过了,接着开发下一个功能——公共方法识别

公共方法识别

我们可以把所有公共方法的名字都写在一个字符串上,然后判断当前进入intercept()method是否在这里面,在的话就去获取 Dao 单例接着调用公共方法,不在就去走自定义 SQL 的逻辑。

因为我们开发BaseDao时是实现了IBaseDao接口的,并且IBaseDao接口中的方法也是公共方法,所以我们可以通过IBaseDao.class.getMethods()获取到全部的公共方法,我们只需要把这些Method对象的名称拼接在一起就行了。

DaoCglibDynamicProxy中新增的代码:

private static String publicMethodsStr = null;

static {
    StringBuilder builder = new StringBuilder();
    for (Method method : IBaseDao.class.getMethods()) {
        builder.append(method.getName()).append(",");
    }
    publicMethodsStr = builder.toString();
}

DaoCglibDynamicProxy.intercept()就改成了下面这样:

@Override
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
    if (publicMethodsStr.contains(method.getName())){
        // 属于公共方法
        Object newInstance = getDaoFromCache(o.getClass().getSuperclass());
        return method.invoke(newInstance, objects);
    }
    // 不属于公共方法

    return null;
}

我测试了一下,是可以通过的,让我们接着开发自定义 SQL。

ps:后面发现getProxyInstance()可以使用泛型,这样可以省去我们再强制转换类型一次,所以又修改了一下getProxyInstance()

public static <T> T getProxyInstance(Class<T> clazz){
    //1.工具类
    Enhancer en = new Enhancer();
    //2.设置父类
    en.setSuperclass(clazz);
    //3.设置回调函数
    en.setCallback(proxy);
    //4.创建子类(代理对象)
    return (T) en.create();
}

终于要开发自定义 SQL 了

我们这里有两个需求:

一个是能够预编译,就是使用#{name}作为占位符,每占一个我们都给它替换成?,然后在预编译 SQL 时设置上对应的参数;

另一个是能够非预编译地实现字符串的替换,也就是使用${name}作为占位符,在预编译之前就替换上对应的参数。

BaseDao的方法中可以看到,这里只有 3 种执行方法,executeInsert执行 insert,executeQuery执行 select,executeUpdate执行 update、delete,所以我们至少需要创建 3 个注解,区分不同的 SQL 类型,让对应类型的 SQL 去调用对应的执行方法。

还有,这 3 个方法都是私有的,我们需要创建一个IExecutor作为这 3 个方法的接口,接着再让BaseDao实现这个接口并且开放这 3 个方法。

image.png

所以我们还要做两件事情:

  1. 创建 1 个只有执行方法的接口,并让BaseDao实现这个接口。
  2. 创建 3 个注解。
  3. 当调用的不是公共方法时,去解析注解。

IExecutor

一个执行器的接口,只有三个方法比较简单。

IExecutor:

/**
 * 执行器接口
 * @param <T>
 */
public interface IExecutor<T> {

    Long executeInsert(String sql, List<String> paramSequence, Map<String, Object> paramMap) ;

    List<T> executeQuery(String sql, List<String> paramSequence, Map<String, Object> paramMap) ;

    int executeUpdate(String sql, List<String> paramSequence, Map<String, Object> paramMap);
}

BaseDao: image.png

至于 3 个方法的修改就不放上来啦,需要把private修改成public

3 个注解

我打算把这 3 个注解命名成:InsertSql、SelectSql、UpdateSql,属性只有一个 value 就好,我们不搞那么复杂的功能。

应该没忘记怎么写注解吧?忘记的赶紧复习一下👉注解

SelectSql

/**
 * executeSelect 执行的 sql
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SelectSql {
    String value();
}

InsertSql

/**
 * executeInsert 执行的 sql
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface InsertSql {
    String value();
}

UpdateSql

/**
 * executeUpdate 执行的 sql
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface UpdateSql {
    String value();
}

好了,让我们用SelectSql来写一个方法,待会可以用来测试。

image.png

我打算统一都使用一个 map 作为参数,我们也可以直接调用执行方法,不需要再去转换了,这里也是之前开发时所犯的错误,这里就要被限制了。

分析注解

我打算创建一个SqlInfo来解析方法上的注解,并把 sql 信息存在SqlInfo对象中。

主要修改了 DaoCglibDynamicProxy 的 intercept 方法,以及增加了 execute 方法。

DaoCglibDynamicProxy:

@Override
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
    if (publicMethodsStr.contains(method.getName())){
        // 属于公共方法
        Object newInstance = getDaoFromCache(o.getClass().getSuperclass());
        return method.invoke(newInstance, objects);
    }
    // 不属于公共方法
    Map<String,Object> params = (Map<String, Object>) objects[0];

    // Dao 都继承了 IExecutor,这里直接强转
    IExecutor executor = (IExecutor)getDaoFromCache(o.getClass().getSuperclass());

    SqlInfo sqlInfo = SqlInfo.getInstance(o.getClass().getSuperclass(),method);

    return execute(sqlInfo.getSql(params),sqlInfo.getParamSequence(),params,sqlInfo.getSqlClass(),executor);
}

/**
 * 判断该执行那个方法,判断执行 sql。
 * @param executor
 * @return
 */
public  Object execute(String sql,List<String> paramSequence,Map<String,Object> params,Class sqlClass,IExecutor executor) {
    if (sqlClass == InsertSql.class) {
        return executor.executeInsert(sql,paramSequence,params);
    } else if (sqlClass == SelectSql.class) {
        return executor.executeQuery(sql,paramSequence,params);
    } else if (sqlClass == UpdateSql.class) {
        return executor.executeUpdate(sql,paramSequence,params);
    }
    return null;
}

新增了一个 SqlInfo 类:

SqlInfo:

/**
 * Sql 信息类
 * 存放解析后的 sql,以及提供非预编译处理方法
 */
public class SqlInfo {

    /**
     * public List<T> executeQuery(String sql, List<String> paramSequence, Map<String, Object> paramMap)
     * 我们的执行方法的三个参数:sql、paramSequence、paramMap。
     * paramMap 我们直接让用户提供给我们了,所以我们需要分析获得 sql、paramSequence
     */
    // 未经过 ${} 替换的 sql
    private String sql = null;

    // 需要预编译的参数列表
    private List<String> paramSequence = null;

    // 非预编译的参数列表,用来替换 ${}
    private List<String> notPrecompiledParamSequence = null;

    // 注解 class 对象
    private Class sqlClass = null;

    // sqlInfo 缓存
    private static Map<String, SqlInfo> sqlInfoCacheMap = new HashMap<>();

    // 外部不能创建 SqlInfo 对象,只能通过 getInstance 获取
    private SqlInfo(String sql, List<String> paramSequence, List<String> notPrecompiledParamSequence, Class sqlClass) {
        this.sql = sql;
        this.paramSequence = paramSequence;
        this.notPrecompiledParamSequence = notPrecompiledParamSequence;
        this.sqlClass = sqlClass;
    }

    /**
     * 根据 class 和 method 来返回一个 sqlInfo,若没有则分析后创建一个
     *
     * @param clazz
     * @param method
     * @return
     */
    public static SqlInfo getInstance(Class clazz, Method method) {
        String sqlStr = null;
        Class sqlClass = null;
        String functionName = clazz.getName() + "." + method.getName();
        Annotation[] annotations = method.getAnnotations();
        for (int i = 0; i < annotations.length; i++) {
            Annotation item = annotations[i];
            // 识别各种注解
            if (item.annotationType() == InsertSql.class) {
                InsertSql sqlAnnotation = (InsertSql) item;
                sqlStr = sqlAnnotation.value();
                sqlClass = InsertSql.class;
            } else if (item.annotationType() == SelectSql.class) {
                SelectSql sqlAnnotation = (SelectSql) item;
                sqlStr = sqlAnnotation.value();
                sqlClass = SelectSql.class;
            } else if (item.annotationType() == UpdateSql.class) {
                UpdateSql sqlAnnotation = (UpdateSql) item;
                sqlStr = sqlAnnotation.value();
                sqlClass = UpdateSql.class;
            }
        }
        // 获取 sqlInfo
        return getSqlInfo(functionName, sqlStr,sqlClass);
    }

    /**
     * 通过 全限定类名+方法名 获取一个 sqlInfo 对象,缓存中存在就现从缓存中获取。
     *
     * @param functionName 全限定类名+方法名
     * @param sqlStr       需要执行的 sql
     * @param sqlClass  当前注解的 class 对象
     * @return
     */
    private static SqlInfo getSqlInfo(String functionName, String sqlStr, Class sqlClass) {
        SqlInfo info = sqlInfoCacheMap.get(functionName);
        //老样子,双重检验
        if (info == null) {
            synchronized (sqlInfoCacheMap) {
                info = sqlInfoCacheMap.get(functionName);
                if (info == null) {
                    info = creatSqlInfo(sqlStr,sqlClass);
                    sqlInfoCacheMap.put(functionName, info);
                }
            }
        }
        return info;
    }

    /**
     * 根据 sql 创建一个 sqlInfo 对象
     *
     * @param sqlStr
     * @param sqlClass
     * @return
     */
    private static SqlInfo creatSqlInfo(String sqlStr, Class sqlClass) {
        StringBuilder builder = new StringBuilder(sqlStr);
        List<String> paramSequence = new ArrayList<>();
        List<String> notPrecompiledParamSequence = new ArrayList<>();
        // 解析 #{}
        // 获取 #{ 字符的下标 和 } 字符的下标,截取中间的字符,替换
        while (builder.indexOf("#{") > 0) {
            int startIndex = builder.indexOf("#{");
            int endIndex = builder.indexOf("}", startIndex);
            String attribute = builder.substring(startIndex + 2, endIndex);
            paramSequence.add(attribute);
            builder.replace(startIndex, endIndex + 1, "?");
        }

        // 解析 ${}
        while (builder.indexOf("${") > 0) {
            int startIndex = builder.indexOf("${");
            int endIndex = builder.indexOf("}", startIndex);
            String attribute = builder.substring(startIndex + 2, endIndex);
            notPrecompiledParamSequence.add(attribute);
            builder.replace(startIndex, endIndex + 1, attribute);
        }

        SqlInfo sqlInfo = new SqlInfo(builder.toString(), paramSequence, notPrecompiledParamSequence,sqlClass);

        return sqlInfo;
    }

    /**
     * 获取 sql。
     * 若 sql 存在 ${},也就是需要实现非预编译的字符串替换
     *
     * @param paramsMap 字符串替换时需要用到的参数
     * @return
     */
    public String getSql(Map paramsMap) {
        if (this.notPrecompiledParamSequence.size() == 0) {
            // 不需要字符串替换,直接返回 sql。
            return this.sql;
        }
        // 经过字符串替换的 sql
        String sqlStr = this.sql;
        for (String attributeName : this.notPrecompiledParamSequence) {
            String param = String.valueOf(paramsMap.get(attributeName));
            sqlStr = sqlStr.replace(attributeName,param);
        }
        return  sqlStr;
    }

    public List<String> getParamSequence(){
        return this.paramSequence;
    }

    public Class getSqlClass(){
        if (this.sqlClass == null){
            throw new NullPointerException("this.sqlClass is null.");
        }
        return this.sqlClass;
    }
}

TEST

我们在这里就简单写几个注解来测试一下,只测试增改查,删除和修改是一样的就不重复了。

写在UserDao的三个方法如下:

public class UserDao extends BaseDao<User> {

    @SelectSql("SELECT * FROM T_USER WHERE PKEY = #{PKEY}")
    public List<User> selectByPkey(Map params){
        return null;
    }

    @InsertSql("INSERT INTO T_USER(NAME,AGE,ADDRESS)VALUES(#{NAME},#{AGE},#{ADDRESS})")
    public Long insertUser(Map params){
        return null;
    }

    @UpdateSql("UPDATE T_USER SET ADDRESS = #{ADDRESS} WHERE NAME = #{NAME}")
    public int updateAddress(Map params){
        return 0;
    }
}

selectByPkey 的表名就没有使用 ${} 的方式了,这里为了简单点就直接写了表名。

测试代码如下:

UserDao userDao = DaoCglibDynamicProxy.getProxyInstance(UserDao.class);
Map<String, Object> params = new HashMap<>();
params.put("NAME","xu");
params.put("AGE",18);
params.put("ADDRESS","guangdong");
Long pkey = userDao.insertUser(params);
params.put("PKEY",pkey);

List<User> users = userDao.selectByPkey(params);
users.forEach( item ->{
    System.out.println(item.toString());
});

params.put("ADDRESS","china");
userDao.updateAddress(params);

users = userDao.selectByPkey(params);
users.forEach( item ->{
    System.out.println(item.toString());
});

注意获取 Dao 需要使用 DaoCglibDynamicProxy.getProxyInstance(class),才能获取到经过代理的对象,进而去使用注解。

打印结果如下:

image.png

最后

为了不错过创造训练营的 deadline,我先提前发布了。

以上代码暂时测试是通过的,等有空我再来详细测试一下,并且补充测试的内容。

测试内容已经补充了。

关于这个 ORM 框架,我最后做一个总结:

  • 首先,能够完成这个框架,我觉得对我而言已经超出预期了,因为我本来打算不开发 SQL注解 这个功能的,但最后经过一拖又拖地还是完成了。
  • 其次,作为一个只读过一点点源码的小白来说,这个框架完成的还是有一些瑕疵的。

image.png 像这里 MyBatis 使用的都是接口,确实使用接口会比较简洁,而我这里只能是方法,这也是因为我一开始使用了BaseDao,后面想改成接口,只有全部推倒重新来过了,我觉得还是算了,以后可以考虑在DaoCglibDynamicProxy.intercept()中调用,增加方法的多元化。

  • 最后,这个 ORM 框架以后我也不知道会不会再继续增加功能或者优化,有空就搞一搞。

最新的代码已经提交到 gitee 上了,地址:gitee.com/xuxiaojian1…,欢迎各位下载来玩。

image.png

这就全部类、接口、注解了。

虽然这个 ORM 的功能比较简单,但也花费了我数个星期,看到这里不妨留个赞👍。