根据Mybatis自己实现一个ORM框架!

536 阅读8分钟

上次分析完了Mybatis启动以及执行流程后,本次根据上次的思路实现一个简单的mybatis框架。

git地址:github.com/dp33221/my_…

首先需要分析一下什么场景下出现了ORM框架,最早的时候我们使用数据库链接步骤如下

原生jdbc操作

可以看到每次都需要注册驱动,建立连接,执行查询,获取结果,关闭连接。显然这存在高耦合以及硬编码问题。所以我们一步一步的解决问题。

第一四个问题:注册驱动以及建立数据库连接

这边我们可以封装一个连接池,这样就可以解决耦合问题,然后数据库连接配置的话我们需要使用方提供,这边借鉴Mybatis的方式,使用xml进行配置。所以需要一个configuration对象。

第二三个问题:执行查询

为了能够解耦以及可配置性,这边也是想到的使用配置文件的方式进行配置,参考Mybatis的Mapper.xml文件,这边需要一个mapperStatement对象存储对应的xml信息,分析一下这个对象需要哪些属性,首先肯定需要记录sql,然后要记录入参parameter,然后就是返回结果result,还有执行的类型也就是executorType,以及对应的id这个id默认就是标签中的id。

mappedStatement

xml中每个标签都会被解析成一个mappedstatement对象,所以我们需要有一个映射关系来保存,最容易想到的就是Map也就是键值对,键就取xml中的namespace+id,值就是对应的MappedStatement对象。存储在configuration对象中。所以configuration对象要存一个连接池,以及解析的xml的数据。

@Data
@Builder
public class MappedStatement {
    private String id;
    private Class<?> paramType;
    private Class<?> resultType;
    private String sql;
    private String executorType;
}
Conifguration

首先解决前两个问题,第一步需要配置配置文件sqlMapConfig.xml,其中我设置一个根标签,以及子标签property来配置数据库连接属性解决硬编码问题。在解析configuration文件时也可以同时解析mapper.xml所以在设置一个mapper标签来设置要解析的mapper路径。

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class Configuration {
    private DataSource dataSource;
    @Builder.Default
    private Map<String,MappedStatement> mappedStatementList=new HashMap<String, MappedStatement>();

}
sqlMapConfig.xml

这就是基础的configuration配置,其中property就是数据库连接的基本配置,mapper就是要解析的mapper.xml

sqlMapConfig.xml

mapper.xml文件配置了一些sql信息包括namespace、id、paramType、resultType、sql这些信息。 image.png

解析配置文件

在解析配置文件之前,要理清楚一个大体的流程。首先需要创建一个SqlSession对象的工厂用于生成SqlSession,同时解析配置文件,然后生成sqlsession对象,之后在进行excutor操作。 总体结构

第一步 初始化解析配置文件

Resources.Java

1.开始解析xml配置文件,我借助dom4j以及jaxen来实现。首先第一步要将xml配置文件读成输入流的形式,借助classloader的getResourcesAsStream方法实现。

public class Resources {
    /**
     * @Description: 获取资源转化为字节输入流
     * @param path 资源路径
     * @return java.io.InputStream
     * @Author: dingpei
     */
    public static InputStream getResourcesAsStream(String path){
        return Resources.class.getClassLoader().getResourceAsStream(path);
    }
}
SqlSessionFactoryBuild

通过SqlSessionFactoryBuild解析输入流

public class SqlSessionFactoryBuild {

    private Configuration configuration;

    public SqlSessionFactoryBuild(){
        this.configuration=new Configuration();
    }

    /**
     * @Description: 解析输入流
     * @param in 输入流
     * @return sqlsession.SqlSessionFactory
     * @Author: dingpei
     */
    public SqlSessionFactory build(InputStream in) throws DocumentException, PropertyVetoException, ClassNotFoundException {
        //解析mapper和SqlMapConfig
        final Configuration configuration = new XmlConfigBuild().buildConfiguration(in);
        DefaultBeanFactory.singletonObjects.put("dateSource",configuration.getDataSource());
        //创建sqlsessionfactory
        return new DefaultSqlSessionFactory(configuration);

    }
}

