数据库操作

67 阅读2分钟
from datetime import datetime
from sqlmodel import SQLModel, Field, Session, create_engine, select


# 数据表模型
class Table1(SQLModel, table=True):
    __tablename__ = 'Table1'
    id: int = Field(default=None, primary_key=True)  # id
    timestamp: datetime = Field(default_factory=datetime.now)  # 时间戳
    data: str | None = Field()  # 数据


class Table2(SQLModel, table=True):
    __tablename__ = 'Table2'
    id: int = Field(default=None, primary_key=True)  # id
    timestamp: datetime = Field(default_factory=datetime.now)  # 时间戳
    data: str | None = Field()  # 数据


TableT = Table1 | Table2


# 数据库 ORM
class DatabaseORM:
    def __init__(self, path: str):
        '''
        初始化数据库 ORM
        :param path: 数据库路径
        '''
        self.engine = create_engine(f'sqlite:///{path}')
        SQLModel.metadata.create_all(self.engine)

    def add(self, data: TableT) -> None:
        '''
        向表中新增一行数据
        :param data: 新增数据
        '''
        with Session(self.engine) as session:
            session.add(data)  # 表单添加数据加入会话
            session.commit()  # 提交会话更改

    def add_many(self, datas: list[TableT]) -> None:
        '''
        向表中新增多行数据
        :param datas: 新增数据列表
        '''
        with Session(self.engine) as session:
            session.add_all(datas)
            session.commit()

    def query(self, table: type[TableT], where) -> TableT | None:
        '''
        查询表数据
        :param table: 表模型
        :param where: 查询条件
        :return 若存在数据则返回该数据模型,若无数据则返回 None
        '''
        with Session(self.engine) as session:
            statement = select(table).where(where)
            results = session.exec(statement)  # noqa
            return results.first()

    def update_data(self, table: type[TableT], where, new_data: TableT) -> None:
        '''
        按条件更新表中的字段
        :param table: 表模型
        :param where: 查询条件
        :param new_data: 新数据
        '''
        with Session(self.engine) as session:
            statement = select(table).where(where)
            results = session.exec(statement)
            data = results.first()
            # 排除未设置的默认值,与'id'字段
            updates = new_data.model_dump(exclude_unset=True, exclude={'id'})
            # 更新字段值
            for k, v in updates.items():
                setattr(data, k, v)
            session.add(data)  # 数据加入会话
            session.commit()  # 提交会话更改

    def update_value(self, table: type[TableT], where, key: ..., value: ...) -> None:
        '''
        按条件更新表中字段的值
        :param table: 表模型
        :param where: 查询条件
        :param key: 更新字段
        :param value: 更新值
        '''
        with Session(self.engine) as session:
            statement = select(table).where(where)
            results = session.exec(statement)
            field = results.first()
            # 更新字段值
            setattr(field, key.key, value)
            session.add(field)
            session.commit()

    def get_table(self, table: type[TableT]) -> list[TableT]:
        '''获取整张表的数据'''
        with Session(self.engine) as session:
            statement = select(table)
            results = session.exec(statement)
            return results.all()  # noqa

    def close(self):
        '''关闭引擎'''
        self.engine.dispose()


if __name__ == '__main__':
    databaseORM = DatabaseORM(path='database.db')
    # 增加数据
    # data = Table1()
    # databaseORM.add(data=data)
    # datas = [Table1(), Table1()]
    # databaseORM.add_many(datas=datas)
    # 查询数据
    # result = databaseORM.query(table=Table1, where=Table1.data == 1)
    # print(result)
    # 更新字段
    # data = Table1(data='123123')
    # databaseORM.update_data(table=Table1, where=Table1.id == 1, new_data=data)
    # 更新字段值
    # databaseORM.update_value(table=Table1, where=Table1.id == 1, key=Table1.data, value=666)
    # 获取表
    # table = databaseORM.get_table(Table1)
    # print(table)