手写Mybatis-plus(ORM框架)

68 阅读11分钟

ORM* *(Object Relational Mapping)对象关系映射

O* (对象模型):*

实体对象,即我们在程序中根据数据库表结构建立的一个个实体javaBean

R* (关系型数据库的数据结构):*

关系数据库领域的Relational(建立的数据库表)

M* (映射):*

从R(数据库)到O(对象模型)的映射,可通过XML文件映射

(1)让实体类和数据库表进行一一对应关系

先让实体类和数据库表对应

通过注解进行实现,定义对应的value值为数据库字段。

再让实体类属性和表里面字段对应

查询玩结果需要封装映射到实体类上面。

(2)不需要直接操作数据库表,直接操作表对应的实体类对象

底层逻辑:反射+代理+注释

例如常用的工具lombak和mybatis-plus,底层都是通过注释标明要操作的类,通过反射获取这个类的属性方法,通过添加代理,让代理对其进行实现。

@Data和@MapperScan皆是如此实现

下面主要以mybatis-plus为目标,一步步写如何实现mybatis-plus**ORM框架

Mapper泛型(MapperBase)

定义Mapper泛型内部方法

// 通用 Mapper 接口
public interface MyBaseMapper<T> {
​
    // 根据 ID 查询
    T selectById(Serializable id);
​
}

注释

定义表名,避免实体类和数据库表不一致

@Retention(RetentionPolicy.RUNTIME) // 运行时保留,反射可获取
@Target(ElementType.TYPE) // 只能用在类/接口上
public @interface Table {
    String value(); // 存储表名,例如 @Table("user_info")
}

定义属性,避免实体类内部定义属性和数据库属性不一致

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Column {
    String value() default ""; // 字段名(默认与属性名一致)
    boolean exist() default true; // 是否为数据库字段
}

定义主键,标明主键,方便通过主键来进行增删改查

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface MyTableId {
    String value() default ""; // 主键字段名
}

字段处理

定义一个对上述实体类定义的字段,我们取到他的值,封装好数据,方便进行生成sql语句

public class TableInfo {
    private String tableName; // 表名
    private String primaryKey; // 主键字段名(数据库列名)
    private String primaryKeyProperty; // 主键属性名(实体类字段名)
    private List<String> columnList; // 所有数据库字段名(排除 exist=false 的字段)
    private List<String> propertyList; // 所有实体类属性名(排除 exist=false 的字段)
​
    // 构造方法:传入实体类 Class,解析表名和字段(与 @Column 联动)
    public TableInfo(Class<?> entityClass) {
        // 1. 解析表名(@Table 注解)
        Table tableAnnotation = entityClass.getAnnotation(Table.class);
        if (tableAnnotation != null && !tableAnnotation.value().isEmpty()) {
            //获取对应注释的值
            this.tableName = tableAnnotation.value();
        } else {
            //用默认名,类名
            this.tableName = entityClass.getSimpleName().toLowerCase(); // 默认表名:类名小写
        }
​
        // 2. 解析字段(含 @Column 注解处理)
        this.columnList = new ArrayList<>();
        this.propertyList = new ArrayList<>();
        Field[] fields = entityClass.getDeclaredFields();
//        getFields():获得某个类的所有的公共(public)的字段,包括父类中的字段。
//        getDeclaredFields():获得某个类的所有声明的字段,即包括public、private和proteced,但是不包括父类的申明字段。
        for (Field field : fields) {
            // 跳过静态/瞬态字段
            //private transient String tempData; → 瞬态字段(transient),不参与序列化和数据库映射,需跳过。
            //java.lang.reflect.Modifier的方法,isStatic判断是否为静态字段,isTransient判断是否为瞬时字段
            //getModifiers方法用于获取类、字段或方法的修饰符,对于类就是获取修饰符,对于属性就是获取修饰符
            if (java.lang.reflect.Modifier.isStatic(field.getModifiers())
                    || java.lang.reflect.Modifier.isTransient(field.getModifiers())) {
                continue;
            }
​
            // -------------------------- 核心:解析 @Column 注解 --------------------------
            //获取注释定义的字段
            //通过 @Column(exist = false) 显式排除 实体类中有但数据库表中无的字段。
            Column columnAnnotation = field.getAnnotation(Column.class);
            // a. 处理 exist=false:排除非数据库字段(不加入 columnList 和 propertyList)
            if (columnAnnotation != null && !columnAnnotation.exist()) {
                System.out.println("【调试】排除非数据库字段:" + field.getName());
                continue;
            }
​
            // b. 获取实体类属性名(如 "nickName")
            String property = field.getName();
​
            // c. 获取数据库字段名(优先 @Column.value(),默认与属性名一致)
            String column;
            if (columnAnnotation != null && !columnAnnotation.value().isEmpty()) {
                // 注解指定了字段名(如 @Column("nickname") → column="nickname")
                column = columnAnnotation.value();
            } else {
                // 未指定注解或 value 为空:使用属性名作为字段名(如 "nickName" → 默认 "nickName",需手动确保数据库字段名一致)
                column = property;
                // 【可选优化】添加驼峰转下划线(如 "nickName" → "nick_name"),需自行实现工具类
                // column = StringUtils.camelToUnderline(property);
            }
​
            // d. 解析主键(@MyTableId 注解)
            //field.isAnnotationPresent(MyTableId.class):检查当前循环到的字段(如 id 字段)是否被 @MyTableId 注解标记。w
            if (field.isAnnotationPresent(MyTableId.class)) {
                //主键用单独的变量进行记录
                this.primaryKey = column; // // 记录主键在数据库中的字段名(如 "id")
                this.primaryKeyProperty = property; // 记录主键在实体类中的属性名(如 "id")
            }
​
            // e. 添加到字段列表(仅包含 exist=true 的字段)
            this.columnList.add(column);
            this.propertyList.add(property);
            System.out.println("【调试】解析字段:属性=" + property + ",数据库字段=" + column);
        }
​
        // 校验:确保主键存在
        if (this.primaryKey == null) {
            throw new RuntimeException("实体类 " + entityClass.getName() + " 未标记 @MyTableId 主键");
        }
    }
                        // Getter/Setter 方法(保持不变)给上述全部属性加上getset方法
****************************************给上述全部属性加上getset方法**************************************************************
    }}

