手撸一个 数据库表转实体类工具 代替Mybatis Generator

1,044 阅读5分钟

假设数据表已经生成并且表结构为:

DROP TABLE IF EXISTS `tb_user`;
CREATE TABLE `tb_user`  (
  `id` bigint(20) UNSIGNED NOT NULL COMMENT '主键id',
  `username` varchar(64) NOT NULL COMMENT '用户名,唯一',
  `password` varchar(255) NOT NULL COMMENT '密码',
  PRIMARY KEY (`id`) USING BTREE,
  UNIQUE INDEX `id`(`id`) USING BTREE,
  UNIQUE INDEX `username`(`username`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COMMENT = '用户表';

现在想要映射数据库表到Java实体类,我本来用的是Mybatis Generator,效果是这样:

package com.cc.model.entity;

import java.util.Date;
import java.io.Serializable;
import javax.persistence.Entity;
import javax.persistence.Table;
import javax.persistence.Column;
import javax.persistence.Id;

@Entity
@Table(name = "tb_user")
public class TbUserMember implements Serializable {
    @Id
    @Column(name = "id")
    private Long id;

    @Column(name = "username")
    private String username;

    @Column(name = "password")
    private String password;

    public Long getId() {
        return id;
    }
    public void setId(Long id) {
        this.id = id;
    }

    public String getUsername() {
        return username;
    }
    public void setUsername(String username) {
        this.username = username;
    }

    public String getPassword() {
        return password;
    }
    public void setPassword(String password) {
        this.password = password;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName());
        sb.append(" [");
        sb.append("Hash = ").append(hashCode());
        sb.append(", id=").append(id);
        sb.append(", username=").append(username);
        sb.append(", password=").append(password);
        sb.append("]");
        return sb.toString();
    }
}

我希望在这个的基础上可以增加:

  • @ApiModelProperty注解
  • 显示字段更详细的信息,如username字段的:varchar(64) NOT NULL COMMENT '用户名,唯一'
  • 更多的定制性需求

但我没有找到Mybaits Generator可以实现以上需求的方法,也可能并没有,于是我手撸了一个映射工具类,在我看来,相比起Mybatis Generator,这个工具类有以下的优点:

  • 单类,使用方便
  • 源码简单,入门时就接触过
  • 定制性高

源码

话不多说,直接上源码:

package com.cc.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.sql.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 数据库表转换实体类工具,将数据库已存在的表,通过读取字段信息的方式来生成Java实体类。
 *
 * 使用说明:
 * 1. 填写正确的数据库配置
 * 2. 按需设置是否需要加上@Entity和Swagger的相关内容
 * 3. 设置alias映射指定的数据库表到实体类
 * 4. 新建一个model包,拷贝出绝对路径给到savePath
 * 5. 设置model包的packageName
 * 6. 启动即可
 *
 * @author chen
 * @date 2021-06-10 14:10
 */
public class DatabaseConvertUtil {
    private static final String DRIVER = "com.mysql.cj.jdbc.Driver";
    private static final String URL = "jdbc:mysql://myserverhost:3306/mymall?characterEncoding=UTF-8&useSSL=false&serverTimezone=UTC";
    private static final String USERNAME = "root";
    private static final String PASSWORD = "root";

    private final static Logger log = LoggerFactory.getLogger(DatabaseConvertUtil.class);

    static private boolean doEntity = false;    // 添加hibernate实体类相关内容
    static private boolean doSwagger = false;   // 添加swagger相关内容

    static private boolean doDao = false;   // 顺带生成tk.mybatis的dao类
    static private String daoPackageName;   // dao存放的包名
    static private String daoSavePath;  // dao保存路径

    // 数据库表名映射实体类名,默认为首字母大写驼峰,如:tb_user -> TbUser
    static private Map<String, String> alias = new HashMap<>();
    static private String entityPackageName;  // 存放的包名
    static private String entitySavePath; // 存放路径
    static private String fileSuffix = ".java"; // 生成的文件后缀,调试时可以改成.txt看效果

    static private Map<String, String> importMap;
    static private Map<String, String> typeMap;

    public static void main(String[] args) {
        doEntity = true;
        doSwagger = true;
        doDao = false;

        entityPackageName = "com.cc.model.entity";
        daoPackageName = "com.cc.dao";

        daoSavePath = "E:\\xxx\\src\\main\\java\\com\\cc\\dao";

        entitySavePath = "E:\\xxx\\src\\main\\java\\com\\cc\\model\\entity";

        alias.put("tb_user_member", "TbUserMember");
        alias.put("tb_user_admin", "TbUserAdmin");
        alias.put("tb_user_resource", "TbUserResource");
        alias.put("tb_user_role", "TbUserRole");
        alias.put("tb_user_role_relation", "TbUserRoleRelation");
        alias.put("tb_user_role_resource_relation", "TbUserRoleResourceRelation");
        alias.put("tb_user_menu", "TbUserMenu");

        System.out.println("开始进行映射...");
        generatorEntityClass();
        System.out.println("数据库表映射到实体类完成");
    }

    static {
        try {
            // 加载驱动
            Class.forName(DRIVER);

            // 初始化,不够条件的时候在这里加
            // 映射变量类型需要导入的依赖
            importMap = new HashMap<>();
            importMap.put("Date", "import java.util.Date;");
            importMap.put("BigDecimal", "import java.math.BigDecimal;");

            // 映射数据库表类型到Java变量类型
            typeMap = new HashMap<>();
            typeMap.put("CHAR", "String");
            typeMap.put("VARCHAR", "String");
            typeMap.put("LONGVARCHAR", "String");
            typeMap.put("NUMERIC", "BigDecimal");
            typeMap.put("DECIMAL", "BigDecimal");
            typeMap.put("BIT", "Boolean");
            typeMap.put("BOOLEAN", "Boolean");
            typeMap.put("TINYINT", "Byte");
            typeMap.put("SMALLINT", "Short");
            typeMap.put("INTEGER", "Integer");
            typeMap.put("INT", "Integer");
            typeMap.put("BIGINT", "Long");
            typeMap.put("REAL", "Float");
            typeMap.put("FLOAT", "Double");
            typeMap.put("DATE", "Date");
            typeMap.put("DATETIME", "Date");
            typeMap.put("TIME", "Time");
            typeMap.put("TIMESTAMP", "Timestamp");
        } catch (ClassNotFoundException e) {
            log.error("can not load jdbc driver", e);
        }
    }

    public static Map<String, String> getImportMap() {
        return importMap;
    }

    public static Map<String, String> getTypeMap() {
        return typeMap;
    }

    // 获取数据库连接
    public static Connection getConnection() {
        Connection conn = null;
        try {
            conn = DriverManager.getConnection(URL, USERNAME, PASSWORD);
        } catch (SQLException e) {
            log.error("get connection failure", e);
        }
        return conn;
    }

    // 关闭数据库连接
    public static void closeConnection(Connection conn) {
        if(conn != null) {
            try {
                conn.close();
            } catch (SQLException e) {
                log.error("close connection failure", e);
            }
        }
    }

    // 开始转换
    public static void generatorEntityClass() {
        for (Map.Entry<String, String> entry : alias.entrySet()) {
            String tableName = entry.getKey();  // 表名
            String className = entry.getValue();    // 类名

            System.out.printf("正在映射表:%s 到实体类:%s%n", tableName, className);

            if (doDao) {
                try {
                    String interfaceName = className.substring(2, className.length()) + "Dao";
                    String path = daoSavePath + "\\" + interfaceName + fileSuffix;
                    BufferedWriter out = new BufferedWriter(new FileWriter(path));
                    out.write(String.format("package %s;\n", daoPackageName));
                    out.write("\n");
                    out.write(String.format("import %s.%s;\n", entityPackageName, className));
                    out.write("import org.springframework.stereotype.Repository;\n");
                    out.write("import tk.mybatis.mapper.common.Mapper;\n");
                    out.write("\n");
                    out.write("@Repository\n");
                    out.write(String.format("public interface %s extends Mapper<%s> {\n", interfaceName, className));
                    out.write("}\n");
                    out.close();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }

            Connection connection = getConnection();
            PreparedStatement preparedStatement;
            String sql = "SELECT * FROM " + tableName;
            ResultSet rs = null;
            try {
                preparedStatement = connection.prepareStatement(sql);
                rs = preparedStatement.executeQuery("show full columns from " + tableName); // 获取表所有的列信息

                Set<String> importStrArray = new HashSet<>();   // import xxx
                List<String> fieldStrArray = new ArrayList<>(); // private Long id;
                List<String> getterAndSetterArray = new ArrayList<>();  // getter and setter
                List<String> fields = new ArrayList<>();    // 所有字段
                boolean hasId = false;
                while (rs.next()) {
                    String fieldStr = getFieldStr(rs);
                    String _field = rs.getString("Field");
                    String _key = rs.getString("Key");

                    String importStr = getImportStr(rs);
                    if (importStr != null) {
                        importStrArray.add(importStr);
                    }

                    fieldStrArray.add((fieldStr == null) ? _field + " 解析失败" : fieldStr);
                    getterAndSetterArray.add(getterAndSetter(rs));
                    fields.add(HumpLineUtil.lineToHump(_field));

                    if ("PRI".equals(_key)) {
                        hasId = true;
                    }
                }

                String path = entitySavePath + "\\" + className + fileSuffix;
                BufferedWriter out = new BufferedWriter(new FileWriter(path));

                // 输出依赖包
                out.write("package " + entityPackageName + ";\n\n");

                // 输出引用依赖
                for (String s : importStrArray) {
                    out.write(s);
                    out.write("\n");
                }

                // 添加Entity、Table、Id、Column依赖
                out.write("import java.io.Serializable;\n");

                if (doEntity) {
                    out.write("import javax.persistence.Entity;\n");
                    out.write("import javax.persistence.Table;\n");
                    out.write("import javax.persistence.Column;\n");
                    if (hasId) {
                        out.write("import javax.persistence.Id;\n");
                    }
                }
                if (doSwagger) {
                    out.write("import io.swagger.annotations.ApiModelProperty;\n");
                }
                out.write("\n");

                // 输出类声明
                if (doEntity) {
                    out.write("@Entity\n");
                    out.write("@Table(name = \"" + tableName + "\")\n");
                }
                out.write("public class " + className + " implements Serializable {");
                out.write("\n");

                // 输出属性变量
                for (String s : fieldStrArray) {
                    out.write(s);
                    out.write("\n");
                }
                out.write("    private static final long serialVersionUID = 1L;\n\n");

                // 输出getter and setter方法
                for (String s : getterAndSetterArray) {
                    out.write(s);
                    out.write("\n");
                }

                out.write(getToString(fields));

                // 收尾
                out.write("}");
                out.close();
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (rs != null) {
                    try {
                        rs.close();
                        closeConnection(connection);
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    // 获取该属性需要import的包
    static String getImportStr(ResultSet rs) throws SQLException {
        String _type = rs.getString("Type");

        Map<String, String> map = getImportMap();

        String type = getType(_type);
        String str = null;

        for (Map.Entry<String, String> entry : map.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            if (key.equalsIgnoreCase(type)) {    // 化为小写方便比较
                str = value;
                break;
            }
        }
        return str;
    }

    // 属性变量声明,如: private Long id;
    static String getFieldStr(ResultSet rs) throws SQLException {
        String _field = rs.getString("Field");
        String _type = rs.getString("Type");
        String _key = rs.getString("Key");
        String _comment = rs.getString("Comment");

        String str = "";

        if (doEntity) {
            if ("PRI".equals(_key)) {
                str += "    @Id\n";
            }
            str += "    @Column(name = \"" + _field + "\", columnDefinition = \"" + getColumnDefinition(rs) + "\")\n";
        }
        if (doSwagger) {
            str += "    @ApiModelProperty(value = \"" + _comment + "\")\n";
        }

        String type = getType(_type);
        if (type == null) {
            return null;
        }
        _field = HumpLineUtil.lineToHump(_field); // 下划线转驼峰
        str += "    private " + type + " " + _field + ";\n";
        return str;
    }

    // 类属性的getter and setter方法
    static String getterAndSetter(ResultSet rs) throws SQLException {
        String _field = rs.getString("Field");
        String _type = rs.getString("Type");

        String field = HumpLineUtil.lineToHump(_field); // 下划线转驼峰
        String type = getType(_type);

        String str = "";
        str += "    public " + type + " get" + upperFirstLetter(field) + "() {\n";
        str += "        return " + field + ";\n";
        str += "    }\n";

        str += "    public void" + " set" + upperFirstLetter(field) + "(" + type + " " + field + ") {\n";
        str += "        this." + field + " = " + field + ";\n";
        str += "    }\n";

        return str;
    }

    // toString() 方法
    static String getToString(List<String> fields) {
        StringBuilder str = new StringBuilder();
        str.append("    @Override\n");
        str.append("    public String toString() {\n");
        str.append("        StringBuilder sb = new StringBuilder();\n");
        str.append("        sb.append(getClass().getSimpleName());\n");
        str.append("        sb.append(\" [\");\n");
        str.append("        sb.append(\"Hash = \").append(hashCode());\n");

        for (String field : fields) {
            str.append("        sb.append(\", ").append(field).append("=\").append(").append(field).append(");\n");
        }

        str.append("        sb.append(\", serialVersionUID=\").append(serialVersionUID);\n");
        str.append("        sb.append(\"]\");\n");
        str.append("        return sb.toString();\n");
        str.append("    }\n");
        return str.toString();
    }

    // 获取属性的变量类型,如:Long
    private static String getType(String _type) {
        String[] a = _type.split("\\("); // 去掉小括号以及括号内的字符串
        _type = a[0];

        Map<String, String> map = getTypeMap();

        String type = null;
        for (Map.Entry<String, String> entry : map.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            if (key.equalsIgnoreCase(_type)) {
                type = value;
                break;
            }
        }
        return type;
    }

    // 字符串首字母大写
    private static String upperFirstLetter(String name) {
        char[] cs = name.toCharArray();
        cs[0] -= 32;
        return String.valueOf(cs);
    }

    // 获取@Column(name = "username", columnDefinition = "")中的columnDefinition内容
    static String getColumnDefinition(ResultSet rs) throws SQLException {
        String _type = rs.getString("Type");
        String _null = rs.getString("Null");
        String _key = rs.getString("Key");
        String _default = rs.getString("Default");
        String _extra = rs.getString("Extra");
        String _comment = rs.getString("Comment");

        String columnDefinition = _type.toUpperCase();

        if ("UNI".equals(_key)) {
            columnDefinition += " UNIQUE";
        }
        if ("NO".equals(_null)) {
            columnDefinition += " NOT NULL";
        }
        if ("".equals(_default)) {
            columnDefinition += " DEFAULT ''";
        } else if (_default != null) {
            columnDefinition += " DEFAULT " + _default.toUpperCase();
        }

        if (!"".equals(_extra)) {
            columnDefinition += " " + _extra;
        }

        if (!"".equals(_comment)) {
            columnDefinition += " COMMENT '" + _comment + "'";
        }
        return columnDefinition;
    }

    /**
     * 驼峰-下划线转换工具,源自:https://blog.csdn.net/turbo_sky/article/details/84814518
     */
    public static class HumpLineUtil {
        private static Pattern linePattern = Pattern.compile("(_)(\\w)");

        /**
         * 驼峰转下划线,示例:TbUser -> _tb_user
         * @param str
         * @return
         */
        public static String humpToLine(String str){
            return str.replaceAll("[A-Z]", "_$0").toLowerCase();
        }

        /**
         * 下划线转驼峰,示例:tb_user -> tbUser
         * @param str
         * @return
         */
        public static String lineToHump(String str){
            Matcher matcher = linePattern.matcher(str);
            StringBuffer sb = new StringBuffer();
            while(matcher.find()){
                matcher.appendReplacement(sb, matcher.group(2).toUpperCase());
            }
            matcher.appendTail(sb);
            return sb.toString();
        }
    }
}

使用说明

需要手动指定实体类存放的绝对路径,然后声明要不要Entity注解、Swagger注解

在上面我新增了一个daoSavePath,即生成TkMybatis的dao层,这个体现出了我想要的定制性。