Python实现ORM
ORM即把数据库中的一个数据表给映射到代码里的一个类上,表的字段对应着类的属性。将增删改查等基本操作封装为类对应的方法,从而写出更干净和更富有层次性的代码。
以查询数据为例,原始的写法要Python代码sql混合,示例代码如下:
1 import MySQLdb 2 import os,sys 3 4 def main(): 5 conn=MySQLdb.connect(host="localhost",port=3306,passwd='toor',user='root') 6 conn.select_db("xdyweb") 7 cursor=conn.cursor() 8 count=cursor.execute("select * from users") 9 result=cursor.fetchmany() 10 print(isinstance(result,tuple)) 11 print(type(result)) 12 print(len(result)) 13 for i in result: 14 print(i) 15 for j in i: 16 print(j) 17 print("row count is %s"%count) 18 cursor.close() 19 conn.close() 20 21 if __name__=="__main__": 22 cp=os.path.abspath('.') 23 sys.path.append(cp) 24 main()
而我们现在想要实现的是类似这样的效果:
1 #查找: 2 u=user.get(id=1) 3 #添加 4 u=user(name='y',password='y',email='1@q.com') 5 u.insert()
实现思路是遍历Model的属性,得出要操作的字段,然后根据不同的操作要求(增,删,改,查)去动态生成不同的sql语句。
1 #coding:utf-8 2 3 #author:xudongyang 4 5 #19:25 2015/4/15 6 7 import logging,time,sys,os,threading 8 import test as db 9 # logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',datefmt='%a, %d %b %Y %H:%M:%S') 10 logging.basicConfig(level=logging.INFO) 11 12 class Field(object): 13 #映射数据表中一个字段的属性,包括字段名称,默认值,是否主键,可空,可更新,可插入,字段类型(varchar,text,Integer之类),字段顺序 14 _count=0#当前定义的字段是类的第几个字段 15 def __init__(self,**kw): 16 self.name = kw.get('name', None) 17 self._default = kw.get('default', None) 18 self.primary_key = kw.get('primary_key', False) 19 self.nullable = kw.get('nullable', False) 20 self.updatable = kw.get('updatable', True) 21 self.insertable = kw.get('insertable', True) 22 self.ddl = kw.get('ddl', '') 23 self._order = Field._count 24 Field._count = Field._count + 1 25 @property 26 def default(self): 27 d = self._default 28 return d() if callable(d) else d 29 30 class StringField(Field): 31 #继承自Field, 32 def __init__(self, **kw): 33 if not 'default' in kw: 34 kw['default'] = '' 35 if not 'ddl' in kw: 36 kw['ddl'] = 'varchar(255)' 37 super(StringField, self).__init__(**kw) 38 39 class IntegerField(Field): 40 41 def __init__(self, **kw): 42 if not 'default' in kw: 43 kw['default'] = 0 44 if not 'ddl' in kw: 45 kw['ddl'] = 'bigint' 46 super(IntegerField, self).__init__(**kw) 47 class FloatField(Field): 48 49 def __init__(self, **kw): 50 if not 'default' in kw: 51 kw['default'] = 0.0 52 if not 'ddl' in kw: 53 kw['ddl'] = 'real' 54 super(FloatField, self).__init__(**kw) 55 56 class BooleanField(Field): 57 58 def __init__(self, **kw): 59 if not 'default' in kw: 60 kw['default'] = False 61 if not 'ddl' in kw: 62 kw['ddl'] = 'bool' 63 super(BooleanField, self).__init__(**kw) 64 65 class TextField(Field): 66 67 def __init__(self, **kw): 68 if not 'default' in kw: 69 kw['default'] = '' 70 if not 'ddl' in kw: 71 kw['ddl'] = 'text' 72 super(TextField, self).__init__(**kw) 73 74 class BlobField(Field): 75 76 def __init__(self, **kw): 77 if not 'default' in kw: 78 kw['default'] = '' 79 if not 'ddl' in kw: 80 kw['ddl'] = 'blob' 81 super(BlobField, self).__init__(**kw) 82 83 class VersionField(Field): 84 85 def __init__(self, name=None): 86 super(VersionField, self).__init__(name=name, default=0, ddl='bigint') 87 88 def _gen_sql(table_name, mappings): 89 print(__name__+'is called'+str(time.time())) 90 pk = None 91 sql = ['-- generating SQL for %s:' % table_name, 'create table `%s` (' % table_name] 92 for f in sorted(mappings.values(), lambda x, y: cmp(x._order, y._order)): 93 if not hasattr(f, 'ddl'): 94 raise StandardError('no ddl in field "%s".' % n) 95 ddl = f.ddl 96 nullable = f.nullable 97 if f.primary_key: 98 pk = f.name 99 sql.append(nullable and ' `%s` %s,' % (f.name, ddl) or ' `%s` %s not null,' % (f.name, ddl)) 100 sql.append(' primary key(`%s`)' % pk) 101 sql.append(');') 102 sql='\n'.join(sql) 103 logging.info('sql is :'+sql) 104 return sql 105 106 class ModelMetaClass(type): 107 #为什么__new__方法会被调用两次 108 #为什么attrs.pop(k)要进行这个,而且进行了之后u.name就可以输出yy而不是一个Field对象 109 def __new__(cls,name,base,attrs): 110 logging.info("cls is:"+str(cls)) 111 logging.info("name is:"+str(name)) 112 logging.info("base is:"+str(base)) 113 logging.info("attrs is:"+str(attrs)) 114 print('new is called at '+str(cls)+str(time.time())) 115 116 if name =="Model": 117 return type.__new__(cls,name,base,attrs) 118 mapping=dict() 119 primary_key=None 120 for k,v in attrs.iteritems(): 121 primary_key=None 122 if isinstance(v,Field): 123 if not v.name: 124 v.name=k 125 mapping[k]=v 126 #检测是否是主键 127 if v.primary_key: 128 if primary_key: 129 raise TypeError("There only should be on primary_key") 130 if v.updatable: 131 logging.warning('primary_key should not be changed') 132 v.updatable=False 133 if v.nullable: 134 logging.warning('pri.. not be.null') 135 v.nullable=False 136 primary_key=v 137 138 for k in mapping.iterkeys(): 139 attrs.pop(k) 140 141 attrs['__mappings__']=mapping 142 logging.info('mapping is :'+str(mapping)) 143 attrs['__primary_key__']=primary_key 144 attrs['__sql__']=lambda self: _gen_sql(attrs['__table__'], mapping) 145 return type.__new__(cls,name,base,attrs) 146 class ModelMetaclass(type): 147 ''' 148 Metaclass for model objects. 149 ''' 150 def __new__(cls, name, bases, attrs): 151 # skip base Model class: 152 if name=='Model': 153 return type.__new__(cls, name, bases, attrs) 154 155 # store all subclasses info: 156 if not hasattr(cls, 'subclasses'): 157 cls.subclasses = {} 158 if not name in cls.subclasses: 159 cls.subclasses[name] = name 160 else: 161 logging.warning('Redefine class: %s' % name) 162 163 logging.info('Scan ORMapping %s...' % name) 164 mappings = dict() 165 primary_key = None 166 for k, v in attrs.iteritems(): 167 if isinstance(v, Field): 168 if not v.name: 169 v.name = k 170 logging.info('Found mapping: %s => %s' % (k, v)) 171 # check duplicate primary key: 172 if v.primary_key: 173 if primary_key: 174 raise TypeError('Cannot define more than 1 primary key in class: %s' % name) 175 if v.updatable: 176 logging.warning('NOTE: change primary key to non-updatable.') 177 v.updatable = False 178 if v.nullable: 179 logging.warning('NOTE: change primary key to non-nullable.') 180 v.nullable = False 181 primary_key = v 182 mappings[k] = v 183 # check exist of primary key: 184 if not primary_key: 185 raise TypeError('Primary key not defined in class: %s' % name) 186 for k in mappings.iterkeys(): 187 attrs.pop(k) 188 if not '__table__' in attrs: 189 attrs['__table__'] = name.lower() 190 attrs['__mappings__'] = mappings 191 attrs['__primary_key__'] = primary_key 192 attrs['__sql__'] = lambda self: _gen_sql(attrs['__table__'], mappings) 193 # for trigger in _triggers: 194 # if not trigger in attrs: 195 # attrs[trigger] = None 196 return type.__new__(cls, name, bases, attrs) 197 class Model(dict): 198 __metaclass__ = ModelMetaClass 199 def __init__(self, **kw): 200 super(Model, self).__init__(**kw) 201 202 def __getattr__(self, key): 203 try: 204 return self[key] 205 except KeyError: 206 raise AttributeError(r"'Dict' object has no attribute '%s'" % key) 207 208 def __setattr__(self, key, value): 209 self[key] = value 210 211 @classmethod 212 def get(cls, pk): 213 ''' 214 Get by primary key. 215 ''' 216 d = db.select_one('select * from %s where %s=?' % (cls.__table__, cls.__primary_key__.name), pk) 217 return cls(**d) if d else None 218 219 @classmethod 220 def find_first(cls, where, *args): 221 ''' 222 Find by where clause and return one result. If multiple results found, 223 only the first one returned. If no result found, return None. 224 ''' 225 d = db.select_one('select * from %s %s' % (cls.__table__, where), *args) 226 return cls(**d) if d else None 227 228 @classmethod 229 def find_all(cls, *args): 230 ''' 231 Find all and return list. 232 ''' 233 L = db.select('select * from `%s`' % cls.__table__) 234 return [cls(**d) for d in L] 235 236 @classmethod 237 def find_by(cls, where, *args): 238 ''' 239 Find by where clause and return list. 240 ''' 241 L = db.select('select * from `%s` %s' % (cls.__table__, where), *args) 242 return [cls(**d) for d in L] 243 244 @classmethod 245 def count_all(cls): 246 ''' 247 Find by 'select count(pk) from table' and return integer. 248 ''' 249 return db.select_int('select count(`%s`) from `%s`' % (cls.__primary_key__.name, cls.__table__)) 250 251 @classmethod 252 def count_by(cls, where, *args): 253 ''' 254 Find by 'select count(pk) from table where ... ' and return int. 255 ''' 256 return db.select_int('select count(`%s`) from `%s` %s' % (cls.__primary_key__.name, cls.__table__, where), *args) 257 258 def update(self): 259 self.pre_update and self.pre_update() 260 L = [] 261 args = [] 262 for k, v in self.__mappings__.iteritems(): 263 if v.updatable: 264 if hasattr(self, k): 265 arg = getattr(self, k) 266 else: 267 arg = v.default 268 setattr(self, k, arg) 269 L.append('`%s`=?' % k) 270 args.append(arg) 271 pk = self.__primary_key__.name 272 args.append(getattr(self, pk)) 273 db.update('update `%s` set %s where %s=?' % (self.__table__, ','.join(L), pk), *args) 274 return self 275 276 def delete(self): 277 self.pre_delete and self.pre_delete() 278 pk = self.__primary_key__.name 279 args = (getattr(self, pk), ) 280 db.update('delete from `%s` where `%s`=?' % (self.__table__, pk), *args) 281 return self 282 283 def insert(self): 284 self.pre_insert and self.pre_insert() 285 params = {} 286 for k, v in self.__mappings__.iteritems(): 287 if v.insertable: 288 if not hasattr(self, k): 289 setattr(self, k, v.default) 290 params[v.name] = getattr(self, k) 291 db.insert('%s' % self.__table__, **params) 292 return self 293 class user(Model): 294 name=StringField(name='name',primary_key=True) 295 password=StringField(name='password') 296 297 def main(): 298 u=user(name='yy',password='yyp') 299 300 logging.info(u.__sql__) 301 logging.info(dir(u.__mappings__.values())) 302 u.password='xxx' 303 print(u.password) 304 305 if __name__ == '__main__': 306 main()
要注意的是遍历Model属性这部分代码,利用了Python的__metaclass__实现,截断了Model的创建过程,进而对Model的属性进行遍历,具体代码见ModelMetaclass的__new__方法实现。
这是模仿廖老师的代码,[http://liaoxuefeng.com],感谢。还有两个疑问注释在了代码中,希望有看明白的人解惑。