Loading

ORM框架疏理——廖雪峰实战系列(一)

ORM(Object Relational Mapping,对象关系映射),是一种程序设计技术,用于实现面向对象编程语言里不同类型系统的数据之间的转换。从效果上来说,它其实创建了一个可在编程语言里使用的“虚拟对象数据库”。

上面是维基百科的解释,但是为什么要用ORM这种编程技术呢?

就这个实战作业来看:

  博客——标题、摘要、内容、评论、作者、创作时间

  评论——内容、评论人、评论文章、评论时间

  用户——姓名、邮箱、口令、权限

上述信息,都需要有组织的存储在数据库中。数据库方面很简单,只需要维护三张表,如果想要加强各表间的联系,可以使用外键。但是,Python该如何组织这些信息呢?每篇博客有不同标题、摘要、内容...,每篇评论和用户信息也个不相同。像C这种面向过程的语言必须创建高级数据结构,而Python这种面向对象的语言真是有天然的优势,我们只需要把每篇博客、评论或者用户看作对象,使用属性表示其蕴含的信息。最后,我们还要解决一个问题:Python和数据库如何高效有组织的交换数据呢?数据库表是由一条条记录组成的,每条记录又包含不同字段。记录和字段,对象和属性...看起来两者关系类不类似?这就是我们的思路——

  将数据库表的每条记录映射为对象,每条记录的字段和对象的属性相应;同时透过对象方法执行SQL命令。

我们编写的ORM框架就是实现上述想法。

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 #Program:
  4 #       This is a ORM for MySQL.
  5 #History:
  6 #2017/06/29       smile          First release
  7 
  8 import logging
  9 import asyncio
 10 import aiomysql
 11 
 12 def log(sql,args=()):
 13     logging.info('SQL: %s' % sql)
 14 
 15 #Close pool
 16 async def destory_pool():
 17     global __pool
 18     __pool.close()
 19     await __pool.wait_closed()
 20 
 21 #Create connect pool
 22 #Parameter: host,port,user,password,db,charset,autocommit
 23 #           maxsize,minsize,loop
 24 async def create_pool(loop,**kw):
 25     logging.info('Create database connection pool...')
 26     global __pool
 27     __pool = await aiomysql.create_pool(
 28         host = kw.get('host', 'localhost'),
 29         port = kw.get('port', 3306),
 30         user = kw['user'],
 31         password = kw['password'],
 32         db = kw['db'],
 33         charset = kw.get('charset', 'utf8'),
 34         autocommit = kw.get('autocommit', 'True'),
 35         maxsize = kw.get('maxsize', 10),
 36         minsize = kw.get('minsize', 1),
 37         loop = loop
 38     )
 39 
 40 #Package SELECT function that can execute SELECT command.
 41 #Setup 1:acquire connection from connection pool.
 42 #Setup 2:create a cursor to execute MySQL command.
 43 #Setup 3:execute MySQL command with cursor.
 44 #Setup 4:return query result.
 45 async def select(sql,args,size=None):
 46     log(sql,args)
 47     global __pool
 48     async with __pool.acquire() as conn:
 49         async with conn.cursor(aiomysql.DictCursor) as cur:
 50             await cur.execute(sql.replace('?','%s'),args or ())
 51             if size:
 52                 rs = await cur.fetchmany(size)
 53             else:
 54                 rs = await cur.fetchall()
 55 
 56         logging.info('rows returned: %s' % len(rs))
 57         return rs
 58 
 59 #Package execute function that can execute INSERT,UPDATE and DELETE command
 60 async def execute(sql,args,autocommit=True):
 61     global __pool
 62     #acquire connection from connection pool
 63     async with __pool.acquire() as conn:
 64         #如果MySQL禁止隐式提交,则标记事务开始
 65         if not autocommit:
 66             await conn.begin()
 67         try:
 68             #create cursor to execute MySQL command
 69             async with conn.cursor(aiomysql.DictCursor) as cur:
 70                 await cur.execute(sql.replace('?','%s'),args or ())
 71                 affectrow = cur.rowcount
 72                 #如果MySQL禁止隐式提交,手动提交事务
 73                 if not autocommit:
 74                     await cur.commit()
 75         #如果事务处理出现错误,则回退
 76         except BaseException as e:
 77             await conn.rollback()
 78             raise
 79 
 80         #return number of affected rows
 81         return affectrow
 82 
 83 #Create placeholder with '?'
 84 def create_args_string(num):
 85     L = []
 86     for i in range(num):
 87         L.append('?')
 88     return ', '.join(L)
 89 
 90 #A base class about Field
 91 #描述字段的字段名,数据类型,键信息,默认值
 92 class Field(object):
 93     def __init__(self,name,column_type,primary_key,default):
 94         self.name = name
 95         self.column_type = column_type
 96         self.primary_key = primary_key
 97         self.default = default
 98 
 99     def __str__(self):