XmlConfigBuild.java

2.将字节输入流传递进来开始进行xml文件解析,执行XmlConfigBuild的buildConfiguration方法,首先获取字节流,然后获取xml跟标签, 通过xpath表达式得到节点信息,然后存储到properties中,创建连接池,然后设置数据源,存储到configuration对象中。然后获取所有的mapper标签,并遍历解析。

public class XmlConfigBuild {
    private Configuration configuration;
    public XmlConfigBuild(){
        this.configuration=new Configuration();
    }

    /**
     * @Description: 解析sqlMapConfig文件 封装成Configuration对象
     * @param inputStream 字节输入流
     * @return config.Configuration
     * @Author: dingpei
     */
    public Configuration buildConfiguration(InputStream inputStream) throws DocumentException, PropertyVetoException, ClassNotFoundException {
        //解析字节流
        Document document=new SAXReader().read(inputStream);
        //获取根标签
        final Element rootElement = document.getRootElement();
        //获取所有property配置
        final List<Element> propertyList = rootElement.selectNodes("//property");
        //存储解析的k v
        Properties resources=new Properties();
        for (Element element : propertyList) {
            final String name = element.attributeValue("name");
            final String value = element.attributeValue("value");
            resources.put(name,value);
        }
        //创建链接池
        ComboPooledDataSource comboPooledDataSource=new ComboPooledDataSource();
        comboPooledDataSource.setJdbcUrl(resources.getProperty("jdbcUrl"));
        comboPooledDataSource.setDriverClass(resources.getProperty("driverClass"));
        comboPooledDataSource.setUser(resources.getProperty("username"));
        comboPooledDataSource.setPassword(resources.getProperty("password"));
        //设置数据源
        configuration.setDataSource(comboPooledDataSource);
        XmlMapperBuild xmlMapperBuild=new XmlMapperBuild(configuration);
        //加载mapper.xml
        final List<Element> mapperList = rootElement.selectNodes("//mapper");
        for (Element element : mapperList) {
            final String mapperPath = element.attributeValue("resource");
            InputStream in= Resources.getResourcesAsStream(mapperPath);
            xmlMapperBuild.build(in);
        }
        return configuration;

    }
}

XmlMapperBuild.java

3.调用的xmlMapperBuild的build方法。在新建xmlmapperbuild对象时,要将已经设置数据源的configuration传入。解析Mapper的步骤与sqlMapConfig差不多,这边主要讲一下解析各个标签的步骤,首先就是获取到对应的id然后返回类型以及参数类型,再获取到sql语句,生成statementid(namespace+id)以及MappedStatement,存入configuration对象中的Map集合中。到这一步就执行完文件解析的步骤了。

public class XmlMapperBuild {
    private Configuration configuration;
    public XmlMapperBuild(Configuration configuration){
        this.configuration=configuration;
    }

    /**
     * @Description: 读取mapper文件 封装成mappedStatement对象
     * @param inputStream 字节输入流
     * @return void
     * @Author: dingpei
     */
    public void build(InputStream inputStream) throws DocumentException, ClassNotFoundException {
        //解析字节流
        final Document document = new SAXReader().read(inputStream);
        //获取根标签
        final Element rootElement = document.getRootElement();
        //获取mapper文件namespace
        final String namespace = rootElement.attributeValue("namespace");
        //查询所有的select标签
        final List<Element> selectList = rootElement.selectNodes("//select");

        iteratorElement(selectList,namespace,"select");

        //查询所有update标签
        final List<Element> updateList = rootElement.selectNodes("//update");
        iteratorElement(updateList,namespace,"update");

        //查询所有insert标签
        final List<Element> insertList = rootElement.selectNodes("//insert");
        iteratorElement(insertList,namespace,"insert");

        //查询所有delete标签
        final List<Element> deleteList = rootElement.selectNodes("//delete");
        iteratorElement(deleteList,namespace,"delete");

    }

    /**
     * @Description: 遍历标签存入configuraation的mappedStatement集合
     * @param elementList element集合
     * @param namespace 全限名
     * @return void
     * @Author: dingpei
     */
    private void iteratorElement(List<Element> elementList,String namespace,String executorType) throws ClassNotFoundException {
        for (Element element : elementList) {
            final String id = element.attributeValue("id");
            final String resultType = element.attributeValue("resultType");
            final String paramType = element.attributeValue("paramType");
            Class<?> paramClassType=getClass(paramType);
            Class<?> resultClassType= getClass(resultType);
            final String sql = element.getTextTrim();
            String statementId=namespace+id;
            configuration.getMappedStatementList().put(statementId, MappedStatement.builder().id(id).paramType(paramClassType).resultType(resultClassType).executorType(executorType).sql(sql).build());
        }
    }

    private Class<?> getClass(String s) throws ClassNotFoundException {
        if(null==s||s.equals("")){
            return null;
        }
        return Class.forName(s);
    }
}

第二步 生成回话对象

通过工厂生成sqlsession
public class DefaultSqlSessionFactory implements SqlSessionFactory {
    private Configuration configuration;
    public DefaultSqlSessionFactory(Configuration configuration){
        this.configuration=configuration;
    }
    @Override
    public DefaultSqlSession openSession() {
        return new DefaultSqlSession(configuration);
    }
}

接着就是生成sqlsession对象,sqlsession主要就是sql对话对象,执行sql操作

sqlsession.java

定义接口

public interface SqlSession {
    <E> List<E> selectAll(String statementId,Object... param) throws NoSuchFieldException, SQLException, InvocationTargetException, IntrospectionException, InstantiationException, IllegalAccessException;
    <E> E selectOne(String statementId,Object... params) throws IllegalAccessException, IntrospectionException, InstantiationException, SQLException, InvocationTargetException, NoSuchFieldException;
    <E> E getMapper(Class c);
 }

defaultSqlSession.java

首先定义一个接口 需要实现哪些方法,然后生成一个默认的实现类来实现这些方法,主要就是一个查询select方法以及一个更新update方法。其中提供了getMapper方法来实现Mapper接口的代理对象创建(jdk动态代理),具体步骤就是通过全限定类名+方法名去获取到MappedStatement对象(注意:因此namespace需要是类的全限定路径,id是类的方法,同时这也导致了方法无法重载)。然后通过返回值类型以及操作类型判断是走查询还是走更新操作,如果走更新操作则会记录方法的参数数组。sqlsession是一个对外的会话对象,其中执行sql语句还是通过Executor对象来执行的。

public class DefaultSqlSession implements SqlSession {
    public DefaultSqlSession(){

    }

    private Configuration configuration;
    public DefaultSqlSession(Configuration configuration){
        this.configuration=configuration;
    }


    /**
     * @Description: 查询全部
     * @param statementId 标签id
     * @param params 参数
     * @return java.util.List<E>
     * @Author: dingpei
     */
    @Override
    public <E> List<E> selectAll(String statementId, Object... params) throws NoSuchFieldException, SQLException, InvocationTargetException, IntrospectionException, InstantiationException, IllegalAccessException {
        final MappedStatement mappedStatement = configuration.getMappedStatementList().get(statementId);
        return new DefaultExcutor().query(configuration,mappedStatement,params);
    }

    /**
     * @Description: 查询单个
     * @param statementId 标签id
     * @param params 参数
     * @return java.util.List<E>
     * @Author: dingpei
     */
    @Override
    public <E> E selectOne(String statementId, Object... params) throws IllegalAccessException, IntrospectionException, InstantiationException, SQLException, InvocationTargetException, NoSuchFieldException {
        final List<Object> objects = selectAll(statementId, params);
        if(objects.size()==0){
            return (E) new Object();
        }
        if(objects.size()>1){
            throw new RuntimeException("out of 2");
        }
        return (E) objects.get(0);
    }

