FastAPI 集成 Tortoise-ORM 实践

3,176 阅读6分钟

1. 概述

前面写了一个篇文章 peewee自动生成mysql表对应的model工具 简单介绍了一下 peewee 的使用,本文介绍另一个 python 的 ORM 框架 Tortoise-ORM

通过本文的学习,你将有如下收获:

  • 了解 Tortoise-ORM 的作用
  • 了解 Tortoise-ORM 和 peewee 的优缺点
  • 了解 FastAPI 集成 Tortoise-ORM
  • 了解自定义生成 Tortoise Model 的工具

1.1. Tortoise-ORM 简介

Tortoise-ORM是一个Python异步ORM(对象关系映射)框架,专门用于异步web应用和微服务。它基于SQLAlchemy并使用asyncio库来处理异步请求。Tortoise-ORM提供了简单易用的API,允许开发人员使用异步方式将Python对象映射到数据库表中,并进行快速、高效的数据库操作。

Tortoise-ORM支持多种主流的关系数据库(如MySQL、PostgreSQL、SQLite等),同时也提供了类似Django的模型定义、自动迁移和查询构建功能。它还支持简单的CRUD操作、复杂的查询(包括聚合查询和连接查询)、事务处理等数据库操作。

除了数据库操作功能,Tortoise-ORM还提供了内置的数据验证、序列化和反序列化功能,简化了与数据库交互的过程。另外,它还可以与Python的异步web框架(如FastAPI和Sanic)很好地集成,使开发者能够轻松地构建高效的异步web应用和微服务。

总之,Tortoise-ORM是一个功能丰富、易于使用的异步ORM框架,它为开发人员提供了便捷的数据库操作和高效的异步支持,使他们能够更加轻松地构建复杂的异步web应用和微服务。

1.2. Tortoise-ORM 和 peewee

1.2.1. Tortoise ORM

优点:

  • 异步IO支持,适合高并发的异步应用
  • 与FastAPI等异步框架集成良好
  • 灵活性较高,支持多种数据库,包括异步数据库

缺点:

  • 相对较新,生态相对不够完善
  • 学习资源和文档相对较少

技术选型方案:
适合于需要异步IO支持的高并发应用,或者与FastAPI等异步框架集成的场景。

1.2.2. Peewee

优点:

  • 简单易用,学习成本低
  • 对小型项目和原型快速开发友好
  • 轻量级框架,适用于资源受限的应用

缺点:

  • 功能相对较少,不适合复杂的数据库操作
  • 对大型数据库和高并发场景支持不足

技术选型方案:
适合于小型应用或者需要快速开发的场景。

2. FastAPI 集成 Tortoise-ORM 实践

接下来介绍一下 FastAPI 如何集成 Tortoise-ORM。

步骤一:安装相关依赖

  • fastapi:FastAPI是一个现代的、快速的Python web框架,用于构建高性能的API。

  • tortoise-orm:tortoise-orm是一个基于异步IO的Python对象关系映射框架,专注于提供高性能和易用性。

  • aiomysql: aiomysql是一个基于asyncio的Python异步MySQL操作库。

    pip install fastapi tortoise-orm aiomysql

步骤二:使用 RegisterTortoise集成 FastAPI 和 Tortoise-ORM,实现代码如下:

from tortoise.contrib.fastapi import RegisterTortoise
from fastapi import FastAPI
import uvicorn
from contextlib import asynccontextmanager
from tortoise.fields import *
from tortoise import Model

mysql_config = {
    'connections': {
        # Dict format for connection
        'default': {
            'engine': 'tortoise.backends.mysql',
            'credentials': {
                'user': 'root',  # 用户名
                'password': 'my-secret-pw',  # 密码
                'host': '127.0.0.1',  # 访问地址
                'port': 3306,  # 服务器端口号
                'database': 'my_database', # 数据库
            }
        },
    },
    'apps': {
        'models': {
            'models': ['__main__'],
            # If no default_connection specified, defaults to 'default'
            'default_connection': 'default',
        }
    }
}


class Approver(Model):
    """
    数据表对应的 Model
    """
    work_no = CharField(null=False, source_field='work_no', max_length=32, default=None, description='员工工号')
    name = CharField(null=True, source_field='name', max_length=32, default=None, description='员工花名')
    biz_line_scope = CharField(null=True, source_field='biz_line_scope', max_length=1024, default=None,
                               description='业务线范围')
    org_scope = CharField(null=True, source_field='org_scope', max_length=1024, default=None,
                          description='组织架构范围')
    biz_line_process_type = CharField(null=True, source_field='biz_line_process_type', max_length=256, default=None,
                                      description='业务线流程类型')
    org_process_type = CharField(null=True, source_field='org_process_type', max_length=256, default=None,
                                 description='组织架构业务类型')
    creator = CharField(null=False, source_field='creator', max_length=16, default=None, description='创建人工号')
    modifier = CharField(null=True, source_field='modifier', max_length=16, default=None, description='修改人工号')

    class Meta:
        table = 'approver'


@asynccontextmanager
async def lifespan(app: FastAPI):
    # 应用程序接受第一个请求之前执行下面的逻辑
    register_tortoise = RegisterTortoise(app, config=mysql_config)
    await register_tortoise.init_orm()
    yield
    # 应用程序在处理完最后一个请求之后执行下面的逻辑
    await register_tortoise.close_orm()


app = FastAPI(lifespan=lifespan)

# 创建一个 Pydantic 模型
Approver_Pydantic = pydantic_model_creator(Approver)

@app.get("/users/{id}", response_model=Approver_Pydantic, response_model_exclude_unset=True)
async def get_user(id: str):
    result = await Approver.get(id=id)
    p = await Approver_Pydantic.from_tortoise_orm(result)
    print("One Event:", p.model_dump_json(indent=4))
    return p



@app.get(path='/')
async def test():
    result = await Approver.all()
    data = []
    for item in result:
        print(item.__dict__)
        data.append(item.__dict__)
    return {'success': True, 'data': data}


if __name__ == '__main__':
    uvicorn.run(app=app, host='127.0.0.1', port=8000)

使用上面的代码,可以查询出对应表的所有数据。

调用查询的内容如下:

自定义自动生成Tortoise Model 工具

在使用 Tortoise-ORM 时,需要创建与 mysql 表对应的 Model,如果手动创建比较繁琐,能否自动根据表结构生成对应的model,查阅了相关文档没有找到可以直接使用的工具,所以自定义了一个工具类来生成 mysql 表对应的 Model,代码如下:

import sys
from getpass import getpass
from optparse import OptionParser

import pymysql
from . import tortoise_model

# 数据库连接配置
config = {
    'user': 'root',  # 用户名
    'password': 'infini_rag_flow',  # 密码
    'host': '127.0.0.1',  # 访问地址
    'port': 3306,  # 端口号
    'database': 'rag_flow',  # 数据库名称
    'charset': 'utf8mb4',
    'cursorclass': pymysql.cursors.DictCursor
}

sql_temp = """
SELECT COLUMN_NAME as column_name, 
        COLUMN_TYPE as column_type, 
        COLUMN_COMMENT as column_comment, 
        COLUMN_DEFAULT as column_default, 
        IS_NULLABLE as is_nullable, 
        EXTRA as extra
FROM information_schema.columns 
WHERE table_name = '%s'
"""


def err(msg):
    sys.stderr.write('\033[91m%s\033[0m\n' % msg)
    sys.stderr.flush()


# 执行查询的函数
def print_models(tables, connect, orm: str):
    connection = pymysql.connect(**connect)
    try:
        # 使用 cursor() 方法创建一个游标对象 cursor
        with connection.cursor() as cursor:

            # 执行查询语句以获取所有表的名称
            cursor.execute("SHOW TABLES")
            # 获取结果集
            query_tables = cursor.fetchall()
            all_table_names = [table.get(f'Tables_in_{connect["database"]}') for table in query_tables]
            if tables:
                intersection = list(set(all_table_names) & set(tables))
            else:
                intersection = all_table_names
            if not intersection:
                err(f'指定的表{tables}不存在,可以选择的表如下:{all_table_names}, 不指定默认为所有表')
                sys.exit(1)

            if orm == 'peewee':
                peewee_model.print_header(config)
            elif orm == 'tortoise':
                tortoise_model.print_header()

            for table_name in intersection:
                # 执行查询语句以获取表结构信息
                cursor.execute(sql_temp % table_name)
                # 获取结果集
                results = cursor.fetchall()
                if orm == 'peewee':
                    peewee_model.print_model(table_name, results)
                elif orm == 'tortoise':
                    tortoise_model.print_model(table_name, results)

    except pymysql.MySQLError as e:
        err(f"Error: {e}")
    finally:
        # 关闭数据库连接
        connection.close()


def get_option_parser():
    parser = OptionParser(usage='usage: %prog [options] database_name')
    ao = parser.add_option
    ao('-H', '--host', dest='host')
    ao('-p', '--port', dest='port', type='int')
    ao('-u', '--user', dest='user')
    ao('-P', '--password', dest='password', action='store_true')
    ao('-o', '--orm', dest='orm', choices=["peewee", "tortoise"],
       help='Choose an ORM to generate code for. default: peewee', default='peewee')
    ao('-t', '--tables', dest='tables',
       help=('Only generate the specified tables. Multiple table names should '
             'be separated by commas.'))
    return parser


