peewee自动生成mysql表对应的model工具

639 阅读7分钟

1. Peewee 框架介绍

Peewee是一个Python编程语言下的轻量级ORM(对象关系映射)框架,它提供了简单易用的API来简化与数据库的交互。ORM框架允许开发者使用面向对象的方式来操作数据库,而不需要直接编写SQL语句。

Peewee支持多种数据库后端,包括SQLite、MySQL和PostgreSQL等,因此可以适用于各种项目的需求。它提供了一个简洁的模型定义语法,使得开发者可以快速创建模型类来映射数据库中的表结构。

Peewee还提供了丰富的查询API,可以轻松地执行复杂的数据库查询操作。此外,它还支持数据的增、删、改操作,并且具有优秀的性能表现。

2. 使用docker安装mysql

为了验证 peewee 框架的使用,这里使用 docker 安装一个 mysql 服务

  1. 在服务器上安装Docker
    首先,确保服务器上安装有Docker。如果未安装Docker,请按照官方文档的指引进行安装。

  2. 拉取MySQL镜像
    使用以下命令从Docker Hub拉取MySQL镜像到本地:

    docker pull mysql

安装好 Docker 后,其 registry server 是默认指向 https://hub.docker.com 的。在国内该hub源访问速度异常慢,尤其是大一点的镜像经常出现timeout。我们可以通过切换至国内镜像仓库来解决这一问题:

docker pull docker.m.daocloud.io/library/mysql

docker 镜像重命名

要重命名一个Docker镜像,可以使用以下命令:

docker tag old_image_name:old_tag new_image_name:new_tag

其中,old_image_name是旧镜像的名称,old_tag是旧镜像的标签,new_image_name是新镜像的名称,new_tag是新镜像的标签。

例如,如果要将一个名为docker.m.daocloud.io/library/mysql的镜像重命名为mysql,可以使用以下命令:

docker tag docker.m.daocloud.io/library/mysql:latest mysql:latest

这样就成功将ubuntu:latest镜像重命名为mysql:latest

  1. 启动MySQL容器
    接下来,使用以下命令启动一个MySQL容器,并且映射外部访问端口:

    docker run --name my-mysql -e MYSQL_ROOT_PASSWORD=my-secret-pw -p 3306:3306 -d mysql:latest

其中,some-mysql是容器的名称,-e MYSQL_ROOT_PASSWORD=my-secret-pw是设置MySQL的root密码,-p 3306:3306映射容器的3306端口到主机的3306端口,-d参数表示以守护进程方式运行。

  1. 外部访问MySQL
    现在,通过主机的IP地址和端口号3306就可以访问MySQL数据库了。

如果是在本地环境,可以使用localhost:3306进行访问;如果是在其他机器上,则需要使用服务器的公网IP地址和端口号进行访问。

注意:安全起见应该配置MySQL进行身份验证,例如设置不允许root用户远程访问。

执行下面的命令进入 mysql 容器中

docker exec -it my-mysql bash

执行下面命令在mysql容器中创建一个数据库:

create database my_database;

创建一个表,由于下面的测试:

CREATE TABLE `approver` (
  `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
  `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
  `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
  `work_no` varchar(32) NOT NULL comment '员工工号',
  `name` varchar(32) DEFAULT NULL comment '员工花名',
  `biz_line_scope` blob DEFAULT NULL comment '业务线范围',
  `org_scope` text DEFAULT NULL comment '组织架构范围',
  `biz_line_process_type` varchar(256) DEFAULT NULL comment '业务线流程类型',
  `org_process_type` varchar(256) DEFAULT NULL comment '组织架构业务类型',
  `creator` varchar(16) NOT NULL comment '创建人工号',
  `modifier` varchar(16) DEFAULT NULL comment '修改人工号',
  PRIMARY KEY(`id`),
  KEY `idx_work_no`(`work_no`)
) AUTO_INCREMENT = 100001 DEFAULT CHARSET = utf8mb4 COMMENT = '审批人员配置表';

3. 使用pwiz自动生成model

使用 peewee 根据 mysql 的表生成对应的 model,可以按照以下步骤:

Step 1: 安装 peewee
首先需要安装 peewee,可以通过 pip 来安装:

pip install peewee

Step 2: 生成 model
然后,根据 mysql 的表生成对应的 model,可以使用 peewee 内置脚本 pwiz:

在命令行中输入: python -m pwiz 可以查看 pwiz 的使用参数说明,如下所示:

 ~ % python -m pwiz
Missing required parameter "database"
Usage: pwiz.py [options] database_name

Options:
  -h, --help            show this help message and exit
  -H HOST, --host=HOST  
  -p PORT, --port=PORT  
  -u USER, --user=USER  
  -P, --password        
  -e ENGINE, --engine=ENGINE
                        Database type, e.g. sqlite, mysql, postgresql or
                        cockroachdb. Default is "postgresql".
  -s SCHEMA, --schema=SCHEMA
  -t TABLES, --tables=TABLES
                        Only generate the specified tables. Multiple table
                        names should be separated by commas.
  -v, --views           Generate model classes for VIEWs in addition to
                        tables.
  -i, --info            Add database information and other metadata to top of
                        the generated file.
  -o, --preserve-order  Model definition column ordering matches source table.
  -I, --ignore-unknown  Ignore fields whose type cannot be determined.
  -L, --legacy-naming   Use legacy table- and column-name generation.

对于mysql数据库来说,可以使用下面的命令生成对应数据表对应的model,默认生成的代码会直接输出到控制台:

% python -m pwiz -e mysql -H 127.0.0.1 -p 3306 -u root -P -o my_database  
Password: 
from peewee import *

database = MySQLDatabase('my_database', **{'charset': 'utf8', 'sql_mode': 'PIPES_AS_CONCAT', 'use_unicode': True, 'host': '127.0.0.1', 'port': 3306, 'user': 'root', 'password': 'my-secret-pw'})

class UnknownField(object):
    def __init__(self, *_, **__): pass

class BaseModel(Model):
    class Meta:
        database = database

class Approver(BaseModel):
    id = BigAutoField()
    gmt_create = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')])
    gmt_modified = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')])
    work_no = CharField(index=True)
    name = CharField(null=True)
    biz_line_scope = TextField(null=True)
    org_scope = TextField(null=True)
    biz_line_process_type = CharField(null=True)
    org_process_type = CharField(null=True)
    creator = CharField()
    modifier = CharField(null=True)

    class Meta:
        table_name = 'approver'

执行命令 python -m pwiz -e mysql -H 127.0.0.1 -p 3306 -u root -P -o database_name 需要输入数据库的秘密,如果需要将生成的内容输出到文件可以执行下面的命令:

$ python -m pwiz -e mysql -H 127.0.0.1 -p 3306 -u root -P -o my_database > models.py

这会生成一个名为 models.py 的文件,里面包含了根据 mysql 的表自动生成的 model。

Step 4: 使用 model
最后,可以在代码中使用生成的 model 来操作 mysql 的表数据,例如:

from models import MyTable

# 查询数据
data = MyTable.select()

# 插入数据
MyTable.create(field1='value1', field2='value2')

# 更新数据
query = MyTable.update(field1='new_value').where(MyTable.field2 == 'value2')
query.execute()

# 删除数据
query = MyTable.delete().where(MyTable.field1 == 'value1')
query.execute()

通过上述步骤,就可以使用 peewee 根据 mysql 的表生成对应的 model,并且对 mysql 的表数据进行操作。

4. 自定义工具生成model

使用peewee中自带的pwiz工具类,可以自动生成mysql表中对应的model,为什么要自定义工具类呢,因为pwiz默认生成的model类中没有对应的字段描述,所以为了将表的字段描述生成到model中,所以定义了一个gen_model.py工具,该工具的源代码如下,该工具是参考 pwiz 实现的:

import sys
from getpass import getpass
from optparse import OptionParser

import pymysql
from peewee import *

"""
Field types table参考文档: https://peewee.readthedocs.io/en/latest/peewee/models.html
"""

HEADER = """from peewee import *

database = MySQLDatabase('{database}', **{{'host': '{host}', 'port': {port}, 'user': '{user}', 'password': '{password}'}})

"""

BASE_MODEL = """\
class BaseModel(Model):
    class Meta:
        database = database
"""

UNKNOWN_FIELD = """\
class UnknownField(object):
    def __init__(self, *_, **__): pass
"""

# 数据库连接配置
config = {
    'user': 'root',  # 用户名
    'password': 'infini_rag_flow',  # 密码
    'host': '127.0.0.1',  # 访问地址
    'port': 5455,  # 端口号
    'database': 'rag_flow',  # 数据库名称
    'charset': 'utf8mb4',
    'cursorclass': pymysql.cursors.DictCursor
}

FILED_MAPPING = {
    'bigint': BigIntegerField,
    'varchar': CharField,
    'enum': CharField,
    'text': TextField,
    'datetime': DateTimeField,
    'timestamp': DateTimeField,
    'date': DateField,
    'time': TimeField,
    'float': FloatField,
    'double': DoubleField,
    'decimal': DecimalField,
    'tinyint': BooleanField,
    'mediumint': IntegerField,
    'smallint': IntegerField,
    'bigint unsigned': BigIntegerField,
    'int unsigned': IntegerField,
    'mediumint unsigned': IntegerField,
    'smallint unsigned': IntegerField,
    'tinyint unsigned': IntegerField,
    'bit': BitField,
    'blob': BlobField,
    'tinyblob': BlobField,
    'mediumblob': BlobField,
}

sql_temp = """
SELECT column_name, 
        column_type, 
        column_comment, 
        column_default, 
        is_nullable 
FROM information_schema.columns 
WHERE table_name = '%s'
"""


def err(msg):
    sys.stderr.write('\033[91m%s\033[0m\n' % msg)
    sys.stderr.flush()


def underline_to_camel(name):
    """
    下划线转驼峰
    :param name:
    :return:
    """
    parts = name.split('_')
    camel_name = ''.join(word.title() for word in parts)
    return camel_name


def transform_enum_string(s):
    """
    解析enum类型
    将字符串enum('N','Y') 转换为 ['N','Y']
    :param s:
    :return:
    """
    s = s.strip("enum()")
    s = s.split(',')
    s = [item.strip("'") for item in s]
    return s


def parse_field_type(column_type: str):
    """
    解析数据类型
    输入:bigint,输出:('bigint', None)
    输入:bigint(20) unsigned,输出:('bigint unsigned', 20)
    输入:varchar(65535),输出:('varchar', 65535)
    :param column_type:
    :return:
    """
    import re
    match = re.match(r'^(\w+)(?:\((\d+)\))?\s*(\w+)?$', column_type)
    if match:
        if match.group(3):
            field_type = match.group(1) + " " + match.group(3)
        else:
            field_type = match.group(1)
        if match.group(2):
            length = int(match.group(2))
        else:
            length = None
        result = {
            "field_type": field_type,
            "length": length
        }
        return result
    elif 'enum' in column_type:
        enum_values = transform_enum_string(column_type)
        result = {
            "field_type": 'enum',
            "enum_vale": enum_values
        }
        return result
    else:
        return {}


def print_model(table, rows):
    header = HEADER.format(**config)
    print(header)

    print(BASE_MODEL)

    print(UNKNOWN_FIELD)

    if not rows:
        return
    print('class %s(BaseModel):' % underline_to_camel(table))

    for row in rows:
        column_name = row['column_name']
        column_type = row['column_type']
        column_comment = row['column_comment']
        column_default = row['column_default']
        is_nullable = True if row['is_nullable'] == 'YES' else False

        if 'timestamp' in column_type and column_default == 'CURRENT_TIMESTAMP':
            column_default = 'None'
        elif column_default and type(column_default) == str:
            column_default = f"'{column_default}'"

        column_type = parse_field_type(column_type)

        if column_type:
            field_type = column_type['field_type']
            field = FILED_MAPPING.get(field_type)
            if field:
                if 'varchar' in field_type:
                    length = column_type['length']
                    print(
                        f"    {column_name} = {field.__name__}(null={is_nullable},column_name='{column_name}', max_length={length}, default={column_default}, help_text='{column_comment}')")
                elif 'enum' in field_type:
                    enum_values = column_type['enum_vale']
                    print(
                        f"    {column_name} = {field.__name__}(null={is_nullable},column_name='{column_name}', choices={enum_values}, default={column_default}, help_text='{column_comment}')")
                else:
                    print(
                        f"    {column_name} = {field.__name__}(null={is_nullable},column_name='{column_name}', default={column_default}, help_text='{column_comment}')")
    print('')
    print('    class Meta:')
    print('        table_name = \'%s\'' % table)
    print('')


# 执行查询的函数
def print_models(tables, connect):
    connection = pymysql.connect(**connect)
    try:
        # 使用 cursor() 方法创建一个游标对象 cursor
        with connection.cursor() as cursor:

            # 执行查询语句以获取所有表的名称
            cursor.execute("SHOW TABLES")
            # 获取结果集
            query_tables = cursor.fetchall()
            all_table_names = [table.get(f'Tables_in_{connect["database"]}') for table in query_tables]
            if tables:
                intersection = list(set(all_table_names) & set(tables))
            else:
                intersection = all_table_names
            if not intersection:
                err(f'指定的表{tables}不存在,可以选择的表如下:{all_table_names}, 不指定默认为所有表')
                sys.exit(1)

            for table_name in intersection:
                # 执行查询语句以获取表结构信息
                cursor.execute(sql_temp % table_name)
                # 获取结果集
                results = cursor.fetchall()
                print_model(table_name, results)

    except pymysql.MySQLError as e:
        err(f"Error: {e}")
    finally:
        # 关闭数据库连接
        connection.close()


def get_option_parser():
    parser = OptionParser(usage='usage: %prog [options] database_name')
    ao = parser.add_option
    ao('-H', '--host', dest='host')
    ao('-p', '--port', dest='port', type='int')
    ao('-u', '--user', dest='user')
    ao('-P', '--password', dest='password', action='store_true')
    ao('-t', '--tables', dest='tables',
       help=('Only generate the specified tables. Multiple table names should '
             'be separated by commas.'))
    return parser


if __name__ == '__main__':
    raw_argv = sys.argv

    parser = get_option_parser()
    options, args = parser.parse_args()

    if len(args) < 1:
        err('Missing required parameter "database"')
        parser.print_help()
        sys.exit(1)

    database = args[-1]
    config['database'] = database
    if options.host:
        config['host'] = options.host
    if options.port:
        config['port'] = options.port
    if options.user:
        config['user'] = options.user
    if options.password:
        config['password'] = getpass()

    tables = None
    if options.tables:
        tables = [table.strip() for table in options.tables.split(',')
                  if table.strip()]

    print_models(tables, config)

该工具的实现原理是,使用 pymysql 读取mysql的information_schema.columns表中对应表的字段信息,然后生成对应的model,索引要使用该工具需要安装 pymysql 执行命令:pip install pymysql

使用该工具可以指定的参数如下:

% python gen_model.py                   
Missing required parameter "database"
Usage: gen_model.py [options] database_name

Options:
  -h, --help            show this help message and exit
  -H HOST, --host=HOST  
  -p PORT, --port=PORT  
  -u USER, --user=USER  
  -P, --password        
  -t TABLES, --tables=TABLES
                        Only generate the specified tables. Multiple table
                        names should be separated by commas.

执行下面命令python gen_model.py --host localhost -p 3306 -u root -P my_database 生成前面创建的表对应的model:

 % python gen_model.py  --host localhost -p 3306 -u root -P my_database 
Password: 
from peewee import *

database = MySQLDatabase('my_database', **{'host': 'localhost', 'port': 3306, 'user': 'root', 'password': 'my-secret-pw'})


class BaseModel(Model):
    class Meta:
        database = database

class UnknownField(object):
    def __init__(self, *_, **__): pass

class Approver(BaseModel):
    biz_line_process_type = CharField(null=True,column_name='biz_line_process_type', max_length=256, default=None, help_text='业务线流程类型')
    biz_line_scope = CharField(null=True,column_name='biz_line_scope', max_length=1024, default=None, help_text='业务线范围')
    creator = CharField(null=False,column_name='creator', max_length=16, default=None, help_text='创建人工号')
    gmt_create = DateTimeField(null=False,column_name='gmt_create', default=None, help_text='创建时间')
    gmt_modified = DateTimeField(null=False,column_name='gmt_modified', default=None, help_text='修改时间')
    id = BigIntegerField(null=False,column_name='id', default=None, help_text='主键')
    modifier = CharField(null=True,column_name='modifier', max_length=16, default=None, help_text='修改人工号')
    name = CharField(null=True,column_name='name', max_length=32, default=None, help_text='员工花名')
    org_process_type = CharField(null=True,column_name='org_process_type', max_length=256, default=None, help_text='组织架构业务类型')
    org_scope = CharField(null=True,column_name='org_scope', max_length=1024, default=None, help_text='组织架构范围')
    work_no = CharField(null=False,column_name='work_no', max_length=32, default=None, help_text='员工工号')

    class Meta:
        table_name = 'approver'

通过自定义的工具类,生成的model包含了字段的注释信息和默认值。