sqlalchemy线程和异步安全的实现,异步使用aiomysql,同步使用mysqlconnector,使用不同的库修改为对应的driver。
import contextlib
import os
from asyncio import current_task
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import sessionmaker, scoped_session
def get_engine_url(db, is_async=True) -> str:
"""
生成数据库连接url,<db_type+driver://user:passwd@host:port/db>
:param bool is_async: 是否使用异步连接
:param str db: 数据库
:return:
"""
if db == "dev":
user = "user"
passwd = "passwd"
host = "host"
port = 3306
dbname = "dbname"
else:
raise RuntimeError(f"Not supported db: {db}, config this db by `.env` file.")
if is_async:
driver = "aiomysql"
else:
driver = "mysqlconnector"
return f"mysql+{driver}://{user}:{passwd}@{host}:{port}/{dbname}"
class DBManager:
def __init__(self, db='dev', engine_options=None, session_options=None, is_async=False):
"""
:param db: 数据库名
:param engine_options:
echo: 是否打印sql语句
pool_size: 连接池大小
max_overflow: 并发时允许超过pool_size的连接数量
pool_timeout: 从连接池中获取连接的等待时间
pool_recycle: 连接回收时间
:param session_options:
autoflush default True,
autocommit default False,
expire_on_commit default True: 避免获取属性时重复查询
:param bool is_async:
"""
if not engine_options:
engine_options = {}
echo_default = os.getenv("ENV_NAME") != "PROD"
self._engine_options = {"echo": engine_options.pop("echo", echo_default),
"pool_size": engine_options.pop("pool_size", 20),
"max_overflow": engine_options.pop("max_overflow", 10),
"pool_timeout": engine_options.pop("pool_timeout", 100),
"pool_recycle": engine_options.pop("pool_recycle", 60 * 60 * 2)}
self.is_async = is_async
self.url = get_engine_url(db, is_async=is_async)
self._engine_options.update(engine_options)
self._session_options = session_options or {}
# 同步session属性
self._session_factory = None
self._session = None
self._engine = None
# 异步session属性
self._async_engine = None
self._async_factory = None
self._async_scoped = None
@property
def engine(self):
return self.create_engine()
@property
def session_factory(self):
return self.create_session()
@property
def safe_session(self):
if self.is_async:
return self.async_safe_session()
else:
return self.get_safe_session()
def create_scoped_session(self, factory, **session_options):
"""
使用create_session的工厂类,生成sqlalchemy.orm.scoping.scoped_session
:param factory:
:param session_options:
:return:
"""
scopefunc = session_options.pop("scopefunc", None)
return scoped_session(factory, scopefunc=scopefunc)
def create_engine(self):
"""
生成数据库连接engine,方便内外部调用
:return:
"""
if self.is_async:
if self._async_engine is None:
self._async_engine = create_async_engine(self.url, **self._engine_options)
return self._async_engine
else:
if self._engine is None:
self._engine = create_engine(self.url, **self._engine_options)
return self._engine
def create_session(self, **session_options):
"""
返回sqlalchemy.orm.session.sessionmaker生成的工厂类
:param session_options:
:return:
"""
if self.is_async is False:
self._session_factory = sessionmaker(self.engine, **session_options)
return self._session_factory
else:
self._async_factory = sessionmaker(class_=AsyncSession, bind=self.engine, **session_options)
return self._async_factory
def __call__(self, *args, **kwargs):
"""
生成一个线程安全的sessin会话
:param args:
:param kwargs:
:return:
"""
if not self._session:
self._session = self.get_safe_session()
return self._session()
def get_safe_session(self):
"""
线程安全的session,如果单独用这个函数,需要手动remove,例如
>>> session = db.get_safe_session()
>>> res = session.execute("SELECT 1")
>>> print(res.fetchall())
[(1,)]
>>> session.remove()
:return:
"""
if not self._session:
self._session = self.create_scoped_session(self.session_factory, **self._session_options)
return self._session
def get_single_session(self):
"""
单个session连接,线程不安全
>>> session = db.get_single_session()
>>> with session as s:
>>> s.execute("select 1")
:return:
"""
engine = create_engine(self.url, **self._engine_options)
factory = sessionmaker(bind=engine)
conn = engine.connect()
session = factory(bind=conn, expire_on_commit=False)
return session
@contextlib.contextmanager
def session(self):
"""
线程安全的session上下文管理
>>> with db.session() as session:
>>> res = session.execute("SELECT 1")
>>> print(res.fetchall())
[(1,)]
:return:
"""
_session = self()
try:
yield _session
_session.commit()
except Exception as e:
_session.rollback()
raise e
finally:
self._session.remove()
async def async_safe_session(self):
"""
生成一个异步安全的session回话
:return:
"""
if not self._async_scoped:
self._async_scoped = async_scoped_session(self.session_factory, scopefunc=current_task)
return self._async_scoped
@contextlib.asynccontextmanager
async def async_session(self):
"""
异步session上下文管理封装
:return:
>>> async with async_db.async_session() as session:
>>> res = await session.execute("SELECT 1")
>>> print(res.scalar())
1
"""
scoped = await self.async_safe_session()
try:
_session = scoped()
yield _session
await self._async_scoped.commit()
except Exception as e:
await self._async_scoped.rollback()
raise e
finally:
# 显式的调用engine.dispose,否则会脱离上下文,对象被回收
# sqlalchemy无法处理异步中的__del__和弱引用,会触发异常
# RuntimeError: Event loop is closed
await self._async_engine.dispose()
await self._async_scoped.remove()
def callproc(self, proc_name, *args, log_resu=False) -> Tuple[Optional[tuple], List[Tuple]]:
"""
执行存储过程,返回存储过程的IN和OUT,异常则返回None
Args:
proc_name: 存储过程名
*args: 存储过程参数
log_resu(bool): 是否记录存储过程返回,默认为False
Returns:
tuple 第一个为存储过程IN&OUT, 第二个为SELECT输出
Examples
>>> db.callproc('proc_demo', 'in', 'out')
(('in', 'Hello in'), [('{"MESSAGE": "000000"}',)])
"""
engine = self.create_engine()
conn = engine.raw_connection()
try:
with conn.cursor() as cursor:
try:
r = cursor.callproc(proc_name, args)
conn.commit()
values = []
for sr in cursor.stored_results():
values.append(sr.fetchone())
if log_resu:
logger.info(f"Proc result: {values}")
return r, values
except Exception as e:
logger.info(f"Call proc error: {e}")
conn.rollback()
finally:
conn.close()
return None, []
async def acallproc(
self,
proc_name, args: Union[List, Tuple],
out_args: Union[List, Tuple] = None
) -> Tuple[Optional[tuple], Optional[tuple]]:
"""
自动匹配OUT参数位置,获取OUT输出
Args:
proc_name: 存储过程名
args: 存储过程所有参数
out_args: 需要获取输出的参数
Returns:
tuple 第一个为获取到的OUT,没有则为None,第二个为SELECT输出
Examples
>>> async def acall():
>>> return await db.acallproc('proc_demo', args=['in', 'out'], out_args=['out'])
(('Hello in',), ('{"MESSAGE": "0000"}',))
"""
engine = self.engine
conn = await aiomysql.connect(
host=engine.url.host,
port=engine.url.port,
user=engine.url.username,
password=engine.url.password,
db=engine.url.database,
echo=self._engine_options.get('echo', False)
)
async with conn:
async with conn.cursor() as cur:
await cur.callproc(proc_name, args=args)
results = await cur.fetchone()
if out_args:
# 获取OUT
await cur.nextset()
out_args = ','.join([f'@_{proc_name}_{args.index(i)}' for i in out_args])
await cur.execute(f'select {out_args}')
return await cur.fetchone(), results
return None, results