问题:为什么要有字段和属性,不是只需要保留?

获取全部字段,主要逻辑就是首先判断字段的colmus的注解是否为空,为空则继续下一次循环,下面的代码不参与,也不会添加到字段列表里面,为静态代码也将会排除。

field.getAnnotation(Column.class):获取字段上的 @Column 注解

field.getName():获取实体类的字段名(属性名)

下面的代码就是 实体类字段->数据库字段名 映射,如何映射呢,通过注解里面的内容,如果有内容就需要将对应的实体类字段替换为数据库字段名,注解里面有value的就进行替换,加入注解的值,如果没有value或者注释的则直接添加到column

这样我们就得到两个数组,一个是和数据库映射对应的column,另一个是实体类字段property

主键用单独的变量进行记录

sql生成器

方便proxy进行调用生成sql语句

%s为占位符,传入的参数为后面的参数

?为占位符,后面通过调用ps.setObject(1, args[0]);左边参数为第几个?(占位符),右边的为要传入的值

String.join拼接字符串

public class SqlGenerator {
    // 根据方法名和表信息生成 SQL
    public static String generateSql(String methodName, TableInfo tableInfo) {
        switch (methodName) {
            case "selectById":
                // 生成:SELECT id,username,... FROM user WHERE id = ?
                return String.format(
                        "SELECT %s FROM %s WHERE %s = ? ",
                        String.join(",", tableInfo.getColumnList()), // 所有字段
                        tableInfo.getTableName(),                     // 表名
                        tableInfo.getPrimaryKey()                     // 主键字段
                );
            case "insert":
                // 生成:INSERT INTO user (id,username,...) VALUES (?,?,...)
                return String.format(
                        "INSERT INTO %s (%s) VALUES (%s)",
                        tableInfo.getTableName(),
                        String.join(",", tableInfo.getColumnList()),
                        String.join(",", tableInfo.getColumnList().stream().map(c -> "?").toList()) // 参数占位符
                );
            case "updateById":
                // 生成:UPDATE user SET username=?,password=? WHERE id = ?
                String setClause = tableInfo.getColumnList().stream()
                        .map(column -> column + " = ?")
                        .collect(Collectors.joining(","));
                return String.format(
                        "UPDATE %s SET %s WHERE %s = ?",
                        tableInfo.getTableName(),
                        setClause,
                        tableInfo.getPrimaryKey()
                );
            case "deleteById":
                // 生成:DELETE FROM user WHERE id = ?
                return String.format(
                        "DELETE FROM %s WHERE %s = ?",
                        tableInfo.getTableName(),
                        tableInfo.getPrimaryKey()
                );
            default:
                throw new RuntimeException("不支持的方法名:" + methodName);
        }
    }
}

