python - sqlachemy另类用法

这里只是给出一个思路,或许对于未来解决问题有一些参考意义。

仿 JAP 的写法

这种写法很像 java 环境中的 JPA,如果引入模版引擎,则可以大幅增强实用性。

但是,在 python 环境中,这不符合主流的 ORM 框架。

潜在风险:代码检测的时候,可能会被误判,因为我们定义了一大堆空的函数。

# 注解式事务 start ---------------------------------------------

@update(sql='UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1')
def modify(params: dict = None) -> int:
    pass


@query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
def queryById(params: dict = None) -> list:
    pass


@query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
def queryById2(id: int) -> list:
    pass


@transactional()
def test_annotation():
    ret = modify({'id': 18, 'desc': 'OR 1=1'})
    print(ret)

    result = queryById2(18)
    print(result)

代码封装


import inspect

import logger_factory
import typing

from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session

from sqlalchemy.engine import Result, CursorResult

logger = logger_factory.get_logger()

# 定义数据库连接字符串
DATABASE_URI = 'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4'

# 替换为你的数据库用户名、密码、主机、端口和数据库名
USERNAME = 'root'
PASSWORD = 'root'
HOST = 'localhost'
PORT = '3306'
DBNAME = 'med'

# 创建数据库引擎,使用连接池
engine = create_engine(
    DATABASE_URI.format(
        username=USERNAME,
        password=PASSWORD,
        host=HOST,
        port=PORT,
        dbname=DBNAME
    ),
    echo=False,  # 如果设置为True,SQLAlchemy将打印所有执行的SQL语句,通常用于调试
    pool_size=10,  # 连接池大小
    max_overflow=20,  # 超过连接池大小外最多创建的连接数
    pool_timeout=30,  # 连接池中没有线程可用时,在抛出异常前等待的时间
    pool_recycle=3600  # 多少秒之后对连接进行一次回收(重置)
)

# do a test
with engine.connect() as con:
    rs = con.execute(text('SELECT 1'))
    rs.fetchone()
    logger.debug('create engine succeed!')

# session-maker
Session = sessionmaker(bind=engine)
# thread safe session-maker
DBSession = scoped_session(Session)


# with Session() as session:
#     # 获取数据库连接
#     connection = session.connection()
#     savepoint = connection.begin_nested()
#     print(savepoint)


def getEffectRows(result: Result) -> int:
    r"""
    获取受影响行数

    这里有点问题:源码部分 rowcount 是一个 callable,但实际应该是 int;
    这里绕一点,确保不会出问题,如果返回 -1,说明出现了意料之外的情况

    :param result: 结果集
    :return: 受影响行数
    """
    if isinstance(result, CursorResult):
        effect_row = result.rowcount
        if isinstance(effect_row, int):
            return effect_row
        if callable(effect_row):
            return effect_row()
    return -1


def resultAsDict(result: Result) -> list:
    r"""
    将查询结果转换为 dict-list

    :param result: 结果集
    :return: dict 列表
    """
    keys = result.keys()

    ret = list()
    for item in result.fetchall():
        ret.append(dict(zip(keys, item)))

    return ret


def execute(sql: str, params: dict = None) -> Result:
    r"""
    执行一条查询语句
    :param sql: 查询语句
    :param params: 参数
    :return: 结果集
    """
    if sql is None:
        raise ValueError('sql cannot be None')
    logger.debug('execute sql: ' + sql)
    logger.debug('parameter  : ' + str(params))
    return DBSession().execute(text(sql), params)


def executeQuery(sql: str, params: dict = None, result_type: type = tuple) -> typing.Sequence:
    r"""
    执行一个查询

    :param sql: sql
    :param params: dict
    :param result_type: 结果集类型,可选:tuple、dict
    :return: 序列
    """
    result = execute(sql, params)

    if result_type == dict:
        return resultAsDict(result)
        pass

    # default return_type tuple-list
    return result.fetchall()


def executeUpdate(sql: str, params: dict = None) -> int:
    r"""

    执行一个查询

    :param sql: sql 执行语句
    :param params: dict 查询参数
    :return: 受影响行数
    """
    result = execute(sql, params)
    return getEffectRows(result)


def transactional(rollback: type = Exception):
    r"""
    注解式事务

    用法类似于 spring 环境下的 @Transactional 注解

    注意: 事务控制在 session 级别,不能兼容事务嵌套的场景(理想状态下,应当通过 save-point 实现)

    推荐: 如果遇到很复杂的事务嵌套,显式调用 session,手动控制事务

    :param rollback: 指定触发回滚的异常类型
    :return: 装饰器函数
    """

    def decorator(func):
        def call(*args, **kwargs):
            session = None
            try:
                session = DBSession()
                ret = func(*args, **kwargs)
                session.commit()
                return ret
            except rollback as e:
                if session:
                    session.rollback()
                logger.exception(f'transaction exception, rollback: {str(e)}')
                raise
            finally:
                if session:
                    session.close()

        return call

    return decorator
    pass


def update(sql: str = None):
    r"""
    注解式查询,E.G.::

        @update(sql='UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1')
        def modify(params: dict = None) -> int:
            pass

    :param sql: 要执行的 sql
    :return: decorator
    """

    def decorator(func):
        def call(*args, **kwargs):
            result = execute(sql, args[0])
            return getEffectRows(result)

        return call

    return decorator
    pass


def query(sql: str = None, result_type: type = tuple):
    r"""
    注解式查询,E.G.::

    E.G.::
        @query(sql='SELECT * FROM `t_temp` WHERE (`id`= :id) LIMIT 1', result_type=dict)
        def queryById2(id: int) -> list:
            pass

    :param sql: 要执行的 sql
    :param result_type: 结果集类型,可选:tuple、dict
    :return:  decorator
    """

    def decorator(func):
        def call(*args, **kwargs):
            if sql is None:
                raise ValueError('sql cannot be None')

            first = args[0]
            if isinstance(first, dict):
                result = DBSession().execute(text(sql), args)
            else:
                names = inspect.signature(func).parameters.values()
                params = dict()
                for idx, name in enumerate(names):
                    params[name.name] = args[idx]
                print(params)
                result = DBSession().execute(text(sql), params)

            if result_type == dict:
                keys = result.keys()

                ret = list()
                for item in result.fetchall():
                    ret.append(dict(zip(keys, item)))

                return ret
                # default return_type tuple
                pass
            return result.fetchall()

        return call

    return decorator
    pass


@transactional()
def test_transaction():
    r"""
    测试注解式事务
    :return: None
    """
    session = DBSession()
    session.execute(text("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1"), {'id': 18, 'desc': 'OR 1=3'})
    session.execute(text("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1"), {'id': 18, 'desc': 'OR 1=4'})

    # raise exception
    raise SyntaxError('Syntax error')


@transactional()
def test_api():
    r"""
    测试封装过的函数
    :return: None
    """

    execute("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1", {'id': 18, 'desc': 'OR 1=1'})
    execute("UPDATE `t_temp` SET `desc`= :desc WHERE (`id`= :id) LIMIT 1", {'id': 18, 'desc': 'OR 1=2'})

    # raise exception
    raise SyntaxError('Syntax error')

posted on 2024-12-06 08:43  疯狂的妞妞  阅读(3)  评论(0编辑  收藏  举报

导航