Python3 ORM hacking
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Python3 ORM hacking # 说明: # 之前分析了一个Python2 ORM的源代码,这次分析一个Python3的源代码,在写法上 # 还是又挺大的区别的。
# 2016-10-22 深圳 南山平山村 曾剑锋 # # 源码: # https://github.com/michaelliao/awesome-python3-webapp/tree/day-03 # # 参考文章: # 1. python logging模块使用教程 # http://www.jianshu.com/p/feb86c06c4f4 # 2. Python async/await入门 # https://ipfans.github.io/2015/08/introduction-to-async-and-await/ # 3. 浅析python的metaclass # http://jianpx.iteye.com/blog/908121 # 4. Why I got ignored exception when I use aiomysql in python 3.5 #59 # https://github.com/aio-libs/aiomysql/issues/59 # __author__ = 'Michael Liao' import asyncio, logging import aiomysql # SQL日志打印输出模板 def log(sql, args=()): logging.info('SQL: %s' % sql) # 创建数据库连接池 async def create_pool(loop, **kw): logging.info('create database connection pool...') # 标记__pool为文件内全局变量,在其他函数内可以直接访问 global __pool __pool = await aiomysql.create_pool( host=kw.get('host', 'localhost'), port=kw.get('port', 3306), user=kw['user'], password=kw['password'], db=kw['db'], charset=kw.get('charset', 'utf8'), autocommit=kw.get('autocommit', True), maxsize=kw.get('maxsize', 10), minsize=kw.get('minsize', 1), loop=loop ) # 数据库查询 async def select(sql, args, size=None): # 输出SQL日志信息 log(sql, args) global __pool # 从连接池中获取连接,aysnc是异步获取连接 async with __pool.get() as conn: async with conn.cursor(aiomysql.DictCursor) as cur: # 合成实际的SQL await cur.execute(sql.replace('?', '%s'), args or ()) # 根据size来获取数据多少行记录 if size: rs = await cur.fetchmany(size) else: rs = await cur.fetchall() # 给出获取到的信息条数 logging.info('rows returned: %s' % len(rs)) return rs # 数据库直接执行SQL async def execute(sql, args, autocommit=True): log(sql) async with __pool.get() as conn: # 如果不是自动提交 if not autocommit: await conn.begin() try: async with conn.cursor(aiomysql.DictCursor) as cur: await cur.execute(sql.replace('?', '%s'), args) # 返回的执行SQL后有效行数,从代码上可以看出,这部分主要是执行更新、插入、删除等SQL语句 affected = cur.rowcount # 完成提交工作 if not autocommit: await conn.commit() except BaseException as e: # 出现问题,回滚 if not autocommit: await conn.rollback() raise # 直接再次抛出异常 return affected # 返回有效行数 # 合成可替代参数字符串,先使用'?'代替'%s' def create_args_string(num): L = [] for n in range(num): L.append('?') return ', '.join(L) # 对应数据库中每一个字段的一个域的基类 class Field(object): # 域名、域类型、是否是主键、默认值 def __init__(self, name, column_type, primary_key, default): self.name = name self.column_type = column_type self.primary_key = primary_key self.default = default # 重写默认输出的str函数 def __str__(self): return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name) # 字符串类型的域 class StringField(Field): def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'): super().__init__(name, ddl, primary_key, default) # Boolean类型的域 class BooleanField(Field): def __init__(self, name=None, default=False): super().__init__(name, 'boolean', False, default) # 整形类型的域 class IntegerField(Field): def __init__(self, name=None, primary_key=False, default=0): super().__init__(name, 'bigint', primary_key, default) # 浮点类型的域 class FloatField(Field): def __init__(self, name=None, primary_key=False, default=0.0): super().__init__(name, 'real', primary_key, default) # 文本类型的域 class TextField(Field): def __init__(self, name=None, default=None): super().__init__(name, 'text', False, default) # MVC中的Model的元类,主要用于自动生成映射(map)类 class ModelMetaclass(type): # name: 类的名字 # bases: 基类,通常是tuple类型 # attrs: dict类型,就是类的属性或者函数 def __new__(cls, name, bases, attrs): # 过滤掉Model类直接生成的实例类 if name=='Model': return type.__new__(cls, name, bases, attrs) # 从类的属性中获取__table__,其实也就是于数据库对应的表名,如果不存在那么就是等于类名 tableName = attrs.get('__table__', None) or name logging.info('found model: %s (table: %s)' % (name, tableName)) # 创建映射字典 mappings = dict() # 域list fields = [] # 主键标记 primaryKey = None # 获取类中的所有的键值对 for k, v in attrs.items(): # 选择Field类型实例的属性作为映射键值 if isinstance(v, Field): logging.info(' found mapping: %s ==> %s' % (k, v)) # 将当前的键值对放入mapping中 mappings[k] = v if v.primary_key: # 防止出现两个、两个以上的主键 if primaryKey: raise StandardError('Duplicate primary key for field: %s' % k) primaryKey = k else: # 将key添加进入fields中,也就是映射类中的属性和数据库中的表的域, 这里面不包含主键 fields.append(k) # 前面可能没有找到主键,提示一下 if not primaryKey: raise StandardError('Primary key not found.') # 删除这些类属性 防止访问实例属性的时候发生错误,因为实例属性优先级大于类属性 for k in mappings.keys(): attrs.pop(k) escaped_fields = list(map(lambda f: '`%s`' % f, fields)) attrs['__mappings__'] = mappings # 保存属性和列的映射关系 attrs['__table__'] = tableName # 表名 attrs['__primary_key__'] = primaryKey # 主键属性名 attrs['__fields__'] = fields # 除主键外的属性名 # 生成查询SQL attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName) # 生成插入SQL attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1)) # 生成更新SQL attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey) # 生成删除SQL attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey) # 调用type生成类 return type.__new__(cls, name, bases, attrs) # 继承自ModelMetaclass元类、dict的类 class Model(dict, metaclass=ModelMetaclass): def __init__(self, **kw): super(Model, self).__init__(**kw) # 重写get方法 def __getattr__(self, key): try: return self[key] except KeyError: raise AttributeError(r"'Model' object has no attribute '%s'" % key) # 重写set方法 def __setattr__(self, key, value): self[key] = value # 重写get方法 def getValue(self, key): return getattr(self, key, None) # 获取值,当不存在的时候获取的是默认值 def getValueOrDefault(self, key): value = getattr(self, key, None) if value is None: field = self.__mappings__[key] if field.default is not None: value = field.default() if callable(field.default) else field.default logging.debug('using default value for %s: %s' % (key, str(value))) setattr(self, key, value) return value @classmethod async def findAll(cls, where=None, args=None, **kw): ' find objects by where clause. ' # 获取元类自动生成的SQL语句,并根据当前的参数,继续合成 sql = [cls.__select__] if where: sql.append('where') sql.append(where) if args is None: args = [] orderBy = kw.get('orderBy', None) if orderBy: sql.append('order by') sql.append(orderBy) limit = kw.get('limit', None) if limit is not None: sql.append('limit') if isinstance(limit, int): sql.append('?') args.append(limit) elif isinstance(limit, tuple) and len(limit) == 2: sql.append('?, ?') args.extend(limit) else: raise ValueError('Invalid limit value: %s' % str(limit)) # 直接调用select函数来处理, 这里是等待函数执行完成函数才能返回 rs = await select(' '.join(sql), args) # 该类本身是字典,自己用自己生成新的实例,里面的阈值正好也是需要查询 return [cls(**r) for r in rs] @classmethod async def findNumber(cls, selectField, where=None, args=None): ' find number by select and where. ' # 这里的 _num_ 什么意思?别名? 我估计是mysql里面一个记录实时查询结果条数的变量 sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)] if where: sql.append('where') sql.append(where) rs = await select(' '.join(sql), args, 1) if len(rs) == 0: return None return rs[0]['_num_'] @classmethod async def find(cls, pk): ' find object by primary key. ' # 通过主键查找对象, 如果不存在,那么就返回None rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1) if len(rs) == 0: return None return cls(**rs[0]) # 插入语句对应的方法 async def save(self): args = list(map(self.getValueOrDefault, self.__fields__)) args.append(self.getValueOrDefault(self.__primary_key__)) rows = await execute(self.__insert__, args) if rows != 1: logging.warn('failed to insert record: affected rows: %s' % rows) # 更新语句对应的方法 async def update(self): args = list(map(self.getValue, self.__fields__)) args.append(self.getValue(self.__primary_key__)) rows = await execute(self.__update__, args) if rows != 1: logging.warn('failed to update by primary key: affected rows: %s' % rows) # 删除语句对应的方法 async def remove(self): args = [self.getValue(self.__primary_key__)] rows = await execute(self.__delete__, args) if rows != 1: logging.warn('failed to remove by primary key: affected rows: %s' % rows)