代理

看起来执行的是mapper调用Mapper方法,原理调用proxy内部的方法

public class MapperProxy<T> implements InvocationHandler {
    private Class<T> mapperInterface; // Mapper 接口(如 UserMapper.class)
    private Class<T> entityClass;     // 实体类(如 User.class)
    private TableInfo tableInfo;      // 表映射信息(从 Step 2 获取)
​
    // 构造方法:初始化 Mapper 接口和实体类信息
    public MapperProxy(Class<T> mapperInterface, Class<T> entityClass) {
        this.mapperInterface = mapperInterface;
        this.entityClass = entityClass;
        this.tableInfo = new TableInfo(entityClass); // 解析实体类元数据
    }
​
    // 创建代理对象(对外提供的入口方法)
    public T getProxy() {
        return (T) Proxy.newProxyInstance(
                mapperInterface.getClassLoader(),
                new Class[]{mapperInterface},
                this // 当前 InvocationHandler
        );
    }
​
    // 核心:拦截方法调用(当调用 userMapper.selectById 时执行)
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
        if (Object.class.equals(method.getDeclaringClass())) {
            // 调用 Object 类的默认方法(如 toString()、hashCode()),不进入代理逻辑
            return method.invoke(this, args); // 用当前 MapperProxy 实例作为调用者
        }
        String methodName = method.getName();
        // 1. 打印参数,排查是否接收到
        System.out.println("【调试】调用方法:" + methodName + ",参数:" + Arrays.toString(args));
        System.out.println(args[0]);
        // 2. 生成 SQL
        String sql = SqlGenerator.generateSql(methodName, tableInfo);
        System.out.println("[迷你ORM] 生成 SQL:" + sql);
        
        // 3. 从数据库连接(修复连接参数和资源关闭)
        try (
                Connection conn = DriverManager.getConnection(
                        "jdbc:mysql://localhost:3306/db_spzx?useSSL=false&serverTimezone=UTC&allowPublicKeyRetrieval=true"+ "&tinyInt1isBit=false",
                        //tinyInt1isBit=false,将
                        "root", "123456"
                );
                PreparedStatement ps = conn.prepareStatement(sql);
​
​
​
        ) {
            DatabaseMetaData metaData = conn.getMetaData();
            System.out.println("【调试】当前连接的数据库:" + metaData.getDatabaseProductName()
                    + ",数据库名:" + conn.getCatalog());
            // 4. 设置参数(修复强转问题)
            if ("selectById".equals(methodName) || "deleteById".equals(methodName)) {
                if (args == null || args.length == 0) {
                    throw new IllegalArgumentException(methodName + "方法必须传入参数(如 id)");
                }
                ps.setObject(1, args[0]); // 关键:使用 setObject 适配所有类型
            } else if ("insert".equals(methodName)) {
                // insert 方法参数处理(省略,参考原代码)
            }
        //  --------------------------将 ResultSet(数据库查询结果集)中的数据通过反射赋值给实体类对象(如 UserInfo)--------------------------
            // 5. 执行 SQL 并处理结果
            if ("selectById".equals(methodName)) {
            //ps.executeQuery():执行 SQL 查询,返回 ResultSet(数据库结果集,可理解为“一行数据的容器”)。
                try (ResultSet rs = ps.executeQuery()) {
                    if (rs.next()) {
                        // 1. 创建实体类对象(空对象)
                        T entity = entityClass.getDeclaredConstructor().newInstance();
​
                        // -------------------------- 2. 反射映射字段值(必须添加!) --------------------------
                        for (int i = 0; i < tableInfo.getPropertyList().size(); i++) {
                            // a. 获取实体类属性名(如 "id"、"username"、"nickName")
                            String property = tableInfo.getPropertyList().get(i);
​
                            // b. 获取数据库字段名(如 "id"、"username"、"nickname",已通过 @Column 注解修正)
                            String column = tableInfo.getColumnList().get(i);
​
                            // c. 获取实体类的 Field 对象(反射访问字段)
             //entityClass.getDeclaredField(property):通过反射获取实体类中 属性名对应的 Field 对象(可理解为“字段的句柄”,用于后续赋值)。
             //field.setAccessible(true):允许访问 private 修饰的字段(如 UserInfo 中的 private Long id)
             //field 对象是 实体类中对应属性的“元数据描述对象”,它存储了该字段的 名称、类型、访问修饰符 等信息,不是具体的值
                            Field field = entityClass.getDeclaredField(property);
                            field.setAccessible(true); // 允许访问 private 字段
​
                            // d. 从 ResultSet 中获取数据库字段的值(根据 column 字段名)
                            Object value = rs.getObject(column); // 例如:rs.getObject("nickname") → "test"
​
                            // e. 将值设置到实体类对象中(例如:entity.setNickName("test"))
                            field.set(entity, value);
​
                            // 调试日志:验证映射是否成功(必须添加,查看字段是否有值)
                            System.out.println("【调试】映射字段:" + property + " = " + value);
                    }
                        return entity;
                    }
             else {
​
                return ps.executeUpdate();
            }
        } catch (SQLException e) {
            System.err.println("【SQL 执行失败】" + e.getMessage());
            throw e; // 抛出异常,让上层感知
        }
​
​
    }}
        return null;
    }}