100         return '<%s,%s:%s>' % (self.__class__.__name__,self.column_type,self.name)
101 
102 #String Field
103 class StringField(Field):
104     def __init__(self,name=None,ddl='varchar(100)',default=None,primary_key=False):
105         super(StringField,self).__init__(name,ddl,primary_key,default)
106 
107 #Bool Fileed
108 class BooleanField(Field):
109     def __init__(self,name=None,ddl='boolean',default=False,primary_key=False):
110         super(BooleanField,self).__init__(name,ddl,primary_key,default)
111 
112 #Integer Field
113 class IntegerField(Field):
114     def __init__(self,name=None,ddl='bigint',default=None,primary_key=None):
115         super(IntegerField,self).__init__(name,ddl,primary_key,default)
116 
117 #Float Field
118 class FloatField(Field):
119     def __init__(self,name=None,ddl='real',default=None,primary_key=None):
120         super(FloatField,self).__init__(name,ddl,primary_key,default)
121 
122 #Text Field
123 class TextField(Field):
124     def __init__(self,name=None,ddl='text',default=None,primary_key=None):
125         super(TextField,self).__init__(name,ddl,primary_key,default)
126 
127 #Meatclass about ORM
128 #作用:
129 #首先,拦截类的创建
130 #然后,修改类
131 #最后,返回修改后的类
132 class ModelMetaclass(type):
133     #采集应用元类的子类属性信息
134     #将采集的信息作为参数传入__new__方法
135     #应用__new__方法修改类
136     def __new__(cls,name,bases,attrs):
137         #不对Model类应用元类
138         if name == 'Model':
139             return type.__new__(cls,name,bases,attrs)
140 
141         #获取数据库表名。若__table__为None,则取用类名
142         tablename = attrs.get('__table__',None) or name
143         logging.info('Found model: %s (table: %s)' % (name,tablename))
144 
145         #存储映射表类的属性(键-值)
146         mappings = dict()
147         #存储映射表类的非主键属性(仅键)
148         fields = []
149         #主键对应字段
150         primarykey = None
151         for k,v in attrs.items():
152             if isinstance(v,Field):
153                 logging.info('Found mapping: %s ==> %s' % (k,v))
154                 mappings[k] = v
155 
156                 if v.primary_key:
157                     logging.info('Found primary key')
158                     if primarykey:
159                         raise Exception('Duplicate primary key for field:%s' % k)
160                     primarykey = k
161                 else:
162                     fields.append(k)
163 
164         #如果没有主键抛出异常
165         if not primarykey:
166             raise Exception('Primary key not found')
167 
168         #删除映射表类的属性,以便应用新的属性
169         for i in mappings.keys():
170             attrs.pop(i)
171 
172         #使用反单引号" ` "区别MySQL保留字,提高兼容性
173         escaped_fields = list(map(lambda f:'`%s`' % f,fields))
174 
175         #重写属性
176         attrs['__mappings__'] = mappings
177         attrs['__table__'] = tablename
178         attrs['__primary_key__'] = primarykey
179         attrs['__fields__'] = fields
180         attrs['__select__'] = 'SELECT `%s`, %s FROM `%s`' % (primarykey,','.join(escaped_fields),tablename)
181         attrs['__insert__'] = 'INSERT `%s` (%s,`%s`) VALUES (%s)' % (tablename,','.join(escaped_fields),primarykey,create_args_string(len(escaped_fields) + 1))
182         attrs['__update__'] = 'UPDATE `%s` SET %s WHERE `%s` = ?' % (tablename,','.join(map(lambda f:'`%s` = ?' % (mappings.get(f).name or f),fields)),primarykey)
183         attrs['__delete__'] = 'DELETE FROM `%s` WHERE `%s` = ?' % (tablename,primarykey)
184 
185         #返回修改后的类
186         return type.__new__(cls,name,bases,attrs)
187 
188 #A base class about Model
189 #继承dict类特性
190 #附加方法:
191 #       以属性形式获取值
192 #       拦截私设属性
193 class Model(dict,metaclass=ModelMetaclass):
194     def __init__(self,**kw):
195         super(Model,self).__init__(**kw)
196 
197     def __getattr__(self,key):
198         try:
199             return self[key]
200         except KeyError:
201             raise AttributeError(r"'Model' object has no attribute '%s'" % key)
202 
203     def __setattr__(self,key,value):
204         self[key] = value
205 
206     def getValue(self,key):
207         return getattr(self,key,None)
208 
209     def getValueorDefault(self,key):
210         value = getattr(self,key,None)
211         if value is None:
212             field = self.__mappings__[key]
213             if field.default is not None:
214                 value = field.default() if callable(field.default) else field.default
215                 logging.debug('using default value for %s: %s' % (key,str(value)))
216                 setattr(self,key,value)
217 
218         return value
219 
220     #ORM框架下,每条记录作为对象返回
221     #@classmethod定义类方法,类对象cls便可完成某些操作
222     @classmethod
223     async def findAll(cls,where=None,args=None,**kw):
224         sql = [cls.__select__]
225         #添加WHERE子句
226         if where:
227             sql.append('WHERE')
228             sql.append(where)
229 
230         if args is None:
231             args = []
232 
233         orderby = kw.get('orderby',None)
234         #添加ORDER BY子句
235         if orderby:
236             sql.append('ORDER BY')
237             sql.append(orderby)
238 
239         limit = kw.get('limit',None)
240         #添加LIMIT子句
241         if limit:
242             sql.append('LIMIT')
243             if isinstance(limit,int):
244                 sql.append('?')
245                 args.append(limit)
246             elif isinstance(limit,tuple):
247                 sql.append('?, ?')
248                 args.extend(limit)
249             else:
250                 raise ValueError('Invalid limit value: %s' % str(limit))
251 
252         #execute SQL
253         rs = await select(' '.join(sql),args)
254         #将每条记录作为对象返回
255         return [cls(**r) for r in rs]
256 
257     
258     #过滤结果数量
259     @classmethod
260     async def findNumber(cls,selectField,where=None,args=None):
261         sql = ['SELECT %s _num_ from `%s`' % (selectField,cls.__table__)]
262 
263         #添加WHERE子句
264         if where:
265             sql.append('WHERE')
266             sql.append(where)
267 
268         rs = await select(' '.join(sql),args)
269         if len(rs) == 0:
270             return None
271         return rs[0]['_num_']
272 
273     #返回主键的一条记录
274     @classmethod
275     async def find(cls,pk):
276         rs = await select('%s WHERE `%s` = ?' % (cls.__select__,cls.__primary_key__),[pk],1)
277         if len(rs) == 0:
278             return None
279         return cls(**rs[0])
280 
281     #INSERT command
282     async def save(self):
283         args = list(map(self.getValueorDefault,self.__fields__))
284         args.append(self.getValueorDefault(self.__primary_key__))
285         rows = await execute(self.__insert__,args)
286         if rows != 1:
287             logging.warn('Faield to insert record:affected rows: %s' % rows)
288 
289     #UPDATE command
290     async def update(self):
291         args = list(map(self.getValue,self.__fields__))
292         args.append(self.getValue(self.__primary_key__))
293         rows = await execute(self.__update__,args)
294         if rows != 1:
295             logging.warn('Faield to update by primary_key:affectesd rows: %s' % rows)
296 
297     #DELETE command
298     async def remove(self):
299         args = [self.getValue(self.__primary_key__)]
300         rows = await execute(self.__delete__,args)
301         if rows != 1:
302             logging.warn('Faield to remove by primary key:affected: %s' % rows)

 

posted @ 2017-07-29 21:15  未夏  阅读(1814)  评论(0编辑  收藏  举报