python - sqlachemy另类用法
这里只是给出一个思路,或许对于未来解决问题有一些参考意义。
仿 JAP 的写法
这种写法很像 java 环境中的 JPA,如果引入模版引擎,则可以大幅增强实用性。
但是,在 python 环境中,这不符合主流的 ORM 框架。
潜在风险:代码检测的时候,可能会被误判,因为我们定义了一大堆空的函数。
# 注解式事务 start ---------------------------------------------
@update(sql='UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1')
def modify(params: dict = None) -> int:
pass
@query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
def queryById(params: dict = None) -> list:
pass
@query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
def queryById2(id: int) -> list:
pass
@transactional()
def test_annotation():
ret = modify({'id': 18, 'desc': 'OR 1=1'})
print(ret)
result = queryById2(18)
print(result)
代码封装
import inspect
import logger_factory
import typing
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.engine import Result, CursorResult
logger = logger_factory.get_logger()
# 定义数据库连接字符串
DATABASE_URI = 'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4'
# 替换为你的数据库用户名、密码、主机、端口和数据库名
USERNAME = 'root'
PASSWORD = 'root'
HOST = 'localhost'
PORT = '3306'
DBNAME = 'med'
# 创建数据库引擎,使用连接池
engine = create_engine(
DATABASE_URI.format(
username=USERNAME,
password=PASSWORD,
host=HOST,
port=PORT,
dbname=DBNAME
),
echo=False, # 如果设置为True,SQLAlchemy将打印所有执行的SQL语句,通常用于调试
pool_size=10, # 连接池大小
max_overflow=20, # 超过连接池大小外最多创建的连接数
pool_timeout=30, # 连接池中没有线程可用时,在抛出异常前等待的时间
pool_recycle=3600 # 多少秒之后对连接进行一次回收(重置)
)
# do a test
with engine.connect() as con:
rs = con.execute(text('SELECT 1'))
rs.fetchone()
logger.debug('create engine succeed!')
# session-maker
Session = sessionmaker(bind=engine)
# thread safe session-maker
DBSession = scoped_session(Session)
# with Session() as session:
# # 获取数据库连接
# connection = session.connection()
# savepoint = connection.begin_nested()
# print(savepoint)
def getEffectRows(result: Result) -> int:
r"""
获取受影响行数
这里有点问题:源码部分 rowcount 是一个 callable,但实际应该是 int;
这里绕一点,确保不会出问题,如果返回 -1,说明出现了意料之外的情况
:param result: 结果集
:return: 受影响行数
"""
if isinstance(result, CursorResult):
effect_row = result.rowcount
if isinstance(effect_row, int):
return effect_row
if callable(effect_row):
return effect_row()
return -1
def resultAsDict(result: Result) -> list:
r"""
将查询结果转换为 dict-list
:param result: 结果集
:return: dict 列表
"""
keys = result.keys()
ret = list()
for item in result.fetchall():
ret.append(dict(zip(keys, item)))
return ret
def execute(sql: str, params: dict = None) -> Result:
r"""
执行一条查询语句
:param sql: 查询语句
:param params: 参数
:return: 结果集
"""
if sql is None:
raise ValueError('sql cannot be None')
logger.debug('execute sql: ' + sql)
logger.debug('parameter : ' + str(params))
return DBSession().execute(text(sql), params)
def executeQuery(sql: str, params: dict = None, result_type: type = tuple) -> typing.Sequence:
r"""
执行一个查询
:param sql: sql
:param params: dict
:param result_type: 结果集类型,可选:tuple、dict
:return: 序列
"""
result = execute(sql, params)
if result_type == dict:
return resultAsDict(result)
pass
# default return_type tuple-list
return result.fetchall()
def executeUpdate(sql: str, params: dict = None) -> int:
r"""
执行一个查询
:param sql: sql 执行语句
:param params: dict 查询参数
:return: 受影响行数
"""
result = execute(sql, params)
return getEffectRows(result)
def transactional(rollback: type = Exception):
r"""
注解式事务
用法类似于 spring 环境下的 @Transactional 注解
注意: 事务控制在 session 级别,不能兼容事务嵌套的场景(理想状态下,应当通过 save-point 实现)
推荐: 如果遇到很复杂的事务嵌套,显式调用 session,手动控制事务
:param rollback: 指定触发回滚的异常类型
:return: 装饰器函数
"""
def decorator(func):
def call(*args, **kwargs):
session = None
try:
session = DBSession()
ret = func(*args, **kwargs)
session.commit()
return ret
except rollback as e:
if session:
session.rollback()
logger.exception(f'transaction exception, rollback: {str(e)}')
raise
finally:
if session:
session.close()
return call
return decorator
pass
def update(sql: str = None):
r"""
注解式查询,E.G.::
@update(sql='UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1')
def modify(params: dict = None) -> int:
pass
:param sql: 要执行的 sql
:return: decorator
"""
def decorator(func):
def call(*args, **kwargs):
result = execute(sql, args[0])
return getEffectRows(result)
return call
return decorator
pass
def query(sql: str = None, result_type: type = tuple):
r"""
注解式查询,E.G.::
E.G.::
@query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
def queryById2(id: int) -> list:
pass
:param sql: 要执行的 sql
:param result_type: 结果集类型,可选:tuple、dict
:return: decorator
"""
def decorator(func):
def call(*args, **kwargs):
if sql is None:
raise ValueError('sql cannot be None')
first = args[0]
if isinstance(first, dict):
result = DBSession().execute(text(sql), args)
else:
names = inspect.signature(func).parameters.values()
params = dict()
for idx, name in enumerate(names):
params[name.name] = args[idx]
print(params)
result = DBSession().execute(text(sql), params)
if result_type == dict:
keys = result.keys()
ret = list()
for item in result.fetchall():
ret.append(dict(zip(keys, item)))
return ret
# default return_type tuple
pass
return result.fetchall()
return call
return decorator
pass
@transactional()
def test_transaction():
r"""
测试注解式事务
:return: None
"""
session = DBSession()
session.execute(text("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1"), {'id': 18, 'desc': 'OR 1=3'})
session.execute(text("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1"), {'id': 18, 'desc': 'OR 1=4'})
# raise exception
raise SyntaxError('Syntax error')
@transactional()
def test_api():
r"""
测试封装过的函数
:return: None
"""
execute("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1", {'id': 18, 'desc': 'OR 1=1'})
execute("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1", {'id': 18, 'desc': 'OR 1=2'})
# raise exception
raise SyntaxError('Syntax error')
疯狂的妞妞 :每一天,做什么都好,不要什么都不做!