    /**
     * @Description: 增删改操作
     * @param statementId 标签id
     * @param parameters 顺序
     * @param params 参数
     * @return java.lang.Integer
     * @Author: dingpei
     */
    public Integer update(String statementId,Parameter[] parameters, Object... params) throws ClassNotFoundException, SQLException, IllegalAccessException, NoSuchFieldException {
        final MappedStatement mappedStatement = configuration.getMappedStatementList().get(statementId);
        return new DefaultExcutor().update(configuration,mappedStatement,parameters,params);
    }

    /**
     * @Description: 代理
     * @Author: dingpei
     * @Date: 2021/6/1 9:09 下午
     */
    @Override
    public <E> E getMapper(Class c) {
        final Object o = Proxy.newProxyInstance(DefaultSqlSession.class.getClassLoader(), new Class[]{c}, new InvocationHandler() {
            @Override
            public Object invoke(Object o, Method method, Object[] objects) throws Throwable {
                //获取类全限定名+方法名
                final String methodName = method.getName();
                final String className = method.getDeclaringClass().getName();
                String statementId=className+methodName;
                final MappedStatement mappedStatement = configuration.getMappedStatementList().get(statementId);
                final Type genericReturnType = method.getGenericReturnType();
                if(genericReturnType instanceof ParameterizedType){
                    return selectAll(statementId,objects);
                }
                //如果不是基本数据类型则是查询单个
                else if(!(genericReturnType.getClass().getClassLoader() ==null)|| "select".equals(mappedStatement.getExecutorType())) {
                    return selectOne(statementId,objects);
                }

                //获取方法的型参名
                final Parameter[] parameters = method.getParameters();
                return update(statementId,parameters,objects);
            }
        });
        return (E) o;

    }
}

第三步

执行sql操作

Excutor.java

接口定义规范

public interface Excutor {
    //查询
    <E> List<E> query(Configuration configuration, MappedStatement mappedStatement,Object...params) throws SQLException, NoSuchFieldException, IllegalAccessException, IntrospectionException, InstantiationException, InvocationTargetException;
    //增删改
    int update(Configuration configuration, MappedStatement mappedStatement, Parameter[] parameters, Object...params) throws SQLException, NoSuchFieldException, IllegalAccessException, ClassNotFoundException;
}

DefaultExecutor.java

和sqlsession一样先创建一个接口定义规则,然后创建一个defaultExecutor进行默认实现。其中有一个查询功能一个更新功能,我这边注释写的比较全面,就不赘述了。

public class DefaultExcutor implements Excutor {

    /**
     * @Description: select执行
     * @param configuration 配置解析
     * @param mappedStatement mapper解析
     * @param params 入参
     * @return java.util.List<E>
     * @Author: dingpei
     */
    @Override
    public <E> List<E> query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws SQLException, NoSuchFieldException, IllegalAccessException, IntrospectionException, InstantiationException, InvocationTargetException {
        //获取链接
        Connection connection = Resource.conn.get();
        if(connection==null){
            connection=configuration.getDataSource().getConnection();
            Resource.conn.set(connection);
        }
        //获取sql
        final String sql = mappedStatement.getSql();
        //新建处理器
        final ParameterMappingTokenHandler parameterMappingTokenHandler = new ParameterMappingTokenHandler();
        //解析sql 将#{}转化为?
        final String parse = new GenericTokenParser("#{", "}", parameterMappingTokenHandler).parse(sql);
        //sql中#{}包含的集合
        final List<ParameterMapping> parameterMappings = parameterMappingTokenHandler.getParameterMappings();
        //预编译sql
        PreparedStatement preparedStatement=connection.prepareStatement(parse);
        //找出入参类型
        final Class<?> paramType = mappedStatement.getParamType();
        if(paramType!=null){
            int index=1;
            for (ParameterMapping parameterMapping : parameterMappings) {
                //字段名
                final String name = parameterMapping.getName();
                //反射获取字段
                final Field declaredField = paramType.getDeclaredField(name);
                //开启访问权限
                declaredField.setAccessible(true);
                //获取入参的值
                final Object o = declaredField.get(params[0]);
                preparedStatement.setObject(index,o);
                index++;
            }
        }

        final ResultSet resultSet = preparedStatement.executeQuery();
        //获取返回值类型
        final Class<?> resultType = mappedStatement.getResultType();
        List<E> res=new ArrayList<>();
        //封装返回结果集
        while(resultSet.next()){
            //查询结果
            final ResultSetMetaData metaData = resultSet.getMetaData();
            //获取多少列数
            final int columnCount = metaData.getColumnCount();
            final Object o = resultType.newInstance();
            //遍历结果并封装对象
            for(int i=1;i<=columnCount;i++){
                //字段名
                final String columnName = metaData.getColumnName(i);
                //字段值
                final Object value = resultSet.getObject(columnName);
                //内省
                PropertyDescriptor propertyDescriptor=new PropertyDescriptor(columnName,resultType);
                //创建写方法
                final Method writeMethod = propertyDescriptor.getWriteMethod();
                writeMethod.invoke(o,value);
            }
            res.add((E) o);

        }

        return res;
    }

