概述
在一个Web App中,所有数据,包括用户信息、发布的日志、评论等,都存储在数据库中。
Web App里面有很多地方都要访问数据库。访问数据库需要创建数据库连接、游标对象,然后执行SQL语句,最后处理异常,清理资源。这些访问数据库的代码如果分散到各个函数中,势必无法维护,也不利于代码复用。
所以,我们要首先把常用的SELECT、INSERT、UPDATE和DELETE操作用函数封装起来。
由于Web框架使用了基于asyncio的aiohttp,这是基于协程的异步模型。在协程中,不能调用普通的同步IO操作,因为所有用户都是由一个线程服务的,协程的执行速度必须非常快,才能处理大量用户的请求。而耗时的IO操作不能在协程中以同步的方式调用,否则,等待一个IO操作时,系统无法响应任何其他用户。
这就是异步编程的一个原则:一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。
幸运的是aiomysql
为MySQL数据库提供了异步IO的驱动。
简单的orm实现技术原理可参考先前写的博文:12、元类(metaclass)实现精简ORM框架
一、创建连接池
我们需要创建一个全局的连接池,每个HTTP请求都可以从连接池中直接获取数据库连接。使用连接池的好处是不必频繁地打开和关闭数据库连接,而是能复用就尽量复用。
连接池由全局变量__pool
存储,缺省情况下将编码设置为utf8
,自动提交事务:
@asyncio.coroutine
def create_pool(loop, **kwargs):
logging.info('create database connection pool...')
global __pool
__pool = yield from aiomysql.create_pool(
host=kwargs.get('host', 'localhost'),
port=kwargs.get('port', 3306),
user=kwargs['user'],
password=kwargs['password'],
db=kwargs['db'],
charset=kwargs.get('charset', 'utf8'),
autocommit=kwargs.get('autocommit', True),
maxsize=kwargs.get('maxsize', 10),
minsize=kwargs.get('minsize', 1),
loop=loop
)
关于aiomysql.create_pool的详细讲述,请参考博文:16、【翻译】aiomysql-Pool
create_pool方法中的kwargs是关键字参数,保存着连接数据库所必须的host、port、user、password等信息,这些关键字参数在函数内部自动组装为一个dict。
二、封装select语句
# 该协程封装的是查询事务,第一个参数为sql语句,第二个为sql语句中占位符的参数列表,第三个参数是要查询数据的数量 @asyncio.coroutine def select(sql, args, size=None): log(sql, args) #显示sql语句和参数 global __pool #引用全局变量 with (yield from __pool) as conn: # 以上下文方式打开conn连接,无需再调用conn.close() 或写成 with await __pool as conn: cur = yield from conn.cursor(aiomysql.DictCursor) # 创建一个DictCursor类指针,返回dict形式的结果集 yield from cur.execute(sql.replace('?', '%s'), args or ()) # 替换占位符,SQL语句占位符为?,MySQL为%s。 if size: rs = yield from cur.fetchmany(size) #接收size条返回结果行. else: rs = yield from cur.fetchall() #接收全部的返回结果行. yield from cur.close() #关闭游标 logging.info('rows returned: %s' % len(rs)) #打印返回结果行数 return rs #返回结果
SQL语句的占位符是?
,而MySQL的占位符是%s
,select()
函数在内部自动替换。注意要始终坚持使用带参数的SQL,而不是自己拼接SQL字符串,这样可以防止SQL注入攻击。
注意到yield from
将调用一个子协程(也就是在一个协程中调用另一个协程)并直接获得子协程的返回结果。
如果传入size
参数,就通过fetchmany()
获取最多指定数量的记录,否则,通过fetchall()
获取所有记录。
三、封装INSERT、UPDATE、DELETE语句
#执行update,insert,delete语句,可以统一用一个execute函数执行, # 因为它们所需参数都一样,而且都只返回一个整数表示影响的行数。 @asyncio.coroutine def execute(sql, args, autocommit=True): log(sql) with (yield from __pool) as conn: if not autocommit: yield from conn.begin() try: cur = yield from conn.cursor() yield from cur.execute(sql.replace('?', '%s'), args) affected = cur.rowcount yield from cur.close() if not autocommit: yield from conn.commit() except BaseException as e: #如果事务处理出现错误,则回退 if not autocommit: yield from conn.rollback() raise return affected
execute()
函数和select()
函数所不同的是,cursor对象不返回结果集,而是通过rowcount
返回结果数。
四、ORM
设计ORM需要从上层调用者角度来设计。
我们先考虑如何定义一个User
对象,然后把数据库表users
和它关联起来。
from orm import Model, StringField, IntegerField class User(Model): __table__ = 'users' id = IntegerField(primary_key=True) name = StringField()
注意到定义在User
类中的__table__
、id
和name
是类的属性,不是实例的属性,类的所有示例都可以访问!!!所以,在类级别上定义的属性用来描述User
对象和表的映射关系,而实例属性用来描述数据库表中的一行数据,必须通过__init__()
方法去初始化,所以两者互不干扰:
# 创建实例: user = User(id=123, name='Michael') # 存入数据库: user.insert() # 查询所有User对象: users = User.findAll()
五、Field以及各种Field子类
用来描述数据库中表字段的属性(字段名、类型、是否主键等等)。
首先定义基类Field:
class Field(object): def __init__(self, name, column_type, primary_key, default): self.name = name self.column_type = column_type self.primary_key = primary_key self.default = default def __str__(self): return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
__str__()是Python中有特殊用途的函数,用来定制类。当我们print(Field或Field子类对象)时,会打印该对象(字段)的类名,字段类别以及字段名称。
然后在Field的基础上,进一步定义各种类型的Field:
# 字符串类型字段,继承自父类Field class StringField(Field): #如果一个函数的参数中含有默认参数,则这个默认参数后的所有参数都必须是默认参数 , # 否则会抛出:SyntaxError: non-default argument follows default argument的异常。 def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'): super(StringField, self).__init__(name, ddl, primary_key, default) # 布尔值类型字段,继承自父类Field class BooleanField(Field): def __init__(self, name=None, default=False): super(BooleanField, self).__init__(name, 'boolean', False, default) # 整数类型字段,继承自父类Field class IntegerField(Field): def __init__(self, name=None, primary_key=False, default=0): super(IntegerField, self).__init__(name, 'bigint', primary_key, default) # 浮点数类型字段,继承自父类Field class FloatField(Field): def __init__(self, name=None, primary_key=False, default=0.0): super(FloatField, self).__init__(name, 'real', primary_key, default) # 文本类型字段,继承自父类Field class TextField(Field): def __init__(self, name=None, default=None): super(TextField, self).__init__(name, 'text', False, default)
上述子类生成对象时,均会调用父类的Init方法初始化。
可见,数据库表字段共4个属性:字段名、字段类型、是否主键、默认值。
六、编写元类—ModelMetaclass
1 class ModelMetaclass(type): 2 3 def __new__(cls, name, bases, attrs): 4 # 排除Model类本身: 5 if name=='Model': 6 return type.__new__(cls, name, bases, attrs) 7 # 获取table名称: 8 tableName = attrs.get('__table__', None) or name 9 logging.info('found model: %s (table: %s)' % (name, tableName)) 10 # 获取所有的Field和主键名: 11 mappings = dict() 12 fields = [] 13 primaryKey = None 14 for k, v in attrs.items(): 15 if isinstance(v, Field): 16 logging.info(' found mapping: %s ==> %s' % (k, v)) 17 mappings[k] = v 18 if v.primary_key: 19 # 找到主键: 20 if primaryKey: 21 raise RuntimeError('Duplicate primary key for field: %s' % k) 22 primaryKey = k 23 else: 24 fields.append(k) 25 if not primaryKey: 26 raise RuntimeError('Primary') 27 for k in mappings.keys(): 28 attrs.pop(k) 29 escaped_fields = list(map(lambda f: '`%s`' % f, fields)) 30 attrs['__mappings__'] = mappings # 保存属性和列的映射关系 31 attrs['__table__'] = tableName 32 attrs['__primary_key__'] = primaryKey # 主键属性名 33 attrs['__fields__'] = fields # 除主键外的属性名 34 # 构造默认的SELECT, INSERT, UPDATE和DELETE语句: 35 ##以下四种方法保存了默认了增删改查操作,其中添加的反引号``,是为了避免与sql关键字冲突的,否则sql语句会执行出错 36 attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName) 37 attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1)) 38 attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey) 39 attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey) 40 return type.__new__(cls, name, bases, attrs)
1、首先进行判断,如果将要创建的类是Model,无需做个性化定制,直接通过type创建,排除对Model类的修改;
2、获取table名称,即类名;
3、mappings保存类属性和表字段的映射关系;primaryKey保存映射表主键的类属性;fields保存映射其余表字段的类属性;
4、注意匿名函数【 lambda f: '`%s`' % f, fields 】 的用法,实际上第29行代码就是这个意思:
fields = ['one', 'two', 'three'] def fun(f): return '`%s`' % f escaped_fields = list(map(fun, fields))
使用匿名函数lambda,针对fields中每个元素,如 name,加上反引号后:`name`后返回;
为何要加上反引号?它是为了区分MYSQL的保留字与普通字符而引入的符号。
5、下面是一系列为定制类动态添加的属性:
(1) attrs['__mappings__'] = mappings -》 保存类属性和表字段的映射关系;
(2)attrs['__table__'] = tableName -》 保存该类对应的表名;
(3)attrs['__primary_key__'] = primaryKey -》 保存映射表中主键字段的类属性;
(4)attrs['__fields__'] = fields -》 保存映射非主键字段的类属性;
接着是SQL语句模板,届时调用时只需要将参数传递给Mysql占位符?即可:
(5)attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
例:
(6)attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
例:
(7)attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
例:
(8)attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
例:
6、在模块加载时使用type动态地定制类。
7、注意:
(1)以上属性都是类的属性,属类所有,所有实例对象共享一个类属性。而实例属性属于各个实例所有,互不干扰;在编写程序的时候,千万不要对实例属性和类属性使用相同的名字,因为相同名称的实例属性将屏蔽掉类属性,但是当你删除实例属性后,再使用相同的名称,访问到的将是类属性。
(2)表的字段名使用类属性名,即名字相同!
七、编写基类——Model
1 class Model(dict, metaclass=ModelMetaclass): 2 3 def __init__(self, **kwargs): 4 super(Model, self).__init__(**kwargs) 5 6 def __getattr__(self, key): 7 try: 8 return self[key] 9 except KeyError: 10 raise AttributeError(r"'Model' object has no attribute '%s'" % key) 11 12 def __setattr__(self, key, value): 13 self[key] = value 14 15 def getValue(self, key): 16 return getattr(self, key, None) 17 18 def getValueOrDefault(self, key): 19 value = getattr(self, key, None) 20 if value is None: 21 field = self.__mappings__[key] 22 if field.default is not None: 23 value = field.default() if callable(field.default) else field.default 24 logging.debug('using default value for %s: %s' % (key, str(value))) 25 setattr(self, key, value) 26 return value 27 28 @classmethod 29 @asyncio.coroutine 30 def findAll(cls, where=None, args=None, **kwargs): 31 'find objects by where clause' 32 sql = [cls.__select__] #sql是list类型,元素是定制类的类属性——select查询语句模板 33 if where: 34 sql.append('where') 35 sql.append(where) 36 if args is None: 37 args = [] 38 orderBy = kwargs.get('orderBy', None) 39 if orderBy: 40 sql.append('order by') 41 sql.append(orderBy) 42 limit = kwargs.get('limit', None) 43 if limit is not None: 44 sql.append('limit') 45 if isinstance(limit, int): 46 sql.append('?') 47 args.append(limit) 48 elif isinstance(limit, tuple) and len(limit) == 2: 49 sql.append('?, ?') 50 args.extend(limit) 51 else: 52 raise ValueError('Invalid limit value: %s' % str(limit)) 53 rs = yield from select(' '.join(sql), args) #传入sql语句及参数,调用select语句获取查询结果 54 return [cls(**r) for r in rs] 55 56 @classmethod 57 @asyncio.coroutine 58 def findNumber(cls, selectField, where=None, args=None): 59 'find number by select and where.' 60 sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)] 61 if where: 62 sql.append('where') 63 sql.append(where) 64 rs = yield from select(' '.join(sql), args, 1) 65 if len(rs) == 0: 66 return None 67 return rs[0]['_num_'] 68 69 @classmethod 70 @asyncio.coroutine 71 def find(cls, pk): 72 'find object by primary key.' 73 rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1) 74 if len(rs) == 0: 75 return None 76 return cls(**rs[0]) 77 78 @asyncio.coroutine 79 def save(self): 80 args = list(map(self.getValueOrDefault, self.__fields__)) 81 args.append(self.getValueOrDefault(self.__primary_key__)) 82 rows = yield from execute(self.__insert__, args) 83 if rows != 1: 84 logging.warn('failed to insert record: affected rows: %s' % rows) 85 86 @asyncio.coroutine 87 def update(self): 88 args = list(map(self.getValue, self.__fields__)) 89 args.append(self.getValue(self.__primary_key__)) 90 rows = yield from execute(self.__update__, args) 91 if rows != 1: 92 logging.warn('failed to update by primary key: affected rows: %s' % rows) 93 94 @asyncio.coroutine 95 def remove(self): 96 args = [self.getValue(self.__primary_key__)] 97 rows = yield from execute(self.__delete__, args) 98 if rows != 1: 99 logging.warn('failed to remove by primary key: affected rows: %s' % rows)
1、__getattr__为内置方法,当使用点号获取实例属性,例如 stu.score 时,如果属性score不存在就自动调用__getattr__方法。注意:已有的属性,比如name,不会在__getattr__中查找;
2、__setattr__当设置实例属性时自动调用,如 stu.score=5时,就会调用__setattr__方法 self.[score]=5;
3、getValueOrDefault() -> 获取属性值,如果为空,则取默认值;
4、@classmethod装饰的方法是类方法,直接使用类名调用,所有子类都可以调用类方法。不需要实例化,不需要 self 参数,第一个参数是表示自身类的 cls 参数。
5、分析findAll() 方法:
(1)第53行语句: rs = yield from select(' '.join(sql), args),调试可见返回结果是list类型,元素是dict类型的每行表数据:
(2)第54行语句: return [cls(**r) for r in rs],不太能理解,故编写语句 result = User.findAll() 来将返回值保存在result参数中,调试可得:
由此可得出结论,[cls(**r) for r in rs] 是将查询数据库表得到的每行结果,生成cls类的对象。
6、分析save() 方法
我们编写下列语句调用save方法:
@asyncio.coroutine def save(self): args = list(map(self.getValueOrDefault, self.__fields__)) args.append(self.getValueOrDefault(self.__primary_key__)) rows = yield from execute(self.__insert__, args) if rows != 1: logging.warn('failed to insert record: affected rows: %s' % rows)
经分析可得,获取调用save函数的类属性__friends__及__primary_key__的值,有默认值就传入默认值,作为sql语句参数执行。
八、定义映射数据库表的类
def next_id(): return '%015d%s000' % (int(time.time() * 1000), uuid.uuid4().hex) class User(Model): __table__ = 'users' id = StringField(primary_key=True, default=next_id, ddl='varchar(50)') email = StringField(ddl='varchar(50)') passwd = StringField(ddl='varchar(50)') admin = BooleanField() name = StringField(ddl='varchar(50)') image = StringField(ddl='varchar(500)') created_at = FloatField(default=time.time)
九、编写测试代码
import orm from models import User, Blog, Comment import asyncio loop = asyncio.get_event_loop() async def test(): # 创建连接池,里面的host,port,user,password需要替换为自己数据库的信息 await orm.create_pool(loop=loop, host='127.0.0.1', port=3306, user='root', password='root', db='awesome') # 没有设置默认值的一个都不能少 u = User(name='Test', email='547280745@qq.com', passwd='1234567890', image='about:blank', id="123") await u.save() result = await User.findAll() loop.run_until_complete(test())
在Mysql数据中查询结果可知导入数据成功: