sqlalchemy线程和异步安全实践

1,050 阅读3分钟

sqlalchemy线程和异步安全的实现,异步使用aiomysql,同步使用mysqlconnector,使用不同的库修改为对应的driver。

import contextlib
import os
from asyncio import current_task

from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import sessionmaker, scoped_session


def get_engine_url(db, is_async=True) -> str:
    """
    生成数据库连接url,<db_type+driver://user:passwd@host:port/db>
    :param bool is_async: 是否使用异步连接
    :param str db: 数据库
    :return:
    """
    if db == "dev":
        user = "user"
        passwd = "passwd"
        host = "host"
        port = 3306
        dbname = "dbname"
    else:
        raise RuntimeError(f"Not supported db: {db}, config this db by `.env` file.")
    if is_async:
        driver = "aiomysql"
    else:
        driver = "mysqlconnector"
    return f"mysql+{driver}://{user}:{passwd}@{host}:{port}/{dbname}"


class DBManager:

    def __init__(self, db='dev', engine_options=None, session_options=None, is_async=False):
        """
        :param db: 数据库名
        :param engine_options:
            echo: 是否打印sql语句
            pool_size: 连接池大小
            max_overflow: 并发时允许超过pool_size的连接数量
            pool_timeout: 从连接池中获取连接的等待时间
            pool_recycle: 连接回收时间
        :param session_options:
            autoflush default True,
            autocommit default False,
            expire_on_commit default True: 避免获取属性时重复查询
        :param bool is_async:
        """
        if not engine_options:
            engine_options = {}
        echo_default = os.getenv("ENV_NAME") != "PROD"
        self._engine_options = {"echo": engine_options.pop("echo", echo_default),
                                "pool_size": engine_options.pop("pool_size", 20),
                                "max_overflow": engine_options.pop("max_overflow", 10),
                                "pool_timeout": engine_options.pop("pool_timeout", 100),
                                "pool_recycle": engine_options.pop("pool_recycle", 60 * 60 * 2)}
        self.is_async = is_async
        self.url = get_engine_url(db, is_async=is_async)
        self._engine_options.update(engine_options)
        self._session_options = session_options or {}
        # 同步session属性
        self._session_factory = None
        self._session = None
        self._engine = None
        # 异步session属性
        self._async_engine = None
        self._async_factory = None
        self._async_scoped = None

    @property
    def engine(self):
        return self.create_engine()

    @property
    def session_factory(self):
        return self.create_session()

    @property
    def safe_session(self):
        if self.is_async:
            return self.async_safe_session()
        else:
            return self.get_safe_session()

    def create_scoped_session(self, factory, **session_options):
        """
        使用create_session的工厂类,生成sqlalchemy.orm.scoping.scoped_session
        :param factory:
        :param session_options:
        :return:
        """
        scopefunc = session_options.pop("scopefunc", None)
        return scoped_session(factory, scopefunc=scopefunc)

    def create_engine(self):
        """
        生成数据库连接engine,方便内外部调用
        :return:
        """
        if self.is_async:
            if self._async_engine is None:
                self._async_engine = create_async_engine(self.url, **self._engine_options)
            return self._async_engine
        else:
            if self._engine is None:
                self._engine = create_engine(self.url, **self._engine_options)
            return self._engine

    def create_session(self, **session_options):
        """
        返回sqlalchemy.orm.session.sessionmaker生成的工厂类
        :param session_options:
        :return:
        """
        if self.is_async is False:
            self._session_factory = sessionmaker(self.engine, **session_options)
            return self._session_factory
        else:
            self._async_factory = sessionmaker(class_=AsyncSession, bind=self.engine, **session_options)
            return self._async_factory

    def __call__(self, *args, **kwargs):
        """
        生成一个线程安全的sessin会话
        :param args:
        :param kwargs:
        :return:
        """
        if not self._session:
            self._session = self.get_safe_session()
        return self._session()

    def get_safe_session(self):
        """
        线程安全的session,如果单独用这个函数,需要手动remove,例如
        >>> session = db.get_safe_session()
        >>> res = session.execute("SELECT 1")
        >>> print(res.fetchall())
        [(1,)]
        >>> session.remove()

        :return:
        """
        if not self._session:
            self._session = self.create_scoped_session(self.session_factory, **self._session_options)
        return self._session

    def get_single_session(self):
        """
        单个session连接,线程不安全

        >>> session = db.get_single_session()
        >>> with session as s:
        >>>     s.execute("select 1")

        :return:
        """
        engine = create_engine(self.url, **self._engine_options)
        factory = sessionmaker(bind=engine)
        conn = engine.connect()
        session = factory(bind=conn, expire_on_commit=False)
        return session

    @contextlib.contextmanager
    def session(self):
        """
        线程安全的session上下文管理

        >>> with db.session() as session:
        >>>     res = session.execute("SELECT 1")
        >>>     print(res.fetchall())
        [(1,)]

        :return:
        """
        _session = self()
        try:
            yield _session
            _session.commit()
        except Exception as e:
            _session.rollback()
            raise e
        finally:
            self._session.remove()

    async def async_safe_session(self):
        """
        生成一个异步安全的session回话
        :return:
        """
        if not self._async_scoped:
            self._async_scoped = async_scoped_session(self.session_factory, scopefunc=current_task)
        return self._async_scoped

    @contextlib.asynccontextmanager
    async def async_session(self):
        """
        异步session上下文管理封装
        :return:

        >>> async with async_db.async_session() as session:
        >>>     res = await session.execute("SELECT 1")
        >>>     print(res.scalar())
        1

        """
        scoped = await self.async_safe_session()
        try:
            _session = scoped()
            yield _session
            await self._async_scoped.commit()
        except Exception as e:
            await self._async_scoped.rollback()
            raise e
        finally:
            # 显式的调用engine.dispose,否则会脱离上下文,对象被回收
            # sqlalchemy无法处理异步中的__del__和弱引用,会触发异常
            # RuntimeError: Event loop is closed
            await self._async_engine.dispose()
            await self._async_scoped.remove()

    def callproc(self, proc_name, *args, log_resu=False) -> Tuple[Optional[tuple], List[Tuple]]:
        """
        执行存储过程,返回存储过程的IN和OUT,异常则返回None
        Args:
            proc_name: 存储过程名
            *args: 存储过程参数
            log_resu(bool): 是否记录存储过程返回,默认为False

        Returns:
            tuple 第一个为存储过程IN&OUT, 第二个为SELECT输出
        Examples
            >>> db.callproc('proc_demo', 'in', 'out')
            (('in', 'Hello in'), [('{"MESSAGE": "000000"}',)])

        """
        engine = self.create_engine()
        conn = engine.raw_connection()
        try:
            with conn.cursor() as cursor:
                try:
                    r = cursor.callproc(proc_name, args)
                    conn.commit()
                    values = []
                    for sr in cursor.stored_results():
                        values.append(sr.fetchone())
                    if log_resu:
                        logger.info(f"Proc result: {values}")
                    return r, values
                except Exception as e:
                    logger.info(f"Call proc error: {e}")
                    conn.rollback()
        finally:
            conn.close()
        return None, []

    async def acallproc(
        self,
        proc_name, args: Union[List, Tuple],
        out_args: Union[List, Tuple] = None
    ) -> Tuple[Optional[tuple], Optional[tuple]]:
        """
        自动匹配OUT参数位置,获取OUT输出
        Args:
            proc_name: 存储过程名
            args: 存储过程所有参数
            out_args: 需要获取输出的参数

        Returns:
            tuple 第一个为获取到的OUT,没有则为None,第二个为SELECT输出
        Examples
            >>> async def acall():
            >>>     return await db.acallproc('proc_demo', args=['in', 'out'], out_args=['out'])
            (('Hello in',), ('{"MESSAGE": "0000"}',))
        """
        engine = self.engine
        conn = await aiomysql.connect(
            host=engine.url.host,
            port=engine.url.port,
            user=engine.url.username,
            password=engine.url.password,
            db=engine.url.database,
            echo=self._engine_options.get('echo', False)
        )
        async with conn:
            async with conn.cursor() as cur:
                await cur.callproc(proc_name, args=args)
                results = await cur.fetchone()
                if out_args:
                    # 获取OUT
                    await cur.nextset()
                    out_args = ','.join([f'@_{proc_name}_{args.index(i)}' for i in out_args])
                    await cur.execute(f'select {out_args}')
                    return await cur.fetchone(), results
        return None, results