    /**
     * @Description: 增删改操作
     * @param configuration 配置信息
     * @param mappedStatement xml解析信息
     * @param params 入参
     * @return int
     * @Author: dingpei
     */
    @Override
    public int update(Configuration configuration, MappedStatement mappedStatement, Parameter[] parameters, Object... params) throws SQLException, NoSuchFieldException, IllegalAccessException, ClassNotFoundException {
        //获取链接
        Connection connection = Resource.conn.get();
        if(connection==null){
            connection=configuration.getDataSource().getConnection();
            Resource.conn.set(connection);
        }
        //获取sql
        final String sql = mappedStatement.getSql();
        //新建处理器
        final ParameterMappingTokenHandler parameterMappingTokenHandler = new ParameterMappingTokenHandler();
        //解析sql 将#{}转化为?
        final String parse = new GenericTokenParser("#{", "}", parameterMappingTokenHandler).parse(sql);
        //sql中#{}包含的集合
        final List<ParameterMapping> parameterMappings = parameterMappingTokenHandler.getParameterMappings();
        //预编译sql
        PreparedStatement preparedStatement=connection.prepareStatement(parse);
        //获取入参类型
        final Class<?> paramType = mappedStatement.getParamType();
        //如果未设置入参类型,则默认基本数据类型或引用数据类型
        if(paramType==null){
            //将传入的值排序
            for (int j=0;j<parameterMappings.size();j++) {
                //parameterMappings是参数插入的顺序,在型参找到对应的位置
                for (int i = 0; i < parameters.length; i++) {
                   //此处parameterMappings是sql中#{}解析出来的顺序,parameters是形参顺序
                   if(parameterMappings.get(j).getName().equals(parameters[i].getName())){
                        preparedStatement.setObject(j+1,params[i]);
                    }
                }

            }
        }
        //请求类型不为空
        if(paramType!=null){
            int index=1;
            for (ParameterMapping parameterMapping : parameterMappings) {
                //字段名
                final String name = parameterMapping.getName();
                //反射获取字段
                final Field declaredField = paramType.getDeclaredField(name);
                //开启访问权限
                declaredField.setAccessible(true);
                //获取入参的值
                final Object o = declaredField.get(params[0]);
                preparedStatement.setObject(index,o);
                index++;
            }
        }
        return preparedStatement.executeUpdate();
    }
}

其中有在处理请求参数时我借用了mybatis中的处理类。并进行了简化

utils

参数处理 GenericTokenParser、ParameterMappingTokenHandler、TokenHandler,如果感兴趣可以至git查看

测试结果

如图就是自定义ORM框架的查询结果,如果要查看增删改方法则调用update/delete/insert方法

image.png

Mapper中方的方法

public interface PaymentChannelMapper {
    List<PaymentChannel> selectAll();
    PaymentChannel selectOne(PaymentChannel paymentChannel);
    int update(Integer id,String channel_name);
    int insert(PaymentChannel paymentChannel);
    int delete(int id);
}