代理工厂(Spring 整合 ORM 的“桥梁” )

方便对于不同的Mapper进行生成prox对象,通过 Spring 的 FactoryBean 机制,动态创建 Mapper 接口的代理对象

有了 MapperFactoryBean 后,Spring 会 自动管理代理对象的生命周期

  1. 扫描发现:自定义扫描器发现 UserMapper 接口,创建 MapperFactoryBean 实例(传入 UserMapper.classUser.class)。
  2. 容器注册:Spring 容器将 MapperFactoryBean 注册为“工厂Bean”。
  3. 按需创建:当其他 Bean(如 UserController)依赖 UserMapper 时,Spring 调用 MapperFactoryBean.getObject() 生成代理对象并注入。
public class MapperFactoryBean<T> implements FactoryBean<T> {
    private Class<T> mapperInterface; // Mapper 接口(如 UserMapper.class)
    private Class<T> entityClass;     // 实体类(如 User.class)
​
    // 构造方法:传入 Mapper 接口和实体类(由扫描器注入)
    public MapperFactoryBean(Class<T> mapperInterface, Class<T> entityClass) {
        this.mapperInterface = mapperInterface;
        this.entityClass = entityClass;
    }
​
    // 核心:创建 Mapper 代理对象(通过 MapperProxy)
    @Override
    public T getObject() throws Exception {
        // 委托给 MapperProxy 创建代理对象(解耦代理逻辑)
        return new MapperProxy<>(mapperInterface, entityClass).getProxy();
    }
​
    // 返回 Mapper 接口类型(Spring 需要知道 Bean 的类型)
    @Override
    public Class<?> getObjectType() {
        return mapperInterface;
    }
​
    // 单例模式(默认)
    @Override
    public boolean isSingleton() {
        return true;
    }
}

注册器

注释

添加类似MapperScan的注解,后期也要添加到SpringAppliationBoot启动类上,使得能够扫描到对应路径下的mapper包

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Import(MapperScanner.class) // 导入扫描注册器
public @interface MyMapperScan {
​
    @AliasFor("basePackages")
    String[] value() default {};
​
    @AliasFor("value")
    String[] basePackages() default {};
}

扫描注册器((扫描Mapper)+()+()())

public class MapperScanner implements ImportBeanDefinitionRegistrar {
    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        System.out.println("[自定义ORM] MapperScanner 开始执行扫描...");
​
        // 1. 获取 @MyMapperScan 注解属性
        AnnotationAttributes attributes = AnnotationAttributes.fromMap(
                importingClassMetadata.getAnnotationAttributes(MyMapperScan.class.getName())
        );
        if (attributes == null) {
            throw new RuntimeException("未找到 @MyMapperScan 注解");
        }
​
        // 2. 解析 basePackages
        String[] basePackages = attributes.getStringArray("basePackages");
        if (basePackages.length == 0) {
            basePackages = new String[]{ClassUtils.getPackageName(importingClassMetadata.getClassName())};
        }
        System.out.println("[自定义ORM] 扫描路径:" + String.join(",", basePackages));
​
        // 3. 扫描并注册 Mapper 接口
        scanAndRegisterMappers(basePackages, registry);
    }
    //扫描和注册,调用下面的两个方法
    private void scanAndRegisterMappers(String[] basePackages, BeanDefinitionRegistry registry) {
        for (String basePackage : basePackages) {
            Set<Class<?>> mapperInterfaces = scanMapperInterfaces(basePackage);
            for (Class<?> mapperInterface : mapperInterfaces) {
                registerMapperBean(mapperInterface, registry);
            }
        }
    }