def main():
    parser = get_option_parser()
    options, args = parser.parse_args()

    if len(args) < 1:
        err('Missing required parameter "database"')
        parser.print_help()
        sys.exit(1)

    database = args[-1]
    config['database'] = database
    if options.host:
        config['host'] = options.host
    if options.port:
        config['port'] = options.port
    if options.user:
        config['user'] = options.user
    if options.password:
        config['password'] = getpass()

    tables = None
    if options.tables:
        tables = [table.strip() for table in options.tables.split(',')
                  if table.strip()]

    print_models(tables, config, options.orm)


if __name__ == '__main__':
    main()

上面的 gen_model.py 文件是用来连接数据库,查询对应的表结构的。

下面的 tortoise_model.py 文件定义了如果输出 Model 类。

"""
在Django中,Tortoise ORM和MySQL表字段的映射对应表可以如下所示:

Tortoise ORM字段	MySQL表字段
CharField	VARCHAR
TextField	TEXT
IntegerField	INT
FloatField	FLOAT
DateTimeField	DATETIME
BooleanField	BOOLEAN
注意:这只是一些常见的字段类型映射,还有其他更多字段类型可以使用,具体取决于您的数据表结构和模型定义。

"""

from .model_utils import underline_to_camel, parse_field_type
from tortoise.fields import *

HEADER = """from tortoise.fields import *
from tortoise.models import Model
"""


FILED_MAPPING = {
    'bigint': IntField,
    'varchar': CharField,
    'enum': CharField,
    'text': TextField,
    'datetime': DatetimeField,
    'decimal': DecimalField,
    'tinyint': BooleanField,
    'date': DateField,
    'time': TimeField,
    'longtext': TextField,
}


def print_header():
    print(HEADER)


def print_model(table, rows):
    if not rows:
        return
    print('class %s(Model):' % underline_to_camel(table))

    for row in rows:
        column_name = row['column_name']
        column_type = row['column_type']
        column_comment = row['column_comment']
        column_default = row['column_default']
        is_nullable = True if row['is_nullable'] == 'YES' else False

        if 'timestamp' in column_type and column_default == 'CURRENT_TIMESTAMP':
            column_default = 'None'
        elif column_default and type(column_default) == str:
            column_default = f"'{column_default}'"

        column_type = parse_field_type(column_type)

        if column_type:
            field_type = column_type['field_type']
            field = FILED_MAPPING.get(field_type)
            if field:
                if 'varchar' in field_type:
                    length = column_type['length']
                    print(
                        f"    {column_name} = {field.__name__}(null={is_nullable}, source_field='{column_name}', max_length={length}, default={column_default}, description='{column_comment}')")
                else:
                    print(
                        f"    {column_name} = {field.__name__}(null={is_nullable}, source_field='{column_name}', default={column_default}, description='{column_comment}')")
    print('')
    print('    class Meta:')
    print('        table = \'%s\'' % table)
    print('')

其中model_utils文件的内容如下:

def underline_to_camel(name):
    """
    下划线转驼峰
    :param name:
    :return:
    """
    parts = name.split('_')
    camel_name = ''.join(word.title() for word in parts)
    return camel_name


def transform_enum_string(s):
    """
    解析enum类型
    将字符串enum('N','Y') 转换为 ['N','Y']
    :param s:
    :return:
    """
    s = s.strip("enum()")
    s = s.split(',')
    s = [item.strip("'") for item in s]
    return s


def parse_field_type(column_type: str):
    """
    解析数据类型
    输入:bigint,输出:('bigint', None)
    输入:bigint(20) unsigned,输出:('bigint unsigned', 20)
    输入:varchar(65535),输出:('varchar', 65535)
    :param column_type:
    :return:
    """
    import re
    match = re.match(r'^(\w+)(?:\((\d+)\))?\s*(\w+)?$', column_type)
    if match:
        if match.group(3):
            field_type = match.group(1) + " " + match.group(3)
        else:
            field_type = match.group(1)
        if match.group(2):
            length = int(match.group(2))
        else:
            length = None
        result = {
            "field_type": field_type,
            "length": length
        }
        return result
    elif 'enum' in column_type:
        enum_values = transform_enum_string(column_type)
        result = {
            "field_type": 'enum',
            "enum_vale": enum_values
        }
        return result
    else:
        return {}

参考文档

Tortoise-ORM FastAPI integration

Lifespan Events