[CTMS]基于Fastapi的微服务架构量化交易管理系统数据库操作sqlalchemy实现

92 阅读7分钟

1. 基类DalBase

DalBase 类,用于封装数据库的通用增删改查操作。

1.1 类属性

ORDER_FIELD = ["desc""descending"]
  • ORDER_FIELD 是一个列表,包含表示倒序排序的字符串,用于判断排序方向。

1.2 初始化

def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None):
    self.db = db
    self.model = model
    self.schema = schema
  • init 方法接收数据库会话 db、数据库模型 model 和序列化模式 schema 作为参数,并将它们存储为实例属性。

1.3 主要方法

1.3.1 get_data 方法

async def get_data(  
        self,  
        data_id: int = None,  
        v_start_sql: SelectType = None,  
        v_select_from: list[Any] = None,  
        v_join: list[Any] = None,  
        v_outer_join: list[Any] = None,  
        v_options: list[_AbstractLoad] = None,  
        v_where: list[BinaryExpression] = None,  
        v_order: str = None,  
        v_order_field: str = None,  
        v_return_none: bool = False,  
        v_schema: Any = None,  
        v_expire_all: bool = False,  
        **kwargs  
) -> Any
  • 用于获取单个数据,默认使用 ID 查询,也支持关键词查询。
  • 可通过参数指定初始 SQL、连接表、查询条件、排序等。
  • 若未找到数据,根据 v_return_none 参数决定返回 None 还是抛出 HTTPException。
  • 若指定 v_schema,则返回序列化后的数据。

1.3.2 get_datas 方法

async def get_datas(  
        self,  
        page: int = 1,  
        limit: int = 20,  
        v_start_sql: SelectType = None,  
        v_select_from: list[Any] = None,  
        v_join: list[Any] = None,  
        v_outer_join: list[Any] = None,  
        v_options: list[_AbstractLoad] = None,  
        v_where: list[BinaryExpression] = None,  
        v_order: str = None,  
        v_order_field: str = None,  
        v_return_count: bool = False,  
        v_return_scalars: bool = False,  
        v_return_objs: bool = False,  
        v_schema: Any = None,  
        v_distinct: bool = False,  
        v_expire_all: bool = False,  
        **kwargs  
) -> Union[list[Any], ScalarResult, tuple]
  • 用于获取数据列表,支持分页查询。
  • 可通过参数指定初始 SQL、连接表、查询条件、排序、是否返回总数等。
  • 返回值优先级为 v_return_scalars > v_return_objs > v_schema。
  • 字段说明
:param page: 页码  
:param limit: 当前页数据量  
:param v_start_sql: 初始 sql:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。  
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。  
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。  
:param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。  
:param v_where: 当前表查询条件,原始表达式  
:param v_order: 排序,默认正序,为 desc 是倒叙  
:param v_order_field: 排序字段  
:param v_return_count: 默认为 False,是否返回 count 过滤后的数据总数,不会影响其他返回结果,会一起返回为一个数组  
:param v_return_scalars: 返回scalars后的结果  
:param v_return_objs: 是否返回对象  
:param v_schema: 指定使用的序列化对象  
:param v_distinct: 是否结果去重  
:param v_expire_all: 使当前会话(Session)中所有已加载的对象过期,确保您获取的是数据库中的最新数据,但可能会有性能损耗
:param kwargs: 查询参数,使用的是自定义表达式  
:return: 返回值优先级:v_return_scalars > v_return_objs > v_schema

1.3.3 create_data 方法

async def create_data(  
        self,  
        data,  
        v_options: list[_AbstractLoad] = None,  
        v_return_obj: bool = False,  
        v_schema: Any = None  
) -> Any
  • 用于创建单个数据,支持字典或模型对象作为输入。
  • 可通过参数指定预加载选项、是否返回对象和序列化模式。

1.3.4 create_datas 方法

async def create_datas(self, datas: list[dict]) -> None
  • 用于批量创建数据,接收字典列表作为输入。

1.3.5 put_data 方法

async def put_data(  
        self,  
        data_id: int,  
        data: Any,  
        v_options: list[_AbstractLoad] = None,  
        v_return_obj: bool = False,  
        v_schema: Any = None  
) -> Any
  • 用于更新单个数据,根据 data_id 查找数据并更新。
  • 可通过参数指定预加载选项、是否返回对象和序列化模式。

1.3.6 delete_datas 方法

async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None
  • 用于删除多条数据,支持软删除。
  • v_soft 为 True 时执行软删除,更新 delete_datetime 和 is_delete 字段。

1.3.7 flush 方法

async def flush(self, obj: Any = None) -> Any
  • 用于将对象刷新到数据库,若传入对象则添加到会话中。
  • 刷新后若传入对象则执行 refresh 操作。

1.3.8 add_relation 方法

def add_relation(  
        self,  
        v_start_sql: SelectType,  
        v_select_from: list[Any] = None,  
        v_join: list[Any] = None,  
        v_outer_join: list[Any] = None,  
        v_options: list[_AbstractLoad] = None,  
) -> SelectType
  • 用于处理关系查询和关系加载,可指定查询起始表、内连接、外连接和查询选项。

1.3.9 dict_filter方法

def __dict_filter(self, **kwargs) -> list[BinaryExpression]:  
    """  
    字典过滤  
    :param model:    :param kwargs:    """    conditions = []  
    for field, value in kwargs.items():  
        if value is not None and value != "":  
            attr = getattr(self.model, field)  
            if isinstance(value, tuple):  
                if len(value) == 1:  
                    if value[0] == "None":  
                        conditions.append(attr.is_(None))  
                    elif value[0] == "not None":  
                        conditions.append(attr.isnot(None))  
                    else:  
                        raise CustomException("SQL查询语法错误")  
                elif len(value) == 2 and value[1not in [None, [], ""]:  
                    if value[0] == "date":  
                        conditions.append(func.date_format(attr, "%Y-%m-%d") == value[1])  
                    elif value[0] == "like":  
                        conditions.append(attr.like(f"%{value[1]}%"))  
                    elif value[0] == "in":  
                        conditions.append(attr.in_(value[1]))  
                    elif value[0] == "between" and len(value[1]) == 2:  
                        conditions.append(attr.between(value[1][0], value[1][1]))  
                    elif value[0] == "month":  
                        conditions.append(func.date_format(attr, "%Y-%m") == value[1])  
                    elif value[0] == "!=":  
                        conditions.append(attr != value[1])  
                    elif value[0] == ">":  
                        conditions.append(attr > value[1])  
                    elif value[0] == ">=":  
                        conditions.append(attr >= value[1])  
                    elif value[0] == "<=":  
                        conditions.append(attr <= value[1])  
                    else:  
                        raise CustomException("SQL查询语法错误")  
            else:  
                conditions.append(attr == value)  
    return conditions
  • 用于字典过滤,根据关键词参数生成查询条件列表。
  • 支持多种查询语法,如日期查询、模糊查询、范围查询等。

2. 具体类实现

定义一个名为 AccountDal 的类,它继承自 DalBase,用于处理与账户数据相关的数据库操作。以下是对该类及其方法的详细解释:

2.1 类定义和初始化

class AccountDal(DalBase):  
  
    def __init__(selfdb: AsyncSession):  
        super(AccountDalself).__init__()  
        self.db = db  
        self.model = models.Account  
        self.schema = schemas.AccountSimpleOut
  • AccountDal 继承自 DalBase,这意味着它可以复用 DalBase 类中的一些通用方法。
  1. init 方法接收一个异步数据库会话 db,并将其存储在实例属性 self.db 中。
  2. self.model 指向 models.Account,表示该类操作的数据库模型。
  3. self.schema 指向 schemas.AccountSimpleOut,用于序列化输出的数据。

2.2 create_data 方法

async def create_data(  
        self,  
        data: schemas.Account,  
        v_options: list[_AbstractLoad] = None,  
        v_return_obj: bool = False,  
        v_schema: Any = None  
) -> Any:  
    date_delta = data.profit_date - timedelta(days=1)  
    total_profit = await self.get_total_profit(apikey_id=data.apikey_id, end=date_delta)  
    unique = await self.get_data(apikey_id=data.apikey_id, profit_date=date_delta, v_return_none=True)  
    if unique:  
        data.profit = float(data.balance) - float(unique.balance)  
    obj = await self.get_data(profit_date=data.profit_date, apikey_id=data.apikey_id, v_return_none=True)  
    if obj:  # 更新  
        obj_dict = jsonable_encoder(data)  
        for key, value in obj_dict.items():  
            setattr(obj, key, value)  
        obj.total_profit = total_profit + data.profit  
        await self.flush(obj)  
    else:  # 创建  
        data.total_profit = total_profit + data.profit  
        obj = self.model(**data.model_dump())  
        await self.flush(obj)  
    return await self.out_dict(obj, v_options, v_return_obj, v_schema)
  • 该方法覆盖了基类的create_data方法,用于创建或更新账户数据。

  • 参数:

    1. data:一个 schemas.Account 类型的对象,包含要创建或更新的账户数据。
    2. v_options:一个可选的 _AbstractLoad 列表,用于指定查询选项。
    3. v_return_obj:一个布尔值,指示是否返回对象本身。
    4. v_schema:一个可选的模式,用于序列化输出的数据。
  • 实现步骤:

    1. 计算 profit_date 前一天的日期 date_delta。
    2. 调用 get_total_profit 方法计算截至 date_delta 的总利润。
    3. 调用 get_data 方法检查 date_delta 当天是否存在相同 apikey_id 的账户数据。
    4. 如果存在,则计算当天的利润 data.profit。
    5. 检查 profit_date 当天是否存在相同 apikey_id 的账户数据。
    6. 如果存在,则更新该账户数据;否则,创建一个新的账户数据。
    7. 调用 out_dict 方法将结果序列化并返回。

2.3 get_total_profit 方法

async def get_total_profit(self,  
                           apikey_id: int = 1,  
                           begin: date = date(197011),  
                           end: date = date.today()) -> float:  
    queryset = await (self.db.execute(  
        select(func.sum(self.model.profit)).filter(  
            self.model.profit_date.between(begin, end), self.model.apikey_id == apikey_id)  
    ))  
    total_profit = queryset.scalar_one()  
    return total_profit or 0
  • 该方法用于计算指定 apikey_id 在指定日期范围内的总利润。

  • 参数:

    1. apikey_id:账户的 API 密钥 ID,默认为 1。
    2. begin:开始日期,默认为 1970 年 1 月 1 日。
    3. end:结束日期,默认为当前日期。
  • 实现步骤:

    1. 执行一个 SQL 查询,使用 func.sum 函数计算 profit 字段的总和。
    2. 使用 filter 方法筛选出 profit_date 在指定范围内且 apikey_id 匹配的记录。
    3. 使用 scalar_one 方法获取查询结果的第一个标量值。
    4. 如果结果为 None,则返回 0。

3. 联系方式

  • 公众号:ScienceStudio

公众号二维码 1.png