​
    // 核心修复:补全扫描逻辑,查找继承 MyBaseMapper 的接口
    private Set<Class<?>> scanMapperInterfaces(String basePackage) {
        Set<Class<?>> mappers = new HashSet<>();
        ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(resolver);
​
        try {
        //构建扫描路径模式
        // 示例:basePackage = "com.atiguigu.mapper" → 
        // pattern = "classpath*:com/atiguigu/mapper/**/*.class"(扫描该路径下所有 .class 文件)
        //CLASSPATH_ALL_URL_PREFIX:固定前缀 classpath*:,表示扫描所有类路径下的资源。
        //basePackage.replace(".", "/"):将包名的.替换为/(如 com.atiguigu.mapper → com/atiguigu/mapper)。
        //  /**/*.class:** 表示递归子包,*.class 匹配所有类文件。
            String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
                    + basePackage.replace(".", "/")
                    + "/**/*.class";
                    
            Resource[] resources = resolver.getResources(pattern); // 扫描所有 .class 文件
​
            for (Resource resource : resources) {
             // a. 读取类元数据(类名、接口信息等)
                MetadataReader metadataReader = metadataReaderFactory.getMetadataReader(resource);
                // 类名(如 "com.atiguigu.mapper.UserMapper")
                Stri/ng className = metadataReader.getClassMetadata().getClassName();
                Class<?> clazz = Class.forName(className);
​
                // 筛选:必须是接口,且继承 MyBaseMapper
                if (clazz.isInterface() && isMyBaseMapper(clazz)) {
                    mappers.add(clazz);
                    System.out.println("[自定义ORM] 扫描到 Mapper 接口:" + className);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("扫描 Mapper 接口失败:" + basePackage, e);
        }
        return mappers;
    }
​
    // 判断接口是否继承 MyBaseMapper
    private boolean isMyBaseMapper(Class<?> clazz) {
    // 类名(如 "com.atiguigu.mapper.UserMapper")
    // 获取当前接口的直接父接口
        Class<?>[] interfaces = clazz.getInterfaces();
        for (Class<?> iface : interfaces) {
            if (iface == MyBaseMapper.class) { // 直接继承 MyBaseMapper
                return true;
            }
            if (isMyBaseMapper(iface)) { // 递归检查父接口(如 A extends B, B extends MyBaseMapper → A 符合条件)
                return true;
            }
        }
        return false;
    }
​
    private void registerMapperBean(Class<?> mapperInterface, BeanDefinitionRegistry registry) {
        try {
    //************作用:通过反射获取 Mapper 接口的泛型实体类(如 UserMapper 对应 User.class),后续生成 SQL 需要实体类信息。*********
            // 解析实体类(MyBaseMapper<T> 中的 T)
            // 假设 mapperInterface 是 UserMapper.class,且 UserMapper extends MyBaseMapper<User>
            ParameterizedType genericInterface = (ParameterizedType) mapperInterface.getGenericInterfaces()[0];
            // genericInterface:MyBaseMapper<User>(带泛型参数的父接口)
            Class<?> entityClass = (Class<?>) genericInterface.getActualTypeArguments()[0];
            // getActualTypeArguments()[0]:获取泛型参数 T → User.class(实体类)//作用:BeanDefinition 是 Spring 中描述 Bean 的“元数据”(类似“产品说明书”),告诉 Spring:
            //要创建的 Bean 类型是 MapperFactoryBean
            //创建时需要调用构造方法 new MapperFactoryBean(UserMapper.class, User.class)。
            GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
            beanDefinition.setBeanClass(MapperFactoryBean.class);
            beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(mapperInterface);// 第1个构造参数:UserMapper.class
            beanDefinition.getConstructorArgumentValues().addGenericArgumentValue(entityClass);// 第2个构造参数:User.classs实体类
​
            // 注册到 Spring 容器
            String beanName = ClassUtils.getShortNameAsProperty(mapperInterface);
            //通过注册器将其注册为spring容器
            registry.registerBeanDefinition(beanName, beanDefinition);
            System.out.println("[自定义ORM] 注册 Mapper Bean:" + beanName + "(接口:" + mapperInterface.getName() + ")");
        } catch (Exception e) {
            throw new RuntimeException("注册 Mapper 接口失败:" + mapperInterface.getName(), e);
        }
    }
}